From 5a3d86f8dd2c0e4b284925a6da38ef9014c8628a Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 17:01:08 -0400 Subject: [PATCH 01/12] feat(gcp): add GCS checkpoint path support for model weight pulling Co-Authored-By: Claude Opus 4.6 --- .../use_cases/llm_model_endpoint_use_cases.py | 54 +++++++++++++++---- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 426d1644..74e6473c 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -330,9 +330,10 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: not checkpoint_path.startswith("s3://") and not checkpoint_path.startswith("azure://") and "blob.core.windows.net" not in checkpoint_path + and not checkpoint_path.startswith("gs://") ): raise ObjectHasInvalidValueException( - f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." + f"Only S3, Azure Blob Storage, and GCS paths are supported. Given checkpoint path: {checkpoint_path}." ) if checkpoint_path.endswith(".tar"): raise ObjectHasInvalidValueException( @@ -623,9 +624,15 @@ def load_model_weights_sub_commands( final_weights_folder, trust_remote_code, ) + elif checkpoint_path.startswith("gs://"): + return self.load_model_weights_sub_commands_gcs( + checkpoint_path, + final_weights_folder, + trust_remote_code, + ) else: raise ObjectHasInvalidValueException( - f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." + f"Only S3, Azure Blob Storage, and GCS paths are supported. Given checkpoint path: {checkpoint_path}." ) def load_model_weights_sub_commands_s3( @@ -701,6 +708,30 @@ def load_model_weights_sub_commands_abs( return subcommands + def load_model_weights_sub_commands_gcs( + self, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool, + ): + subcommands = [] + + checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) + validate_checkpoint_files(checkpoint_files) + + # Install gcloud CLI on-the-fly for GCS access (similar to azcopy install for Azure) + subcommands.append( + "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" + " | tar -xz -C /opt" + " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null" + ) + + subcommands.append( + f"/opt/google-cloud-sdk/bin/gcloud storage cp -r" + f" {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + return subcommands + def load_model_files_sub_commands_trt_llm( self, checkpoint_path, @@ -717,14 +748,19 @@ def load_model_files_sub_commands_trt_llm( subcommands = [ f"./s5cmd {s3_endpoint_flag} --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" ] + elif checkpoint_path.startswith("gs://"): + subcommands = [ + "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" + " | tar -xz -C /opt" + " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", + f"/opt/google-cloud-sdk/bin/gcloud storage cp -r {os.path.join(checkpoint_path, '*')} ./", + ] else: - subcommands.extend( - [ - "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", - "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", - f"azcopy copy --recursive {os.path.join(checkpoint_path, '*')} ./", - ] - ) + subcommands = [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + f"azcopy copy --recursive {os.path.join(checkpoint_path, '*')} ./", + ] return subcommands async def create_deepspeed_bundle( From e8a4f5f5432668860d8b34fb35fd06d2fe118a2a Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 17:08:38 -0400 Subject: [PATCH 02/12] greptile fixes --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 74e6473c..c5876563 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -719,6 +719,10 @@ def load_model_weights_sub_commands_gcs( checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) validate_checkpoint_files(checkpoint_files) + # Auth: On GKE with Workload Identity, gcloud automatically obtains credentials + # from the node metadata server — no explicit auth step needed (unlike Azure's + # AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD). The pod's KSA must be bound to a GSA with + # storage.objects.list and storage.objects.get permissions on the checkpoint bucket. # Install gcloud CLI on-the-fly for GCS access (similar to azcopy install for Azure) subcommands.append( "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" @@ -726,9 +730,13 @@ def load_model_weights_sub_commands_gcs( " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null" ) + file_selection_str = '--include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' + if trust_remote_code: + file_selection_str += ' --include="*.py"' + subcommands.append( f"/opt/google-cloud-sdk/bin/gcloud storage cp -r" - f" {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + f" {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" ) return subcommands From 5626f4a3a652b5dc90c70cbdf6501c76a788e33f Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 17:31:05 -0400 Subject: [PATCH 03/12] tests --- model-engine/tests/unit/conftest.py | 14 +++- .../tests/unit/domain/test_llm_use_cases.py | 76 +++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index f66b5945..235cf52a 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -786,6 +786,10 @@ def __init__(self): "model-fake.bin, model-fake2.bin", "model-fake.safetensors", ], + "fake-bucket/fake-checkpoint": [ + "model-fake.bin, model-fake2.bin", + "model-fake.safetensors", + ], "llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"], "llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"], "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], @@ -866,13 +870,19 @@ def __init__(self): def _add_model(self, owner: str, model_name: str): self.existing_models.append((owner, model_name)) + def _strip_cloud_prefix(self, path: str) -> str: + for prefix in ("s3://", "gs://", "azure://"): + if path.startswith(prefix): + return path[len(prefix):] + return path + def list_files(self, path: str, **kwargs) -> List[str]: - path = path.lstrip("s3://") + path = self._strip_cloud_prefix(path) if path in self.s3_bucket: return self.s3_bucket[path] def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: - path = path.lstrip("s3://") + path = self._strip_cloud_prefix(path) if path in self.s3_bucket: return self.s3_bucket[path] diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 1738bd1c..4aeb4f97 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -64,6 +64,7 @@ validate_and_update_completion_params, validate_chat_template, validate_checkpoint_files, + validate_checkpoint_path_uri, ) from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase @@ -642,6 +643,69 @@ def test_load_model_weights_sub_commands( ] assert expected_result == subcommands + # GCS + framework = LLMInferenceFramework.VLLM + framework_image_tag = "0.2.7" + checkpoint_path = "gs://fake-bucket/fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" + " | tar -xz -C /opt" + " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", + '/opt/google-cloud-sdk/bin/gcloud storage cp -r --include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*" gs://fake-bucket/fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + trust_remote_code = True + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + ) + + expected_result = [ + "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" + " | tar -xz -C /opt" + " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", + '/opt/google-cloud-sdk/bin/gcloud storage cp -r --include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*" --include="*.py" gs://fake-bucket/fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + +def test_load_model_files_sub_commands_trt_llm_gcs( + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + + checkpoint_path = "gs://fake-bucket/fake-checkpoint" + subcommands = llm_bundle_use_case.load_model_files_sub_commands_trt_llm(checkpoint_path) + + expected_result = [ + "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" + " | tar -xz -C /opt" + " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", + "/opt/google-cloud-sdk/bin/gcloud storage cp -r gs://fake-bucket/fake-checkpoint/* ./", + ] + assert expected_result == subcommands + @pytest.mark.asyncio async def test_create_model_endpoint_trt_llm_use_case_success( @@ -2103,6 +2167,18 @@ async def test_delete_public_inference_model_raises_not_authorized( ) +def test_validate_checkpoint_path_uri_gcs(): + # Should not raise for gs:// paths + validate_checkpoint_path_uri("gs://my-bucket/models/weights/") + validate_checkpoint_path_uri("gs://bucket/path") + + # Should still reject unsupported schemes + with pytest.raises(ObjectHasInvalidValueException): + validate_checkpoint_path_uri("/local/path/to/model") + with pytest.raises(ObjectHasInvalidValueException): + validate_checkpoint_path_uri("hdfs://cluster/model") + + @pytest.mark.asyncio async def test_validate_checkpoint_files_no_safetensors(): fake_model_files = ["model-fake.bin", "model.json", "optimizer.pt"] From 19f70d8385778630d22e7658a221b4edf8c7d481 Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 17:42:51 -0400 Subject: [PATCH 04/12] lint fixes --- .../use_cases/llm_model_endpoint_use_cases.py | 474 +++++++++++++----- model-engine/tests/unit/conftest.py | 292 ++++++++--- .../tests/unit/domain/test_llm_use_cases.py | 172 +++++-- 3 files changed, 679 insertions(+), 259 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index c5876563..f4e51940 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -16,7 +16,9 @@ import yaml from model_engine_server.common.config import hmi_config -from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.batch_jobs import ( + CreateDockerImageBatchJobResourceRequests, +) from model_engine_server.common.dtos.llms import ( ChatCompletionV2Request, ChatCompletionV2StreamSuccessChunk, @@ -55,10 +57,16 @@ CompletionV2SyncResponse, ) from model_engine_server.common.dtos.llms.sglang import SGLangEndpointAdditionalArgs -from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs, VLLMModelConfig +from model_engine_server.common.dtos.llms.vllm import ( + VLLMEndpointAdditionalArgs, + VLLMModelConfig, +) from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, + TaskStatus, +) from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.config import infra_config @@ -112,7 +120,10 @@ ModelBundleRepository, TokenizerRepository, ) -from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService +from model_engine_server.domain.services import ( + LLMModelEndpointService, + ModelEndpointService, +) from model_engine_server.domain.services.llm_batch_completions_service import ( LLMBatchCompletionsService, ) @@ -141,7 +152,9 @@ def _get_s3_endpoint_flag() -> str: """Get S3 endpoint flag for s5cmd, determined at command construction time.""" - s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv("S3_ENDPOINT_URL") + s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv( + "S3_ENDPOINT_URL" + ) if s3_endpoint: return f"--endpoint-url {s3_endpoint}" return "" @@ -187,20 +200,26 @@ def _get_s3_endpoint_flag() -> str: } -NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes +NUM_DOWNSTREAM_REQUEST_RETRIES = ( + 80 # has to be high enough so that the retries take the 5 minutes +) DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes DEFAULT_BATCH_COMPLETIONS_NODES_PER_WORKER = 1 SERVICE_NAME = "model-engine" -LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" +LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = ( + f"{SERVICE_NAME}-inference-framework-latest-config" +) RECOMMENDED_HARDWARE_CONFIG_MAP_NAME = f"{SERVICE_NAME}-recommended-hardware-config" SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") if SERVICE_IDENTIFIER: SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" -def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: +def count_tokens( + input: str, model_name: str, tokenizer_repository: TokenizerRepository +) -> int: """ Count the number of tokens in the input string. """ @@ -280,7 +299,9 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( return response -def validate_model_name(_model_name: str, _inference_framework: LLMInferenceFramework) -> None: +def validate_model_name( + _model_name: str, _inference_framework: LLMInferenceFramework +) -> None: # TODO: replace this logic to check if the model architecture is supported instead pass @@ -304,7 +325,10 @@ def validate_num_shards( def validate_quantization( quantize: Optional[Quantization], inference_framework: LLMInferenceFramework ) -> None: - if quantize is not None and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework]: + if ( + quantize is not None + and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework] + ): raise ObjectHasInvalidValueException( f"Quantization {quantize} is not supported for inference framework {inference_framework}. Supported quantization types are {_SUPPORTED_QUANTIZATIONS[inference_framework]}." ) @@ -341,7 +365,9 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: ) -def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: +def get_checkpoint_path( + model_name: str, checkpoint_path_override: Optional[str] +) -> str: checkpoint_path = None models_info = SUPPORTED_MODELS_INFO.get(model_name, None) if checkpoint_path_override: @@ -350,7 +376,9 @@ def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str] checkpoint_path = get_models_s3_uri(models_info.s3_repo, "") # pragma: no cover if not checkpoint_path: - raise InvalidRequestException(f"No checkpoint path found for model {model_name}") + raise InvalidRequestException( + f"No checkpoint path found for model {model_name}" + ) validate_checkpoint_path_uri(checkpoint_path) return checkpoint_path @@ -361,7 +389,9 @@ def validate_checkpoint_files(checkpoint_files: List[str]) -> None: model_files = [f for f in checkpoint_files if "model" in f] num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) if num_safetensors == 0: - raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") + raise ObjectHasInvalidValueException( + "No safetensors found in the checkpoint path." + ) def encode_template(chat_template: str) -> str: @@ -501,7 +531,9 @@ async def execute( ) case LLMInferenceFramework.SGLANG: # pragma: no cover if not hmi_config.sglang_repository: - raise ObjectHasInvalidValueException("SGLang repository is not set.") + raise ObjectHasInvalidValueException( + "SGLang repository is not set." + ) additional_sglang_args = ( SGLangEndpointAdditionalArgs.model_validate(additional_args) @@ -526,7 +558,9 @@ async def execute( model_bundle = await self.model_bundle_repository.get_model_bundle(bundle_id) if model_bundle is None: - raise ObjectNotFoundException(f"Model bundle {bundle_id} was not found after creation.") + raise ObjectNotFoundException( + f"Model bundle {bundle_id} was not found after creation." + ) return model_bundle async def create_text_generation_inference_bundle( @@ -616,7 +650,10 @@ def load_model_weights_sub_commands( final_weights_folder, trust_remote_code, ) - elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path: + elif ( + checkpoint_path.startswith("azure://") + or "blob.core.windows.net" in checkpoint_path + ): return self.load_model_weights_sub_commands_abs( framework, framework_image_tag, @@ -652,7 +689,9 @@ def load_model_weights_sub_commands_s3( framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE and framework_image_tag != "0.9.3-launch_s3" ): - subcommands.append(f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}") + subcommands.append( + f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" + ) else: s5cmd = "./s5cmd" @@ -987,7 +1026,11 @@ def _create_vllm_bundle_command( exclude_none=True ) ), - **(additional_args.model_dump(exclude_none=True) if additional_args else {}), + **( + additional_args.model_dump(exclude_none=True) + if additional_args + else {} + ), } ) @@ -1041,7 +1084,9 @@ def _create_vllm_bundle_command( vllm_args.disable_log_requests = True # Use wrapper if startup metrics enabled, otherwise use vllm_server directly - server_module = "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server" + server_module = ( + "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server" + ) vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "::"' for field in VLLMEndpointAdditionalArgs.model_fields.keys(): config_value = getattr(vllm_args, field, None) @@ -1401,9 +1446,13 @@ async def execute( validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) - validate_num_shards(request.num_shards, request.inference_framework, request.gpus) + validate_num_shards( + request.num_shards, request.inference_framework, request.gpus + ) validate_quantization(request.quantize, request.inference_framework) - validate_chat_template(request.chat_template_override, request.inference_framework) + validate_chat_template( + request.chat_template_override, request.inference_framework + ) if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, @@ -1546,8 +1595,10 @@ async def execute( Returns: A response object that contains the model endpoints. """ - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=name, order_by=order_by + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=name, order_by=order_by + ) ) return ListLLMModelEndpointsV1Response( model_endpoints=[ @@ -1566,7 +1617,9 @@ def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() - async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndpointV1Response: + async def execute( + self, user: User, model_endpoint_name: str + ) -> GetLLMModelEndpointV1Response: """ Runs the use case to get the LLM endpoint with the given name. @@ -1635,7 +1688,9 @@ async def execute( ) if not model_endpoint: raise ObjectNotFoundException - if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + if not self.authz_module.check_access_write_owned_entity( + user, model_endpoint.record + ): raise ObjectNotAuthorizedException endpoint_record = model_endpoint.record @@ -1663,11 +1718,15 @@ async def execute( or request.checkpoint_path or request.chat_template_override ): - llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {}) + llm_metadata = (model_endpoint.record.metadata or {}).get( + LLM_METADATA_KEY, {} + ) inference_framework = llm_metadata["inference_framework"] if request.inference_framework_image_tag == "latest": - inference_framework_image_tag = await _get_latest_tag(inference_framework) + inference_framework_image_tag = await _get_latest_tag( + inference_framework + ) else: inference_framework_image_tag = ( request.inference_framework_image_tag @@ -1678,7 +1737,9 @@ async def execute( source = request.source or llm_metadata["source"] num_shards = request.num_shards or llm_metadata["num_shards"] quantize = request.quantize or llm_metadata.get("quantize") - checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") + checkpoint_path = request.checkpoint_path or llm_metadata.get( + "checkpoint_path" + ) validate_model_name(model_name, inference_framework) validate_num_shards( @@ -1802,7 +1863,9 @@ def __init__( self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() - async def execute(self, user: User, model_endpoint_name: str) -> DeleteLLMEndpointResponse: + async def execute( + self, user: User, model_endpoint_name: str + ) -> DeleteLLMEndpointResponse: """ Runs the use case to delete the LLM endpoint owned by the user with the given name. @@ -1817,15 +1880,21 @@ async def execute(self, user: User, model_endpoint_name: str) -> DeleteLLMEndpoi ObjectNotFoundException: If a model endpoint with the given name could not be found. ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.user_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.user_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) != 1: raise ObjectNotFoundException model_endpoint = model_endpoints[0] - if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + if not self.authz_module.check_access_write_owned_entity( + user, model_endpoint.record + ): raise ObjectNotAuthorizedException - await self.model_endpoint_service.delete_model_endpoint(model_endpoint.record.id) + await self.model_endpoint_service.delete_model_endpoint( + model_endpoint.record.id + ) return DeleteLLMEndpointResponse(deleted=True) @@ -1936,7 +2005,9 @@ def validate_and_update_completion_params( or request.guided_json is not None or request.guided_grammar is not None ) and not inference_framework == LLMInferenceFramework.VLLM: - raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") + raise ObjectHasInvalidValueException( + "Guided decoding is only supported in vllm." + ) return request @@ -1964,7 +2035,9 @@ def model_output_to_completion_output( prompt: str, with_token_probs: Optional[bool], ) -> CompletionOutput: - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: completion_token_count = len(model_output["token_probs"]["tokens"]) tokens = None @@ -1980,7 +2053,10 @@ def model_output_to_completion_output( num_completion_tokens=completion_token_count, tokens=tokens, ) - elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + elif ( + model_content.inference_framework + == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ): try: tokens = None if with_token_probs: @@ -1995,9 +2071,13 @@ def model_output_to_completion_output( tokens=tokens, ) except Exception: - logger.exception(f"Error parsing text-generation-inference output {model_output}.") + logger.exception( + f"Error parsing text-generation-inference output {model_output}." + ) if model_output.get("error_type") == "validation": - raise InvalidRequestException(model_output.get("error")) # trigger a 400 + raise InvalidRequestException( + model_output.get("error") + ) # trigger a 400 else: raise UpstreamServiceError( status_code=500, content=bytes(model_output["error"], "utf-8") @@ -2066,14 +2146,18 @@ def model_output_to_completion_output( f"Invalid endpoint {model_content.name} has no base model" ) if not prompt: - raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") + raise InvalidRequestException( + "Prompt must be provided for TensorRT-LLM models." + ) num_prompt_tokens = count_tokens( prompt, model_content.model_name, self.tokenizer_repository ) if "token_ids" in model_output: # TensorRT 23.10 has this field, TensorRT 24.03 does not # For backwards compatibility with pre-2024/05/02 - num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens + num_completion_tokens = ( + len(model_output["token_ids"]) - num_prompt_tokens + ) # Output is " prompt output" text = model_output["text_output"][(len(prompt) + 4) :] elif "output_log_probs" in model_output: @@ -2120,8 +2204,10 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: @@ -2151,14 +2237,18 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + inference_gateway = ( + self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + ) autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -2201,7 +2291,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: + if ( + predict_result.status == TaskStatus.SUCCESS + and predict_result.result is not None + ): return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( @@ -2221,7 +2314,8 @@ async def execute( ), ) elif ( - endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + endpoint_content.inference_framework + == LLMInferenceFramework.TEXT_GENERATION_INFERENCE ): tgi_args: Any = { "inputs": request.prompt, @@ -2252,7 +2346,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -2289,7 +2386,9 @@ async def execute( if request.return_token_log_probs: vllm_args["logprobs"] = 1 if request.include_stop_str_in_output is not None: - vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output + vllm_args["include_stop_str_in_output"] = ( + request.include_stop_str_in_output + ) if request.guided_choice is not None: vllm_args["guided_choice"] = request.guided_choice if request.guided_regex is not None: @@ -2313,7 +2412,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -2365,7 +2467,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -2410,7 +2515,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -2483,12 +2591,16 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: - raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + raise ObjectNotFoundException( + f"Model endpoint {model_endpoint_name} not found." + ) if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -2521,7 +2633,9 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) validated_request = validate_and_update_completion_params( model_content.inference_framework, request ) @@ -2558,7 +2672,10 @@ async def execute( model_content.model_name, self.tokenizer_repository, ) - elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + elif ( + model_content.inference_framework + == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ): args = { "inputs": request.prompt, "parameters": { @@ -2701,7 +2818,9 @@ async def _response_chunk_generator( raise UpstreamServiceError( status_code=500, content=( - res.traceback.encode("utf-8") if res.traceback is not None else b"" + res.traceback.encode("utf-8") + if res.traceback is not None + else b"" ), ) # Otherwise, yield empty response chunk for unsuccessful or empty results @@ -2760,7 +2879,9 @@ async def _response_chunk_generator( output=CompletionStreamOutput( text=result["result"]["token"]["text"], finished=finished, - num_prompt_tokens=(num_prompt_tokens if finished else None), + num_prompt_tokens=( + num_prompt_tokens if finished else None + ), num_completion_tokens=num_completion_tokens, token=token, ), @@ -2792,7 +2913,9 @@ async def _response_chunk_generator( num_completion_tokens = usage.get("completion_tokens", 0) if request.return_token_log_probs and choice.get("logprobs"): logprobs = choice["logprobs"] - if logprobs.get("tokens") and logprobs.get("token_logprobs"): + if logprobs.get("tokens") and logprobs.get( + "token_logprobs" + ): # Get the last token from the logprobs idx = len(logprobs["tokens"]) - 1 token = TokenOutput( @@ -2805,7 +2928,9 @@ async def _response_chunk_generator( finished = vllm_output["finished"] num_prompt_tokens = vllm_output["count_prompt_tokens"] num_completion_tokens = vllm_output["count_output_tokens"] - if request.return_token_log_probs and vllm_output.get("log_probs"): + if request.return_token_log_probs and vllm_output.get( + "log_probs" + ): token = TokenOutput( token=vllm_output["text"], log_prob=list(vllm_output["log_probs"].values())[0], @@ -2821,7 +2946,9 @@ async def _response_chunk_generator( ), ) # LIGHTLLM - elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + elif ( + model_content.inference_framework == LLMInferenceFramework.LIGHTLLM + ): token = None num_completion_tokens += 1 if request.return_token_log_probs: @@ -2841,7 +2968,10 @@ async def _response_chunk_generator( ), ) # TENSORRT_LLM - elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + elif ( + model_content.inference_framework + == LLMInferenceFramework.TENSORRT_LLM + ): num_completion_tokens += 1 yield CompletionStreamV1Response( request_id=request_id, @@ -2860,7 +2990,10 @@ async def _response_chunk_generator( def validate_endpoint_supports_openai_completion( endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response ): # pragma: no cover - if endpoint_content.inference_framework not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS: + if ( + endpoint_content.inference_framework + not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS + ): raise EndpointUnsupportedInferenceTypeException( f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support openai compatible completion." ) @@ -2916,8 +3049,10 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: @@ -2955,14 +3090,18 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + inference_gateway = ( + self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + ) autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -2990,7 +3129,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -3033,12 +3175,16 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: - raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + raise ObjectNotFoundException( + f"Model endpoint {model_endpoint_name} not found." + ) if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -3079,7 +3225,9 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3140,7 +3288,11 @@ async def _response_chunk_generator( if not res.status == TaskStatus.SUCCESS or res.result is None: raise UpstreamServiceError( status_code=500, - content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + content=( + res.traceback.encode("utf-8") + if res.traceback is not None + else b"" + ), ) else: result = res.result["result"] @@ -3160,12 +3312,16 @@ def validate_endpoint_supports_chat_completion( ) if not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike): - raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") + raise EndpointUnsupportedRequestException( + "Endpoint does not support chat completion" + ) flavor = endpoint.record.current_model_bundle.flavor all_routes = flavor.extra_routes + flavor.routes if OPENAI_CHAT_COMPLETION_PATH not in all_routes: - raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") + raise EndpointUnsupportedRequestException( + "Endpoint does not support chat completion" + ) class ChatCompletionSyncV2UseCase: @@ -3206,8 +3362,10 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: @@ -3245,14 +3403,18 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + inference_gateway = ( + self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + ) autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3280,7 +3442,10 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + if ( + predict_result.status != TaskStatus.SUCCESS + or predict_result.result is None + ): raise UpstreamServiceError( status_code=500, content=( @@ -3323,12 +3488,16 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None + model_endpoints = ( + await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) ) if len(model_endpoints) == 0: - raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + raise ObjectNotFoundException( + f"Model endpoint {model_endpoint_name} not found." + ) if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -3369,7 +3538,9 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint + ) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3429,7 +3600,11 @@ async def _response_chunk_generator( if not res.status == TaskStatus.SUCCESS or res.result is None: raise UpstreamServiceError( status_code=500, - content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + content=( + res.traceback.encode("utf-8") + if res.traceback is not None + else b"" + ), ) else: result = res.result["result"] @@ -3451,7 +3626,9 @@ def __init__( self.model_endpoint_service = model_endpoint_service self.llm_artifact_gateway = llm_artifact_gateway - async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownloadResponse: + async def execute( + self, user: User, request: ModelDownloadRequest + ) -> ModelDownloadResponse: model_endpoints = await self.model_endpoint_service.list_model_endpoints( owner=user.team_id, name=request.model_name, order_by=None ) @@ -3469,7 +3646,9 @@ async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownl for model_file in model_files: # don't want to make s3 bucket full keys public, so trim to just keep file name public_file_name = model_file.rsplit("/", 1)[-1] - urls[public_file_name] = self.filesystem_gateway.generate_signed_url(model_file) + urls[public_file_name] = self.filesystem_gateway.generate_signed_url( + model_file + ) return ModelDownloadResponse(urls=urls) @@ -3495,7 +3674,9 @@ async def _fill_hardware_info( raise ObjectHasInvalidValueException( "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing." ) - checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) + checkpoint_path = get_checkpoint_path( + request.model_name, request.checkpoint_path + ) hardware_info = await _infer_hardware( llm_artifact_gateway, request.model_name, checkpoint_path ) @@ -3566,14 +3747,18 @@ async def _infer_hardware( model_param_count_b = get_model_param_count_b(model_name) model_weights_size = dtype_size * model_param_count_b * 1_000_000_000 - min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) + min_memory_gb = math.ceil( + (min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9 + ) logger.info( f"Memory calculation result: {min_memory_gb=} for {model_name} context_size: {max_position_embeddings}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" ) config_map = await _get_recommended_hardware_config_map() - by_model_name = {item["name"]: item for item in yaml.safe_load(config_map["byModelName"])} + by_model_name = { + item["name"]: item for item in yaml.safe_load(config_map["byModelName"]) + } by_gpu_memory_gb = yaml.safe_load(config_map["byGpuMemoryGb"]) if model_name in by_model_name: cpus = by_model_name[model_name]["cpus"] @@ -3594,7 +3779,9 @@ async def _infer_hardware( nodes_per_worker = recs["nodes_per_worker"] break else: - raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") + raise ObjectHasInvalidValueException( + f"Unable to infer hardware for {model_name}." + ) return CreateDockerImageBatchJobResourceRequests( cpus=cpus, @@ -3656,37 +3843,33 @@ async def create_batch_job_bundle( ) -> DockerImageBatchJobBundle: assert hardware.gpu_type is not None - bundle_name = ( - f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" - ) + bundle_name = f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) config_file_path = "/opt/config.json" - batch_bundle = ( - await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( - name=bundle_name, - created_by=user.user_id, - owner=user.team_id, - image_repository=hmi_config.batch_inference_vllm_repository, - image_tag=image_tag, - command=[ - "dumb-init", - "--", - "/bin/bash", - "-c", - "ddtrace-run python vllm_batch.py", - ], - env={"CONFIG_FILE": config_file_path}, - mount_location=config_file_path, - cpus=str(hardware.cpus), - memory=str(hardware.memory), - storage=str(hardware.storage), - gpus=hardware.gpus, - gpu_type=hardware.gpu_type, - public=False, - ) + batch_bundle = await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( + name=bundle_name, + created_by=user.user_id, + owner=user.team_id, + image_repository=hmi_config.batch_inference_vllm_repository, + image_tag=image_tag, + command=[ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ], + env={"CONFIG_FILE": config_file_path}, + mount_location=config_file_path, + cpus=str(hardware.cpus), + memory=str(hardware.memory), + storage=str(hardware.storage), + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + public=False, ) return batch_bundle @@ -3714,7 +3897,10 @@ async def execute( engine_request = CreateBatchCompletionsEngineRequest.from_api_v1(request) engine_request.model_cfg.num_shards = hardware.gpus - if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": + if ( + engine_request.tool_config + and engine_request.tool_config.name != "code_evaluator" + ): raise ObjectHasInvalidValueException( "Only code_evaluator tool is supported for batch completions." ) @@ -3723,10 +3909,14 @@ async def execute( engine_request.model_cfg.model ) - engine_request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization + engine_request.max_gpu_memory_utilization = ( + additional_engine_args.gpu_memory_utilization + ) engine_request.attention_backend = additional_engine_args.attention_backend - batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) + batch_bundle = await self.create_batch_job_bundle( + user, engine_request, hardware + ) validate_resource_requests( bundle=batch_bundle, @@ -3740,21 +3930,25 @@ async def execute( if ( engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1 ): # pragma: no cover - raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + raise ObjectHasInvalidValueException( + "max_runtime_sec must be a positive integer." + ) - job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( - created_by=user.user_id, - owner=user.team_id, - job_config=engine_request.model_dump(by_alias=True), - env=batch_bundle.env, - command=batch_bundle.command, - repo=batch_bundle.image_repository, - tag=batch_bundle.image_tag, - resource_requests=hardware, - labels=engine_request.labels, - mount_location=batch_bundle.mount_location, - override_job_max_runtime_s=engine_request.max_runtime_sec, - num_workers=engine_request.data_parallelism, + job_id = ( + await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=engine_request.model_dump(by_alias=True), + env=batch_bundle.env, + command=batch_bundle.command, + repo=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + resource_requests=hardware, + labels=engine_request.labels, + mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=engine_request.max_runtime_sec, + num_workers=engine_request.data_parallelism, + ) ) return CreateBatchCompletionsV1Response(job_id=job_id) @@ -3824,7 +4018,9 @@ async def execute( ) if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: - raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + raise ObjectHasInvalidValueException( + "max_runtime_sec must be a positive integer." + ) # Right now we only support VLLM for batch inference. Refactor this if we support more inference frameworks. image_repo = hmi_config.batch_inference_vllm_repository @@ -3869,7 +4065,9 @@ async def execute( ) if not job: - raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + raise ObjectNotFoundException( + f"Batch completion {batch_completion_id} not found." + ) return GetBatchCompletionV2Response(job=job) @@ -3890,7 +4088,9 @@ async def execute( request=request, ) if not result: - raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + raise ObjectNotFoundException( + f"Batch completion {batch_completion_id} not found." + ) return UpdateBatchCompletionsV2Response( **result.model_dump(by_alias=True, exclude_none=True), diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 235cf52a..4e674d2c 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -21,8 +21,13 @@ import pytest from model_engine_server.api.dependencies import ExternalInterfaces from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.batch_jobs import ( + CreateDockerImageBatchJobResourceRequests, +) +from model_engine_server.common.dtos.docker_repository import ( + BuildImageRequest, + BuildImageResponse, +) from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.common.dtos.model_endpoints import ( @@ -32,7 +37,9 @@ ModelEndpointOrderBy, StorageSpecificationType, ) -from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.resource_manager import ( + CreateOrUpdateResourcesRequest, +) from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, @@ -152,7 +159,10 @@ translate_kwargs_to_model_bundle_orm, translate_model_bundle_orm_to_model_bundle, ) -from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services import ( + LiveBatchJobService, + LiveModelEndpointService, +) from model_engine_server.infra.services.fake_llm_batch_completions_service import ( FakeLLMBatchCompletionsService, ) @@ -211,7 +221,9 @@ def __init__(self, contents: Optional[Dict[str, ModelBundle]] = None): self.db = contents self.unique_owner_name_versions = set() for model_bundle in self.db.values(): - self.unique_owner_name_versions.add((model_bundle.owner, model_bundle.name)) + self.unique_owner_name_versions.add( + (model_bundle.owner, model_bundle.name) + ) else: self.db = {} self.unique_owner_name_versions = set() @@ -260,7 +272,9 @@ async def list_model_bundles( self, owner: str, name: Optional[str], order_by: Optional[ModelBundleOrderBy] ) -> Sequence[ModelBundle]: model_bundles = [ - mb for mb in self.db.values() if mb.owner == owner and (not name or mb.name == name) + mb + for mb in self.db.values() + if mb.owner == owner and (not name or mb.name == name) ] if order_by == ModelBundleOrderBy.NEWEST: @@ -270,8 +284,12 @@ async def list_model_bundles( return model_bundles - async def get_latest_model_bundle_by_name(self, owner: str, name: str) -> Optional[ModelBundle]: - model_bundles = await self.list_model_bundles(owner, name, ModelBundleOrderBy.NEWEST) + async def get_latest_model_bundle_by_name( + self, owner: str, name: str + ) -> Optional[ModelBundle]: + model_bundles = await self.list_model_bundles( + owner, name, ModelBundleOrderBy.NEWEST + ) if not model_bundles: return None return model_bundles[0] @@ -316,7 +334,9 @@ async def create_batch_job_record( model_bundle_id=model_bundle_id, ) orm_batch_job.created_at = datetime.now() - model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) + model_bundle = await self.model_bundle_repository.get_model_bundle( + model_bundle_id + ) assert model_bundle is not None batch_job = _translate_fake_batch_job_orm_to_batch_job_record( orm_batch_job, model_bundle=model_bundle @@ -355,14 +375,18 @@ async def update_batch_job_record( async def get_batch_job_record(self, batch_job_id: str) -> Optional[BatchJobRecord]: return self.db.get(batch_job_id) - async def list_batch_job_records(self, owner: Optional[str]) -> List[BatchJobRecord]: + async def list_batch_job_records( + self, owner: Optional[str] + ) -> List[BatchJobRecord]: def filter_fn(m: BatchJobRecord) -> bool: return not owner or m.owner == owner batch_jobs = list(filter(filter_fn, self.db.values())) return batch_jobs - async def unset_model_endpoint_id(self, batch_job_id: str) -> Optional[BatchJobRecord]: + async def unset_model_endpoint_id( + self, batch_job_id: str + ) -> Optional[BatchJobRecord]: batch_job_record = await self.get_batch_job_record(batch_job_id) if batch_job_record: batch_job_record.model_endpoint_id = None @@ -382,7 +406,9 @@ def __init__( self.db = contents self.unique_owner_name_versions = set() for model_endpoint in self.db.values(): - self.unique_owner_name_versions.add((model_endpoint.owner, model_endpoint.name)) + self.unique_owner_name_versions.add( + (model_endpoint.owner, model_endpoint.name) + ) else: self.db = {} self.unique_owner_name_versions = set() @@ -471,7 +497,9 @@ async def create_model_endpoint_record( ) orm_model_endpoint.created_at = datetime.now() orm_model_endpoint.last_updated_at = datetime.now() - model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) + model_bundle = await self.model_bundle_repository.get_model_bundle( + model_bundle_id + ) assert model_bundle is not None model_endpoint = _translate_fake_model_endpoint_orm_to_model_endpoint_record( orm_model_endpoint, current_model_bundle=model_bundle @@ -602,10 +630,14 @@ def _get_new_id(self): self.next_id += 1 return str(self.next_id) - def add_docker_image_batch_job_bundle(self, batch_bundle: DockerImageBatchJobBundle): + def add_docker_image_batch_job_bundle( + self, batch_bundle: DockerImageBatchJobBundle + ): new_id = batch_bundle.id if new_id in {bun.id for bun in self.db.values()}: - raise ValueError(f"Error in test set up, batch bundle with {new_id} already present") + raise ValueError( + f"Error in test set up, batch bundle with {new_id} already present" + ) self.db[new_id] = batch_bundle async def create_docker_image_batch_job_bundle( @@ -767,7 +799,9 @@ async def write_job_template_for_model( class FakeLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): def __init__(self): self.initialized_events = [] - self.all_events_list = [LLMFineTuneEvent(timestamp=1, message="message", level="info")] + self.all_events_list = [ + LLMFineTuneEvent(timestamp=1, message="message", level="info") + ] async def get_fine_tune_events(self, user_id: str, model_endpoint_name: str): if (user_id, model_endpoint_name) in self.initialized_events: @@ -798,7 +832,9 @@ def __init__(self): "llama-3-70b": ["model-fake.safetensors"], "llama-3-1-405b-instruct": ["model-fake.safetensors"], } - self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} + self.urls = { + "filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz" + } self.model_config = { "_name_or_path": "meta-llama/Llama-2-7b-hf", "architectures": ["LlamaForCausalLM"], @@ -873,7 +909,7 @@ def _add_model(self, owner: str, model_name: str): def _strip_cloud_prefix(self, path: str) -> str: for prefix in ("s3://", "gs://", "azure://"): if path.startswith(prefix): - return path[len(prefix):] + return path[len(prefix) :] return path def list_files(self, path: str, **kwargs) -> List[str]: @@ -881,7 +917,9 @@ def list_files(self, path: str, **kwargs) -> List[str]: if path in self.s3_bucket: return self.s3_bucket[path] - def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + def download_files( + self, path: str, target_path: str, overwrite=False, **kwargs + ) -> List[str]: path = self._strip_cloud_prefix(path) if path in self.s3_bucket: return self.s3_bucket[path] @@ -1084,7 +1122,9 @@ class FakeModelEndpointInfraGateway(ModelEndpointInfraGateway): def __init__( self, contents: Optional[Dict[str, ModelEndpointInfraState]] = None, - model_endpoint_record_repository: Optional[ModelEndpointRecordRepository] = None, + model_endpoint_record_repository: Optional[ + ModelEndpointRecordRepository + ] = None, ): self.db = contents if contents else {} self.in_flight_infra = {} @@ -1181,9 +1221,9 @@ def update_model_endpoint_infra_in_place( if kwargs["per_worker"] is not None: model_endpoint_infra.deployment_state.per_worker = kwargs["per_worker"] if kwargs["concurrent_requests_per_worker"] is not None: - model_endpoint_infra.deployment_state.concurrent_requests_per_worker = kwargs[ - "concurrent_requests_per_worker" - ] + model_endpoint_infra.deployment_state.concurrent_requests_per_worker = ( + kwargs["concurrent_requests_per_worker"] + ) if kwargs["cpus"] is not None: model_endpoint_infra.resource_state.cpus = kwargs["cpus"] if kwargs["gpus"] is not None: @@ -1239,7 +1279,9 @@ async def update_model_endpoint_infra( assert model_endpoint_infra is not None model_endpoint_infra = model_endpoint_infra.copy() self.update_model_endpoint_infra_in_place(**locals()) - self.in_flight_infra[model_endpoint_infra.deployment_name] = model_endpoint_infra + self.in_flight_infra[model_endpoint_infra.deployment_name] = ( + model_endpoint_infra + ) return "test_creation_task_id" async def get_model_endpoint_infra( @@ -1265,7 +1307,9 @@ async def promote_in_flight_infra(self, owner: str, model_endpoint_name: str): model_endpoint_records[0].status = ModelEndpointStatus.READY del self.in_flight_infra[deployment_name] - async def delete_model_endpoint_infra(self, model_endpoint_record: ModelEndpointRecord) -> bool: + async def delete_model_endpoint_infra( + self, model_endpoint_record: ModelEndpointRecord + ) -> bool: deployment_name = self._get_deployment_name( model_endpoint_record.created_by, model_endpoint_record.name ) @@ -1288,7 +1332,9 @@ def __init__(self): self.db: Dict[str, ModelEndpointInfraState] = {} # type: ignore def add_resource(self, endpoint_id: str, infra_state: ModelEndpointInfraState): - infra_state.labels.update({"user_id": "user_id", "endpoint_name": "endpoint_name"}) + infra_state.labels.update( + {"user_id": "user_id", "endpoint_name": "endpoint_name"} + ) self.db[endpoint_id] = infra_state async def create_queue( @@ -1307,7 +1353,9 @@ async def create_or_update_resources( build_endpoint_request = request.build_endpoint_request endpoint_id = build_endpoint_request.model_endpoint_record.id model_endpoint_record = build_endpoint_request.model_endpoint_record - q = await self.create_queue(model_endpoint_record, build_endpoint_request.labels) + q = await self.create_queue( + model_endpoint_record, build_endpoint_request.labels + ) infra_state = ModelEndpointInfraState( deployment_name=build_endpoint_request.deployment_name, aws_role=build_endpoint_request.aws_role, @@ -1344,7 +1392,9 @@ async def create_or_update_resources( ) # self.db[build_endpoint_request.deployment_name] = infra_state self.db[endpoint_id] = infra_state - return EndpointResourceGatewayCreateOrUpdateResourcesResponse(destination=q.queue_name) + return EndpointResourceGatewayCreateOrUpdateResourcesResponse( + destination=q.queue_name + ) async def get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType @@ -1413,13 +1463,19 @@ async def create_docker_image_batch_job( return job_id - async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[DockerImageBatchJob]: + async def get_docker_image_batch_job( + self, batch_job_id: str + ) -> Optional[DockerImageBatchJob]: return self.db.get(batch_job_id) - async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatchJob]: + async def list_docker_image_batch_jobs( + self, owner: str + ) -> List[DockerImageBatchJob]: return [job for job in self.db.values() if job["owner"] == owner] - async def update_docker_image_batch_job(self, batch_job_id: str, cancel: bool) -> bool: + async def update_docker_image_batch_job( + self, batch_job_id: str, cancel: bool + ) -> bool: if batch_job_id not in self.db: return False @@ -1535,7 +1591,9 @@ async def create_fine_tune( return job_id - async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: + async def get_fine_tune( + self, owner: str, fine_tune_id: str + ) -> Optional[DockerImageBatchJob]: di_batch_job = self.db.get(fine_tune_id) if di_batch_job is None or di_batch_job.owner != owner: return None @@ -1560,7 +1618,9 @@ async def get_fine_tune_model_name_from_id( return None -class FakeStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): +class FakeStreamingModelEndpointInferenceGateway( + StreamingModelEndpointInferenceGateway +): def __init__(self): self.responses = [ SyncEndpointPredictV1Response( @@ -1729,12 +1789,18 @@ def __init__( self, contents: Optional[Dict[str, ModelEndpoint]] = None, model_bundle_repository: Optional[ModelBundleRepository] = None, - async_model_endpoint_inference_gateway: Optional[AsyncModelEndpointInferenceGateway] = None, + async_model_endpoint_inference_gateway: Optional[ + AsyncModelEndpointInferenceGateway + ] = None, streaming_model_endpoint_inference_gateway: Optional[ StreamingModelEndpointInferenceGateway ] = None, - sync_model_endpoint_inference_gateway: Optional[SyncModelEndpointInferenceGateway] = None, - inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway] = None, + sync_model_endpoint_inference_gateway: Optional[ + SyncModelEndpointInferenceGateway + ] = None, + inference_autoscaling_metrics_gateway: Optional[ + InferenceAutoscalingMetricsGateway + ] = None, can_scale_http_endpoint_from_zero_flag: bool = True, ): if contents: @@ -1753,28 +1819,44 @@ def __init__( self.model_bundle_repository = model_bundle_repository if async_model_endpoint_inference_gateway is None: - async_model_endpoint_inference_gateway = FakeAsyncModelEndpointInferenceGateway() - self.async_model_endpoint_inference_gateway = async_model_endpoint_inference_gateway + async_model_endpoint_inference_gateway = ( + FakeAsyncModelEndpointInferenceGateway() + ) + self.async_model_endpoint_inference_gateway = ( + async_model_endpoint_inference_gateway + ) if streaming_model_endpoint_inference_gateway is None: streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - self.streaming_model_endpoint_inference_gateway = streaming_model_endpoint_inference_gateway + self.streaming_model_endpoint_inference_gateway = ( + streaming_model_endpoint_inference_gateway + ) if sync_model_endpoint_inference_gateway is None: - sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() - self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway + sync_model_endpoint_inference_gateway = ( + FakeSyncModelEndpointInferenceGateway() + ) + self.sync_model_endpoint_inference_gateway = ( + sync_model_endpoint_inference_gateway + ) if inference_autoscaling_metrics_gateway is None: - inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() - self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + inference_autoscaling_metrics_gateway = ( + FakeInferenceAutoscalingMetricsGateway() + ) + self.inference_autoscaling_metrics_gateway = ( + inference_autoscaling_metrics_gateway + ) self.model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway() ) - self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag + self.can_scale_http_endpoint_from_zero_flag = ( + can_scale_http_endpoint_from_zero_flag + ) def get_async_model_endpoint_inference_gateway( self, @@ -1838,7 +1920,9 @@ async def create_model_endpoint( endpoint_name=name, endpoint_type=endpoint_type, ) - current_model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) + current_model_bundle = await self.model_bundle_repository.get_model_bundle( + model_bundle_id + ) assert current_model_bundle is not None model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -1929,7 +2013,9 @@ async def update_model_endpoint( task_expires_seconds: Optional[int] = None, public_inference: Optional[bool] = None, ) -> ModelEndpointRecord: - model_endpoint = await self.get_model_endpoint(model_endpoint_id=model_endpoint_id) + model_endpoint = await self.get_model_endpoint( + model_endpoint_id=model_endpoint_id + ) if model_endpoint is None: raise ObjectNotFoundException current_model_bundle = None @@ -1957,11 +2043,15 @@ async def update_model_endpoint( ) return model_endpoint.record - async def get_model_endpoint(self, model_endpoint_id: str) -> Optional[ModelEndpoint]: + async def get_model_endpoint( + self, model_endpoint_id: str + ) -> Optional[ModelEndpoint]: return self.db.get(model_endpoint_id) async def get_model_endpoints_schema(self, owner: str) -> ModelEndpointsSchema: - endpoints = await self.list_model_endpoints(owner=owner, name=None, order_by=None) + endpoints = await self.list_model_endpoints( + owner=owner, name=None, order_by=None + ) records = [endpoint.record for endpoint in endpoints] return self.model_endpoints_schema_gateway.get_model_endpoints_schema( model_endpoint_records=records @@ -1978,7 +2068,9 @@ async def get_model_endpoint_record( def _filter_by_name_owner( record: ModelEndpointRecord, owner: Optional[str], name: Optional[str] ): - return (not owner or record.owner == owner) and (not name or record.name == name) + return (not owner or record.owner == owner) and ( + not name or record.name == name + ) async def list_model_endpoints( self, @@ -2102,7 +2194,9 @@ def fake_docker_repository_image_never_exists() -> FakeDockerRepository: @pytest.fixture -def fake_docker_repository_image_never_exists_and_builds_dont_work() -> FakeDockerRepository: +def fake_docker_repository_image_never_exists_and_builds_dont_work() -> ( + FakeDockerRepository +): repo = FakeDockerRepository(image_always_exists=False, raises_error=True) return repo @@ -2132,7 +2226,9 @@ def fake_model_endpoint_record_repository() -> FakeModelEndpointRecordRepository @pytest.fixture -def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBundleRepository: +def fake_docker_image_batch_job_bundle_repository() -> ( + FakeDockerImageBatchJobBundleRepository +): repo = FakeDockerImageBatchJobBundleRepository() return repo @@ -2221,25 +2317,33 @@ def fake_model_primitive_gateway() -> FakeModelPrimitiveGateway: @pytest.fixture -def fake_async_model_endpoint_inference_gateway() -> FakeAsyncModelEndpointInferenceGateway: +def fake_async_model_endpoint_inference_gateway() -> ( + FakeAsyncModelEndpointInferenceGateway +): gateway = FakeAsyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_streaming_model_endpoint_inference_gateway() -> FakeStreamingModelEndpointInferenceGateway: +def fake_streaming_model_endpoint_inference_gateway() -> ( + FakeStreamingModelEndpointInferenceGateway +): gateway = FakeStreamingModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferenceGateway: +def fake_sync_model_endpoint_inference_gateway() -> ( + FakeSyncModelEndpointInferenceGateway +): gateway = FakeSyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: +def fake_inference_autoscaling_metrics_gateway() -> ( + FakeInferenceAutoscalingMetricsGateway +): gateway = FakeInferenceAutoscalingMetricsGateway() return gateway @@ -2336,14 +2440,18 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: model_endpoint_record_repository=fake_model_endpoint_record_repository, ) fake_model_endpoint_cache_repository = FakeModelEndpointCacheRepository() - async_model_endpoint_inference_gateway = FakeAsyncModelEndpointInferenceGateway() + async_model_endpoint_inference_gateway = ( + FakeAsyncModelEndpointInferenceGateway() + ) streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway( - fake_sync_inference_content + sync_model_endpoint_inference_gateway = ( + FakeSyncModelEndpointInferenceGateway(fake_sync_inference_content) + ) + inference_autoscaling_metrics_gateway = ( + FakeInferenceAutoscalingMetricsGateway() ) - inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway(), ) @@ -2371,8 +2479,10 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: ), ), ) - fake_docker_image_batch_job_bundle_repository = FakeDockerImageBatchJobBundleRepository( - contents=fake_docker_image_batch_job_bundle_repository_contents + fake_docker_image_batch_job_bundle_repository = ( + FakeDockerImageBatchJobBundleRepository( + contents=fake_docker_image_batch_job_bundle_repository_contents + ) ) fake_trigger_repository = FakeTriggerRepository( contents=fake_trigger_repository_contents @@ -2393,7 +2503,9 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_llm_fine_tuning_service_contents ) fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() - fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) + fake_file_storage_gateway = FakeFileStorageGateway( + fake_file_storage_gateway_contents + ) fake_tokenizer_repository = FakeTokenizerRepository() fake_streaming_storage_gateway = FakeStreamingStorageGateway() @@ -2973,7 +3085,9 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd @pytest.fixture -def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: +def model_endpoint_public( + test_api_key: str, model_bundle_1: ModelBundle +) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_1", @@ -3040,7 +3154,9 @@ def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> Mod @pytest.fixture -def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: +def model_endpoint_public_sync( + test_api_key: str, model_bundle_1: ModelBundle +) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_1", @@ -3107,7 +3223,9 @@ def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) - @pytest.fixture -def model_endpoint_runnable(test_api_key: str, model_bundle_4: ModelBundle) -> ModelEndpoint: +def model_endpoint_runnable( + test_api_key: str, model_bundle_4: ModelBundle +) -> ModelEndpoint: # model_bundle_4 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -3165,7 +3283,9 @@ def model_endpoint_runnable(test_api_key: str, model_bundle_4: ModelBundle) -> M @pytest.fixture -def model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: +def model_endpoint_streaming( + test_api_key: str, model_bundle_5: ModelBundle +) -> ModelEndpoint: # model_bundle_5 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -3223,7 +3343,9 @@ def model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> @pytest.fixture -def model_endpoint_multinode(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: +def model_endpoint_multinode( + test_api_key: str, model_bundle_1: ModelBundle +) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_multinode", @@ -3289,7 +3411,9 @@ def model_endpoint_multinode(test_api_key: str, model_bundle_1: ModelBundle) -> @pytest.fixture -def batch_job_1(model_bundle_1: ModelBundle, model_endpoint_1: ModelEndpoint) -> BatchJob: +def batch_job_1( + model_bundle_1: ModelBundle, model_endpoint_1: ModelEndpoint +) -> BatchJob: batch_job = BatchJob( record=BatchJobRecord( id="test_batch_job_id_1", @@ -3444,7 +3568,9 @@ def build_endpoint_request_async_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth( + kind="basic", username="username", password="password" + ) ), ) return build_endpoint_request @@ -3489,7 +3615,9 @@ def build_endpoint_request_streaming_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth( + kind="basic", username="username", password="password" + ) ), ) return build_endpoint_request @@ -3534,7 +3662,9 @@ def build_endpoint_request_sync_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth( + kind="basic", username="username", password="password" + ) ), ) return build_endpoint_request @@ -3579,7 +3709,9 @@ def build_endpoint_request_sync_pytorch( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth( + kind="basic", username="username", password="password" + ) ), ) return build_endpoint_request @@ -3623,7 +3755,9 @@ def build_endpoint_request_async_tensorflow( optimize_costs=False, default_callback_url="https://example.com/path", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth( + kind="basic", username="username", password="password" + ) ), ) return build_endpoint_request @@ -3774,7 +3908,9 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An args=["test_arg_1", "test_arg_2"], callback_url="http://test_callback_url.xyz", callback_auth=CallbackAuth( - root=CallbackBasicAuth(kind="basic", username="test_username", password="test_password") + root=CallbackBasicAuth( + kind="basic", username="test_username", password="test_password" + ) ), return_pickled=True, ) @@ -3783,7 +3919,9 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An @pytest.fixture -def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: +def sync_endpoint_predict_request_1() -> ( + Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] +): request = SyncEndpointPredictV1Request( url="test_url", return_pickled=False, @@ -4611,7 +4749,9 @@ def llm_model_endpoint_sync_trt_llm( @pytest.fixture -def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: +def llm_model_endpoint_streaming( + test_api_key: str, model_bundle_5: ModelBundle +) -> ModelEndpoint: # model_bundle_5 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 4aeb4f97..5d580828 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -3,7 +3,9 @@ from unittest import mock import pytest -from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.batch_jobs import ( + CreateDockerImageBatchJobResourceRequests, +) from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamV1Request, @@ -20,7 +22,10 @@ CreateBatchCompletionsEngineRequest, CreateBatchCompletionsV2Request, ) -from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Response, + TaskStatus, +) from model_engine_server.core.auth.authentication_repository import User from model_engine_server.domain.entities import ( LLMInferenceFramework, @@ -66,7 +71,9 @@ validate_checkpoint_files, validate_checkpoint_path_uri, ) -from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( + CreateModelBundleV2UseCase, +) from ..conftest import mocked__get_recommended_hardware_config_map from .conftest import CreateLLMModelEndpointV1Request_gen @@ -132,7 +139,9 @@ async def test_create_model_endpoint_use_case_success( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) + response_1 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_async + ) assert response_1.endpoint_creation_task_id assert isinstance(response_1, CreateLLMModelEndpointV1Response) endpoint = ( @@ -156,7 +165,9 @@ async def test_create_model_endpoint_use_case_success( } } - response_2 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_sync) + response_2 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_sync + ) assert response_2.endpoint_creation_task_id assert isinstance(response_2, CreateLLMModelEndpointV1Response) endpoint = ( @@ -214,7 +225,10 @@ async def test_create_model_endpoint_use_case_success( bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( owner=user.team_id, name=create_llm_model_endpoint_request_llama_2.name ) - assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] + assert ( + "--max-total-tokens" in bundle.flavor.command[-1] + and "4096" in bundle.flavor.command[-1] + ) response_5 = await use_case.execute( user=user, request=create_llm_model_endpoint_request_llama_3_70b @@ -299,7 +313,9 @@ async def test_create_model_bundle_fails_if_no_checkpoint( docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + request = ( + create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + ) with pytest.raises(expected_error): await use_case.execute( @@ -360,14 +376,18 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( llm_artifact_gateway=fake_llm_artifact_gateway, ) - request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + request = ( + create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + ) request.inference_framework = inference_framework request.inference_framework_image_tag = inference_framework_image_tag user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) if valid: await use_case.execute(user=user, request=request) else: - llm_bundle_use_case.docker_repository = fake_docker_repository_image_never_exists + llm_bundle_use_case.docker_repository = ( + fake_docker_repository_image_never_exists + ) with pytest.raises(DockerImageNotFoundException): await use_case.execute(user=user, request=request) @@ -592,7 +612,11 @@ def test_load_model_weights_sub_commands( trust_remote_code = True subcommands = llm_bundle_use_case.load_model_weights_sub_commands( - framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, ) expected_result = [ @@ -633,7 +657,11 @@ def test_load_model_weights_sub_commands( trust_remote_code = True subcommands = llm_bundle_use_case.load_model_weights_sub_commands( - framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, ) expected_result = [ @@ -663,7 +691,11 @@ def test_load_model_weights_sub_commands( trust_remote_code = True subcommands = llm_bundle_use_case.load_model_weights_sub_commands( - framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, ) expected_result = [ @@ -696,7 +728,9 @@ def test_load_model_files_sub_commands_trt_llm_gcs( ) checkpoint_path = "gs://fake-bucket/fake-checkpoint" - subcommands = llm_bundle_use_case.load_model_files_sub_commands_trt_llm(checkpoint_path) + subcommands = llm_bundle_use_case.load_model_files_sub_commands_trt_llm( + checkpoint_path + ) expected_result = [ "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" @@ -818,7 +852,9 @@ async def test_get_llm_model_endpoint_use_case_raises_not_found( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectNotFoundException): - await use_case.execute(user=user, model_endpoint_name="invalid_model_endpoint_name") + await use_case.execute( + user=user, model_endpoint_name="invalid_model_endpoint_name" + ) @pytest.mark.asyncio @@ -883,7 +919,9 @@ async def test_update_model_endpoint_use_case_success( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + await create_use_case.execute( + user=user, request=create_llm_model_endpoint_request_streaming + ) endpoint = ( await fake_model_endpoint_service.list_model_endpoints( owner=None, @@ -919,7 +957,10 @@ async def test_update_model_endpoint_use_case_success( "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } - assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.resource_state.memory + == update_llm_model_endpoint_request.memory + ) assert ( endpoint.infra_state.deployment_state.min_workers == update_llm_model_endpoint_request.min_workers @@ -955,7 +996,10 @@ async def test_update_model_endpoint_use_case_success( "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } - assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.resource_state.memory + == update_llm_model_endpoint_request.memory + ) assert ( endpoint.infra_state.deployment_state.min_workers == update_llm_model_endpoint_request_only_workers.min_workers @@ -1009,7 +1053,9 @@ async def test_update_model_endpoint_use_case_failure( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + await create_use_case.execute( + user=user, request=create_llm_model_endpoint_request_streaming + ) endpoint = ( await fake_model_endpoint_service.list_model_endpoints( owner=None, @@ -1138,11 +1184,12 @@ async def test_completion_sync_text_generation_inference_use_case_success( llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_sync_request: CompletionSyncV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) + fake_llm_model_endpoint_service.add_model_endpoint( + llm_model_endpoint_text_generation_inference + ) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """ + result={"result": """ { "generated_text": " Deep Learning is a new type of machine learning", "details": { @@ -1210,8 +1257,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( ] } } -""" - }, +"""}, traceback=None, status_code=200, ) @@ -1372,7 +1418,9 @@ async def test_completion_sync_use_case_predict_failed_lightllm( llm_model_endpoint_sync_lightllm: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_lightllm[0]) + fake_llm_model_endpoint_service.add_model_endpoint( + llm_model_endpoint_sync_lightllm[0] + ) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, @@ -1405,7 +1453,9 @@ async def test_completion_sync_use_case_predict_failed_trt_llm( completion_sync_request: CompletionSyncV1Request, ): completion_sync_request.return_token_log_probs = False # not yet supported - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_trt_llm[0]) + fake_llm_model_endpoint_service.add_model_endpoint( + llm_model_endpoint_sync_trt_llm[0] + ) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, @@ -1440,14 +1490,12 @@ async def test_completion_sync_use_case_predict_failed_with_errors( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """ + result={"result": """ { "error": "Request failed during generation: Server error: transport error", "error_type": "generation" } -""" - }, +"""}, traceback="failed to predict", status_code=500, ) @@ -1498,7 +1546,9 @@ async def test_validate_and_update_completion_params(): return_token_log_probs=True, ) - validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + validate_and_update_completion_params( + LLMInferenceFramework.VLLM, completion_sync_request + ) validate_and_update_completion_params( LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request @@ -1516,7 +1566,9 @@ async def test_validate_and_update_completion_params(): completion_sync_request.guided_choice = [""] completion_sync_request.guided_grammar = "" with pytest.raises(ObjectHasInvalidValueException): - validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + validate_and_update_completion_params( + LLMInferenceFramework.VLLM, completion_sync_request + ) completion_sync_request.guided_regex = None completion_sync_request.guided_choice = None @@ -1742,7 +1794,9 @@ async def test_completion_stream_text_generation_inference_use_case_success( llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_stream_request: CompletionStreamV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) + fake_llm_model_endpoint_service.add_model_endpoint( + llm_model_endpoint_text_generation_inference + ) fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, @@ -1771,7 +1825,9 @@ async def test_completion_stream_text_generation_inference_use_case_success( ), SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": "."}, "generated_text": "I am a newbie."}}, + result={ + "result": {"token": {"text": "."}, "generated_text": "I am a newbie."} + }, traceback=None, ), ] @@ -2031,7 +2087,9 @@ async def test_get_fine_tune_events_success( llm_fine_tuning_service=fake_llm_fine_tuning_service, ) response_2 = await use_case.execute(user=user, fine_tune_id=response.id) - assert len(response_2.events) == len(fake_llm_fine_tuning_events_repository.all_events_list) + assert len(response_2.events) == len( + fake_llm_fine_tuning_events_repository.all_events_list + ) @pytest.mark.asyncio @@ -2099,8 +2157,10 @@ async def test_delete_model_success( response = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name ) - remaining_endpoint_model_service = await fake_model_endpoint_service.get_model_endpoint( - llm_model_endpoint_sync[0].record.id + remaining_endpoint_model_service = ( + await fake_model_endpoint_service.get_model_endpoint( + llm_model_endpoint_sync[0].record.id + ) ) assert remaining_endpoint_model_service is None assert response.deleted is True @@ -2249,7 +2309,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 102400, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "" + ) assert hardware.cpus == 160 assert hardware.gpus == 8 assert hardware.memory == "800Gi" @@ -2377,7 +2439,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32064, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "" + ) assert hardware.cpus == 5 assert hardware.gpus == 1 assert hardware.memory == "20Gi" @@ -2437,7 +2501,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 100352, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "" + ) print(hardware) assert hardware.cpus == 5 assert hardware.gpus == 1 @@ -2487,7 +2553,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32064, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "" + ) assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2616,7 +2684,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB assert hardware.nodes_per_worker == 1 - hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True + ) assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2653,7 +2723,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB assert hardware.nodes_per_worker == 1 - hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True + ) assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2836,7 +2908,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.41.0.dev0", "vocab_size": 128256, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "" + ) assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" @@ -2869,7 +2943,9 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "use_sliding_window": False, "vocab_size": 152064, } - hardware = await _infer_hardware(fake_llm_artifact_gateway, "qwen2-72b-instruct", "") + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "qwen2-72b-instruct", "" + ) assert hardware.cpus == 80 assert hardware.gpus == 4 assert hardware.memory == "320Gi" @@ -2951,7 +3027,9 @@ async def test_create_batch_completions_v1( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) result = await use_case.execute(user, create_batch_completions_v1_request) - job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job(result.job_id) + job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job( + result.job_id + ) assert job.num_workers == create_batch_completions_v1_request.data_parallelism bundle = list(fake_docker_image_batch_job_bundle_repository.db.values())[0] @@ -3107,7 +3185,9 @@ def test_merge_metadata(): def test_validate_chat_template(): assert validate_chat_template(None, LLMInferenceFramework.DEEPSPEED) is None good_chat_template = CHAT_TEMPLATE_MAX_LENGTH * "_" - assert validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None + assert ( + validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None + ) bad_chat_template = (CHAT_TEMPLATE_MAX_LENGTH + 1) * "_" with pytest.raises(ObjectHasInvalidValueException): From 6c6df3eb20727b08c338b8f783342af989460aee Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 17:58:16 -0400 Subject: [PATCH 05/12] fix: black formatting with line-length=100 --- .../use_cases/llm_model_endpoint_use_cases.py | 455 +++++------------- model-engine/tests/unit/conftest.py | 272 +++-------- .../tests/unit/domain/test_llm_use_cases.py | 129 ++--- 3 files changed, 236 insertions(+), 620 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f4e51940..3b61965f 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -152,9 +152,7 @@ def _get_s3_endpoint_flag() -> str: """Get S3 endpoint flag for s5cmd, determined at command construction time.""" - s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv( - "S3_ENDPOINT_URL" - ) + s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv("S3_ENDPOINT_URL") if s3_endpoint: return f"--endpoint-url {s3_endpoint}" return "" @@ -200,26 +198,20 @@ def _get_s3_endpoint_flag() -> str: } -NUM_DOWNSTREAM_REQUEST_RETRIES = ( - 80 # has to be high enough so that the retries take the 5 minutes -) +NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes DEFAULT_BATCH_COMPLETIONS_NODES_PER_WORKER = 1 SERVICE_NAME = "model-engine" -LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = ( - f"{SERVICE_NAME}-inference-framework-latest-config" -) +LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" RECOMMENDED_HARDWARE_CONFIG_MAP_NAME = f"{SERVICE_NAME}-recommended-hardware-config" SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") if SERVICE_IDENTIFIER: SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" -def count_tokens( - input: str, model_name: str, tokenizer_repository: TokenizerRepository -) -> int: +def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: """ Count the number of tokens in the input string. """ @@ -299,9 +291,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response( return response -def validate_model_name( - _model_name: str, _inference_framework: LLMInferenceFramework -) -> None: +def validate_model_name(_model_name: str, _inference_framework: LLMInferenceFramework) -> None: # TODO: replace this logic to check if the model architecture is supported instead pass @@ -325,10 +315,7 @@ def validate_num_shards( def validate_quantization( quantize: Optional[Quantization], inference_framework: LLMInferenceFramework ) -> None: - if ( - quantize is not None - and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework] - ): + if quantize is not None and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework]: raise ObjectHasInvalidValueException( f"Quantization {quantize} is not supported for inference framework {inference_framework}. Supported quantization types are {_SUPPORTED_QUANTIZATIONS[inference_framework]}." ) @@ -365,9 +352,7 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None: ) -def get_checkpoint_path( - model_name: str, checkpoint_path_override: Optional[str] -) -> str: +def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: checkpoint_path = None models_info = SUPPORTED_MODELS_INFO.get(model_name, None) if checkpoint_path_override: @@ -376,9 +361,7 @@ def get_checkpoint_path( checkpoint_path = get_models_s3_uri(models_info.s3_repo, "") # pragma: no cover if not checkpoint_path: - raise InvalidRequestException( - f"No checkpoint path found for model {model_name}" - ) + raise InvalidRequestException(f"No checkpoint path found for model {model_name}") validate_checkpoint_path_uri(checkpoint_path) return checkpoint_path @@ -389,9 +372,7 @@ def validate_checkpoint_files(checkpoint_files: List[str]) -> None: model_files = [f for f in checkpoint_files if "model" in f] num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) if num_safetensors == 0: - raise ObjectHasInvalidValueException( - "No safetensors found in the checkpoint path." - ) + raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") def encode_template(chat_template: str) -> str: @@ -531,9 +512,7 @@ async def execute( ) case LLMInferenceFramework.SGLANG: # pragma: no cover if not hmi_config.sglang_repository: - raise ObjectHasInvalidValueException( - "SGLang repository is not set." - ) + raise ObjectHasInvalidValueException("SGLang repository is not set.") additional_sglang_args = ( SGLangEndpointAdditionalArgs.model_validate(additional_args) @@ -558,9 +537,7 @@ async def execute( model_bundle = await self.model_bundle_repository.get_model_bundle(bundle_id) if model_bundle is None: - raise ObjectNotFoundException( - f"Model bundle {bundle_id} was not found after creation." - ) + raise ObjectNotFoundException(f"Model bundle {bundle_id} was not found after creation.") return model_bundle async def create_text_generation_inference_bundle( @@ -650,10 +627,7 @@ def load_model_weights_sub_commands( final_weights_folder, trust_remote_code, ) - elif ( - checkpoint_path.startswith("azure://") - or "blob.core.windows.net" in checkpoint_path - ): + elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path: return self.load_model_weights_sub_commands_abs( framework, framework_image_tag, @@ -689,9 +663,7 @@ def load_model_weights_sub_commands_s3( framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE and framework_image_tag != "0.9.3-launch_s3" ): - subcommands.append( - f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" - ) + subcommands.append(f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}") else: s5cmd = "./s5cmd" @@ -1026,11 +998,7 @@ def _create_vllm_bundle_command( exclude_none=True ) ), - **( - additional_args.model_dump(exclude_none=True) - if additional_args - else {} - ), + **(additional_args.model_dump(exclude_none=True) if additional_args else {}), } ) @@ -1084,9 +1052,7 @@ def _create_vllm_bundle_command( vllm_args.disable_log_requests = True # Use wrapper if startup metrics enabled, otherwise use vllm_server directly - server_module = ( - "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server" - ) + server_module = "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server" vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "::"' for field in VLLMEndpointAdditionalArgs.model_fields.keys(): config_value = getattr(vllm_args, field, None) @@ -1446,13 +1412,9 @@ async def execute( validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) validate_model_name(request.model_name, request.inference_framework) - validate_num_shards( - request.num_shards, request.inference_framework, request.gpus - ) + validate_num_shards(request.num_shards, request.inference_framework, request.gpus) validate_quantization(request.quantize, request.inference_framework) - validate_chat_template( - request.chat_template_override, request.inference_framework - ) + validate_chat_template(request.chat_template_override, request.inference_framework) if request.inference_framework in [ LLMInferenceFramework.TEXT_GENERATION_INFERENCE, @@ -1595,10 +1557,8 @@ async def execute( Returns: A response object that contains the model endpoints. """ - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=name, order_by=order_by - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=name, order_by=order_by ) return ListLLMModelEndpointsV1Response( model_endpoints=[ @@ -1617,9 +1577,7 @@ def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() - async def execute( - self, user: User, model_endpoint_name: str - ) -> GetLLMModelEndpointV1Response: + async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndpointV1Response: """ Runs the use case to get the LLM endpoint with the given name. @@ -1688,9 +1646,7 @@ async def execute( ) if not model_endpoint: raise ObjectNotFoundException - if not self.authz_module.check_access_write_owned_entity( - user, model_endpoint.record - ): + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): raise ObjectNotAuthorizedException endpoint_record = model_endpoint.record @@ -1718,15 +1674,11 @@ async def execute( or request.checkpoint_path or request.chat_template_override ): - llm_metadata = (model_endpoint.record.metadata or {}).get( - LLM_METADATA_KEY, {} - ) + llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {}) inference_framework = llm_metadata["inference_framework"] if request.inference_framework_image_tag == "latest": - inference_framework_image_tag = await _get_latest_tag( - inference_framework - ) + inference_framework_image_tag = await _get_latest_tag(inference_framework) else: inference_framework_image_tag = ( request.inference_framework_image_tag @@ -1737,9 +1689,7 @@ async def execute( source = request.source or llm_metadata["source"] num_shards = request.num_shards or llm_metadata["num_shards"] quantize = request.quantize or llm_metadata.get("quantize") - checkpoint_path = request.checkpoint_path or llm_metadata.get( - "checkpoint_path" - ) + checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") validate_model_name(model_name, inference_framework) validate_num_shards( @@ -1863,9 +1813,7 @@ def __init__( self.llm_model_endpoint_service = llm_model_endpoint_service self.authz_module = LiveAuthorizationModule() - async def execute( - self, user: User, model_endpoint_name: str - ) -> DeleteLLMEndpointResponse: + async def execute(self, user: User, model_endpoint_name: str) -> DeleteLLMEndpointResponse: """ Runs the use case to delete the LLM endpoint owned by the user with the given name. @@ -1880,21 +1828,15 @@ async def execute( ObjectNotFoundException: If a model endpoint with the given name could not be found. ObjectNotAuthorizedException: If the owner does not own the model endpoint. """ - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.user_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.user_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) != 1: raise ObjectNotFoundException model_endpoint = model_endpoints[0] - if not self.authz_module.check_access_write_owned_entity( - user, model_endpoint.record - ): + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): raise ObjectNotAuthorizedException - await self.model_endpoint_service.delete_model_endpoint( - model_endpoint.record.id - ) + await self.model_endpoint_service.delete_model_endpoint(model_endpoint.record.id) return DeleteLLMEndpointResponse(deleted=True) @@ -2005,9 +1947,7 @@ def validate_and_update_completion_params( or request.guided_json is not None or request.guided_grammar is not None ) and not inference_framework == LLMInferenceFramework.VLLM: - raise ObjectHasInvalidValueException( - "Guided decoding is only supported in vllm." - ) + raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") return request @@ -2035,9 +1975,7 @@ def model_output_to_completion_output( prompt: str, with_token_probs: Optional[bool], ) -> CompletionOutput: - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: completion_token_count = len(model_output["token_probs"]["tokens"]) tokens = None @@ -2053,10 +1991,7 @@ def model_output_to_completion_output( num_completion_tokens=completion_token_count, tokens=tokens, ) - elif ( - model_content.inference_framework - == LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ): + elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: try: tokens = None if with_token_probs: @@ -2071,13 +2006,9 @@ def model_output_to_completion_output( tokens=tokens, ) except Exception: - logger.exception( - f"Error parsing text-generation-inference output {model_output}." - ) + logger.exception(f"Error parsing text-generation-inference output {model_output}.") if model_output.get("error_type") == "validation": - raise InvalidRequestException( - model_output.get("error") - ) # trigger a 400 + raise InvalidRequestException(model_output.get("error")) # trigger a 400 else: raise UpstreamServiceError( status_code=500, content=bytes(model_output["error"], "utf-8") @@ -2146,18 +2077,14 @@ def model_output_to_completion_output( f"Invalid endpoint {model_content.name} has no base model" ) if not prompt: - raise InvalidRequestException( - "Prompt must be provided for TensorRT-LLM models." - ) + raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") num_prompt_tokens = count_tokens( prompt, model_content.model_name, self.tokenizer_repository ) if "token_ids" in model_output: # TensorRT 23.10 has this field, TensorRT 24.03 does not # For backwards compatibility with pre-2024/05/02 - num_completion_tokens = ( - len(model_output["token_ids"]) - num_prompt_tokens - ) + num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens # Output is " prompt output" text = model_output["text_output"][(len(prompt) + 4) :] elif "output_log_probs" in model_output: @@ -2204,10 +2131,8 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: @@ -2237,18 +2162,14 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = ( - self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() - ) + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -2291,10 +2212,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status == TaskStatus.SUCCESS - and predict_result.result is not None - ): + if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: return CompletionSyncV1Response( request_id=request_id, output=self.model_output_to_completion_output( @@ -2314,8 +2232,7 @@ async def execute( ), ) elif ( - endpoint_content.inference_framework - == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE ): tgi_args: Any = { "inputs": request.prompt, @@ -2346,10 +2263,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -2386,9 +2300,7 @@ async def execute( if request.return_token_log_probs: vllm_args["logprobs"] = 1 if request.include_stop_str_in_output is not None: - vllm_args["include_stop_str_in_output"] = ( - request.include_stop_str_in_output - ) + vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output if request.guided_choice is not None: vllm_args["guided_choice"] = request.guided_choice if request.guided_regex is not None: @@ -2412,10 +2324,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -2467,10 +2376,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -2515,10 +2421,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -2591,16 +2494,12 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: - raise ObjectNotFoundException( - f"Model endpoint {model_endpoint_name} not found." - ) + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -2633,9 +2532,7 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) validated_request = validate_and_update_completion_params( model_content.inference_framework, request ) @@ -2672,10 +2569,7 @@ async def execute( model_content.model_name, self.tokenizer_repository, ) - elif ( - model_content.inference_framework - == LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ): + elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: args = { "inputs": request.prompt, "parameters": { @@ -2818,9 +2712,7 @@ async def _response_chunk_generator( raise UpstreamServiceError( status_code=500, content=( - res.traceback.encode("utf-8") - if res.traceback is not None - else b"" + res.traceback.encode("utf-8") if res.traceback is not None else b"" ), ) # Otherwise, yield empty response chunk for unsuccessful or empty results @@ -2879,9 +2771,7 @@ async def _response_chunk_generator( output=CompletionStreamOutput( text=result["result"]["token"]["text"], finished=finished, - num_prompt_tokens=( - num_prompt_tokens if finished else None - ), + num_prompt_tokens=(num_prompt_tokens if finished else None), num_completion_tokens=num_completion_tokens, token=token, ), @@ -2913,9 +2803,7 @@ async def _response_chunk_generator( num_completion_tokens = usage.get("completion_tokens", 0) if request.return_token_log_probs and choice.get("logprobs"): logprobs = choice["logprobs"] - if logprobs.get("tokens") and logprobs.get( - "token_logprobs" - ): + if logprobs.get("tokens") and logprobs.get("token_logprobs"): # Get the last token from the logprobs idx = len(logprobs["tokens"]) - 1 token = TokenOutput( @@ -2928,9 +2816,7 @@ async def _response_chunk_generator( finished = vllm_output["finished"] num_prompt_tokens = vllm_output["count_prompt_tokens"] num_completion_tokens = vllm_output["count_output_tokens"] - if request.return_token_log_probs and vllm_output.get( - "log_probs" - ): + if request.return_token_log_probs and vllm_output.get("log_probs"): token = TokenOutput( token=vllm_output["text"], log_prob=list(vllm_output["log_probs"].values())[0], @@ -2946,9 +2832,7 @@ async def _response_chunk_generator( ), ) # LIGHTLLM - elif ( - model_content.inference_framework == LLMInferenceFramework.LIGHTLLM - ): + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: token = None num_completion_tokens += 1 if request.return_token_log_probs: @@ -2968,10 +2852,7 @@ async def _response_chunk_generator( ), ) # TENSORRT_LLM - elif ( - model_content.inference_framework - == LLMInferenceFramework.TENSORRT_LLM - ): + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: num_completion_tokens += 1 yield CompletionStreamV1Response( request_id=request_id, @@ -2990,10 +2871,7 @@ async def _response_chunk_generator( def validate_endpoint_supports_openai_completion( endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response ): # pragma: no cover - if ( - endpoint_content.inference_framework - not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS - ): + if endpoint_content.inference_framework not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS: raise EndpointUnsupportedInferenceTypeException( f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support openai compatible completion." ) @@ -3049,10 +2927,8 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: @@ -3090,18 +2966,14 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = ( - self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() - ) + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3129,10 +3001,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -3175,16 +3044,12 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: - raise ObjectNotFoundException( - f"Model endpoint {model_endpoint_name} not found." - ) + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -3225,9 +3090,7 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3288,11 +3151,7 @@ async def _response_chunk_generator( if not res.status == TaskStatus.SUCCESS or res.result is None: raise UpstreamServiceError( status_code=500, - content=( - res.traceback.encode("utf-8") - if res.traceback is not None - else b"" - ), + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), ) else: result = res.result["result"] @@ -3312,16 +3171,12 @@ def validate_endpoint_supports_chat_completion( ) if not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike): - raise EndpointUnsupportedRequestException( - "Endpoint does not support chat completion" - ) + raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") flavor = endpoint.record.current_model_bundle.flavor all_routes = flavor.extra_routes + flavor.routes if OPENAI_CHAT_COMPLETION_PATH not in all_routes: - raise EndpointUnsupportedRequestException( - "Endpoint does not support chat completion" - ) + raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") class ChatCompletionSyncV2UseCase: @@ -3362,10 +3217,8 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: @@ -3403,18 +3256,14 @@ async def execute( f"Endpoint {model_endpoint_name} does not serve sync requests." ) - inference_gateway = ( - self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() - ) + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() autoscaling_metrics_gateway = ( self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() ) await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( endpoint_id=model_endpoint.record.id ) - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3442,10 +3291,7 @@ async def execute( endpoint_name=model_endpoint.record.name, ) - if ( - predict_result.status != TaskStatus.SUCCESS - or predict_result.result is None - ): + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: raise UpstreamServiceError( status_code=500, content=( @@ -3488,16 +3334,12 @@ async def execute( request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) add_trace_request_id(request_id) - model_endpoints = ( - await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None ) if len(model_endpoints) == 0: - raise ObjectNotFoundException( - f"Model endpoint {model_endpoint_name} not found." - ) + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( @@ -3538,9 +3380,7 @@ async def execute( endpoint_id=model_endpoint.record.id ) - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint - ) + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) manually_resolve_dns = ( model_endpoint.infra_state is not None @@ -3600,11 +3440,7 @@ async def _response_chunk_generator( if not res.status == TaskStatus.SUCCESS or res.result is None: raise UpstreamServiceError( status_code=500, - content=( - res.traceback.encode("utf-8") - if res.traceback is not None - else b"" - ), + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), ) else: result = res.result["result"] @@ -3626,9 +3462,7 @@ def __init__( self.model_endpoint_service = model_endpoint_service self.llm_artifact_gateway = llm_artifact_gateway - async def execute( - self, user: User, request: ModelDownloadRequest - ) -> ModelDownloadResponse: + async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownloadResponse: model_endpoints = await self.model_endpoint_service.list_model_endpoints( owner=user.team_id, name=request.model_name, order_by=None ) @@ -3646,9 +3480,7 @@ async def execute( for model_file in model_files: # don't want to make s3 bucket full keys public, so trim to just keep file name public_file_name = model_file.rsplit("/", 1)[-1] - urls[public_file_name] = self.filesystem_gateway.generate_signed_url( - model_file - ) + urls[public_file_name] = self.filesystem_gateway.generate_signed_url(model_file) return ModelDownloadResponse(urls=urls) @@ -3674,9 +3506,7 @@ async def _fill_hardware_info( raise ObjectHasInvalidValueException( "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing." ) - checkpoint_path = get_checkpoint_path( - request.model_name, request.checkpoint_path - ) + checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) hardware_info = await _infer_hardware( llm_artifact_gateway, request.model_name, checkpoint_path ) @@ -3747,18 +3577,14 @@ async def _infer_hardware( model_param_count_b = get_model_param_count_b(model_name) model_weights_size = dtype_size * model_param_count_b * 1_000_000_000 - min_memory_gb = math.ceil( - (min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9 - ) + min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) logger.info( f"Memory calculation result: {min_memory_gb=} for {model_name} context_size: {max_position_embeddings}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" ) config_map = await _get_recommended_hardware_config_map() - by_model_name = { - item["name"]: item for item in yaml.safe_load(config_map["byModelName"]) - } + by_model_name = {item["name"]: item for item in yaml.safe_load(config_map["byModelName"])} by_gpu_memory_gb = yaml.safe_load(config_map["byGpuMemoryGb"]) if model_name in by_model_name: cpus = by_model_name[model_name]["cpus"] @@ -3779,9 +3605,7 @@ async def _infer_hardware( nodes_per_worker = recs["nodes_per_worker"] break else: - raise ObjectHasInvalidValueException( - f"Unable to infer hardware for {model_name}." - ) + raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") return CreateDockerImageBatchJobResourceRequests( cpus=cpus, @@ -3843,33 +3667,37 @@ async def create_batch_job_bundle( ) -> DockerImageBatchJobBundle: assert hardware.gpu_type is not None - bundle_name = f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" + bundle_name = ( + f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" + ) image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) config_file_path = "/opt/config.json" - batch_bundle = await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( - name=bundle_name, - created_by=user.user_id, - owner=user.team_id, - image_repository=hmi_config.batch_inference_vllm_repository, - image_tag=image_tag, - command=[ - "dumb-init", - "--", - "/bin/bash", - "-c", - "ddtrace-run python vllm_batch.py", - ], - env={"CONFIG_FILE": config_file_path}, - mount_location=config_file_path, - cpus=str(hardware.cpus), - memory=str(hardware.memory), - storage=str(hardware.storage), - gpus=hardware.gpus, - gpu_type=hardware.gpu_type, - public=False, + batch_bundle = ( + await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( + name=bundle_name, + created_by=user.user_id, + owner=user.team_id, + image_repository=hmi_config.batch_inference_vllm_repository, + image_tag=image_tag, + command=[ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ], + env={"CONFIG_FILE": config_file_path}, + mount_location=config_file_path, + cpus=str(hardware.cpus), + memory=str(hardware.memory), + storage=str(hardware.storage), + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + public=False, + ) ) return batch_bundle @@ -3897,10 +3725,7 @@ async def execute( engine_request = CreateBatchCompletionsEngineRequest.from_api_v1(request) engine_request.model_cfg.num_shards = hardware.gpus - if ( - engine_request.tool_config - and engine_request.tool_config.name != "code_evaluator" - ): + if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": raise ObjectHasInvalidValueException( "Only code_evaluator tool is supported for batch completions." ) @@ -3909,14 +3734,10 @@ async def execute( engine_request.model_cfg.model ) - engine_request.max_gpu_memory_utilization = ( - additional_engine_args.gpu_memory_utilization - ) + engine_request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization engine_request.attention_backend = additional_engine_args.attention_backend - batch_bundle = await self.create_batch_job_bundle( - user, engine_request, hardware - ) + batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) validate_resource_requests( bundle=batch_bundle, @@ -3930,25 +3751,21 @@ async def execute( if ( engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1 ): # pragma: no cover - raise ObjectHasInvalidValueException( - "max_runtime_sec must be a positive integer." - ) + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") - job_id = ( - await self.docker_image_batch_job_gateway.create_docker_image_batch_job( - created_by=user.user_id, - owner=user.team_id, - job_config=engine_request.model_dump(by_alias=True), - env=batch_bundle.env, - command=batch_bundle.command, - repo=batch_bundle.image_repository, - tag=batch_bundle.image_tag, - resource_requests=hardware, - labels=engine_request.labels, - mount_location=batch_bundle.mount_location, - override_job_max_runtime_s=engine_request.max_runtime_sec, - num_workers=engine_request.data_parallelism, - ) + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=engine_request.model_dump(by_alias=True), + env=batch_bundle.env, + command=batch_bundle.command, + repo=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + resource_requests=hardware, + labels=engine_request.labels, + mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=engine_request.max_runtime_sec, + num_workers=engine_request.data_parallelism, ) return CreateBatchCompletionsV1Response(job_id=job_id) @@ -4018,9 +3835,7 @@ async def execute( ) if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: - raise ObjectHasInvalidValueException( - "max_runtime_sec must be a positive integer." - ) + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") # Right now we only support VLLM for batch inference. Refactor this if we support more inference frameworks. image_repo = hmi_config.batch_inference_vllm_repository @@ -4065,9 +3880,7 @@ async def execute( ) if not job: - raise ObjectNotFoundException( - f"Batch completion {batch_completion_id} not found." - ) + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") return GetBatchCompletionV2Response(job=job) @@ -4088,9 +3901,7 @@ async def execute( request=request, ) if not result: - raise ObjectNotFoundException( - f"Batch completion {batch_completion_id} not found." - ) + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") return UpdateBatchCompletionsV2Response( **result.model_dump(by_alias=True, exclude_none=True), diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 4e674d2c..4f6b0301 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -221,9 +221,7 @@ def __init__(self, contents: Optional[Dict[str, ModelBundle]] = None): self.db = contents self.unique_owner_name_versions = set() for model_bundle in self.db.values(): - self.unique_owner_name_versions.add( - (model_bundle.owner, model_bundle.name) - ) + self.unique_owner_name_versions.add((model_bundle.owner, model_bundle.name)) else: self.db = {} self.unique_owner_name_versions = set() @@ -272,9 +270,7 @@ async def list_model_bundles( self, owner: str, name: Optional[str], order_by: Optional[ModelBundleOrderBy] ) -> Sequence[ModelBundle]: model_bundles = [ - mb - for mb in self.db.values() - if mb.owner == owner and (not name or mb.name == name) + mb for mb in self.db.values() if mb.owner == owner and (not name or mb.name == name) ] if order_by == ModelBundleOrderBy.NEWEST: @@ -284,12 +280,8 @@ async def list_model_bundles( return model_bundles - async def get_latest_model_bundle_by_name( - self, owner: str, name: str - ) -> Optional[ModelBundle]: - model_bundles = await self.list_model_bundles( - owner, name, ModelBundleOrderBy.NEWEST - ) + async def get_latest_model_bundle_by_name(self, owner: str, name: str) -> Optional[ModelBundle]: + model_bundles = await self.list_model_bundles(owner, name, ModelBundleOrderBy.NEWEST) if not model_bundles: return None return model_bundles[0] @@ -334,9 +326,7 @@ async def create_batch_job_record( model_bundle_id=model_bundle_id, ) orm_batch_job.created_at = datetime.now() - model_bundle = await self.model_bundle_repository.get_model_bundle( - model_bundle_id - ) + model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) assert model_bundle is not None batch_job = _translate_fake_batch_job_orm_to_batch_job_record( orm_batch_job, model_bundle=model_bundle @@ -375,18 +365,14 @@ async def update_batch_job_record( async def get_batch_job_record(self, batch_job_id: str) -> Optional[BatchJobRecord]: return self.db.get(batch_job_id) - async def list_batch_job_records( - self, owner: Optional[str] - ) -> List[BatchJobRecord]: + async def list_batch_job_records(self, owner: Optional[str]) -> List[BatchJobRecord]: def filter_fn(m: BatchJobRecord) -> bool: return not owner or m.owner == owner batch_jobs = list(filter(filter_fn, self.db.values())) return batch_jobs - async def unset_model_endpoint_id( - self, batch_job_id: str - ) -> Optional[BatchJobRecord]: + async def unset_model_endpoint_id(self, batch_job_id: str) -> Optional[BatchJobRecord]: batch_job_record = await self.get_batch_job_record(batch_job_id) if batch_job_record: batch_job_record.model_endpoint_id = None @@ -406,9 +392,7 @@ def __init__( self.db = contents self.unique_owner_name_versions = set() for model_endpoint in self.db.values(): - self.unique_owner_name_versions.add( - (model_endpoint.owner, model_endpoint.name) - ) + self.unique_owner_name_versions.add((model_endpoint.owner, model_endpoint.name)) else: self.db = {} self.unique_owner_name_versions = set() @@ -497,9 +481,7 @@ async def create_model_endpoint_record( ) orm_model_endpoint.created_at = datetime.now() orm_model_endpoint.last_updated_at = datetime.now() - model_bundle = await self.model_bundle_repository.get_model_bundle( - model_bundle_id - ) + model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) assert model_bundle is not None model_endpoint = _translate_fake_model_endpoint_orm_to_model_endpoint_record( orm_model_endpoint, current_model_bundle=model_bundle @@ -630,14 +612,10 @@ def _get_new_id(self): self.next_id += 1 return str(self.next_id) - def add_docker_image_batch_job_bundle( - self, batch_bundle: DockerImageBatchJobBundle - ): + def add_docker_image_batch_job_bundle(self, batch_bundle: DockerImageBatchJobBundle): new_id = batch_bundle.id if new_id in {bun.id for bun in self.db.values()}: - raise ValueError( - f"Error in test set up, batch bundle with {new_id} already present" - ) + raise ValueError(f"Error in test set up, batch bundle with {new_id} already present") self.db[new_id] = batch_bundle async def create_docker_image_batch_job_bundle( @@ -799,9 +777,7 @@ async def write_job_template_for_model( class FakeLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): def __init__(self): self.initialized_events = [] - self.all_events_list = [ - LLMFineTuneEvent(timestamp=1, message="message", level="info") - ] + self.all_events_list = [LLMFineTuneEvent(timestamp=1, message="message", level="info")] async def get_fine_tune_events(self, user_id: str, model_endpoint_name: str): if (user_id, model_endpoint_name) in self.initialized_events: @@ -832,9 +808,7 @@ def __init__(self): "llama-3-70b": ["model-fake.safetensors"], "llama-3-1-405b-instruct": ["model-fake.safetensors"], } - self.urls = { - "filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz" - } + self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} self.model_config = { "_name_or_path": "meta-llama/Llama-2-7b-hf", "architectures": ["LlamaForCausalLM"], @@ -917,9 +891,7 @@ def list_files(self, path: str, **kwargs) -> List[str]: if path in self.s3_bucket: return self.s3_bucket[path] - def download_files( - self, path: str, target_path: str, overwrite=False, **kwargs - ) -> List[str]: + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: path = self._strip_cloud_prefix(path) if path in self.s3_bucket: return self.s3_bucket[path] @@ -1122,9 +1094,7 @@ class FakeModelEndpointInfraGateway(ModelEndpointInfraGateway): def __init__( self, contents: Optional[Dict[str, ModelEndpointInfraState]] = None, - model_endpoint_record_repository: Optional[ - ModelEndpointRecordRepository - ] = None, + model_endpoint_record_repository: Optional[ModelEndpointRecordRepository] = None, ): self.db = contents if contents else {} self.in_flight_infra = {} @@ -1221,9 +1191,9 @@ def update_model_endpoint_infra_in_place( if kwargs["per_worker"] is not None: model_endpoint_infra.deployment_state.per_worker = kwargs["per_worker"] if kwargs["concurrent_requests_per_worker"] is not None: - model_endpoint_infra.deployment_state.concurrent_requests_per_worker = ( - kwargs["concurrent_requests_per_worker"] - ) + model_endpoint_infra.deployment_state.concurrent_requests_per_worker = kwargs[ + "concurrent_requests_per_worker" + ] if kwargs["cpus"] is not None: model_endpoint_infra.resource_state.cpus = kwargs["cpus"] if kwargs["gpus"] is not None: @@ -1279,9 +1249,7 @@ async def update_model_endpoint_infra( assert model_endpoint_infra is not None model_endpoint_infra = model_endpoint_infra.copy() self.update_model_endpoint_infra_in_place(**locals()) - self.in_flight_infra[model_endpoint_infra.deployment_name] = ( - model_endpoint_infra - ) + self.in_flight_infra[model_endpoint_infra.deployment_name] = model_endpoint_infra return "test_creation_task_id" async def get_model_endpoint_infra( @@ -1307,9 +1275,7 @@ async def promote_in_flight_infra(self, owner: str, model_endpoint_name: str): model_endpoint_records[0].status = ModelEndpointStatus.READY del self.in_flight_infra[deployment_name] - async def delete_model_endpoint_infra( - self, model_endpoint_record: ModelEndpointRecord - ) -> bool: + async def delete_model_endpoint_infra(self, model_endpoint_record: ModelEndpointRecord) -> bool: deployment_name = self._get_deployment_name( model_endpoint_record.created_by, model_endpoint_record.name ) @@ -1332,9 +1298,7 @@ def __init__(self): self.db: Dict[str, ModelEndpointInfraState] = {} # type: ignore def add_resource(self, endpoint_id: str, infra_state: ModelEndpointInfraState): - infra_state.labels.update( - {"user_id": "user_id", "endpoint_name": "endpoint_name"} - ) + infra_state.labels.update({"user_id": "user_id", "endpoint_name": "endpoint_name"}) self.db[endpoint_id] = infra_state async def create_queue( @@ -1353,9 +1317,7 @@ async def create_or_update_resources( build_endpoint_request = request.build_endpoint_request endpoint_id = build_endpoint_request.model_endpoint_record.id model_endpoint_record = build_endpoint_request.model_endpoint_record - q = await self.create_queue( - model_endpoint_record, build_endpoint_request.labels - ) + q = await self.create_queue(model_endpoint_record, build_endpoint_request.labels) infra_state = ModelEndpointInfraState( deployment_name=build_endpoint_request.deployment_name, aws_role=build_endpoint_request.aws_role, @@ -1392,9 +1354,7 @@ async def create_or_update_resources( ) # self.db[build_endpoint_request.deployment_name] = infra_state self.db[endpoint_id] = infra_state - return EndpointResourceGatewayCreateOrUpdateResourcesResponse( - destination=q.queue_name - ) + return EndpointResourceGatewayCreateOrUpdateResourcesResponse(destination=q.queue_name) async def get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType @@ -1463,19 +1423,13 @@ async def create_docker_image_batch_job( return job_id - async def get_docker_image_batch_job( - self, batch_job_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[DockerImageBatchJob]: return self.db.get(batch_job_id) - async def list_docker_image_batch_jobs( - self, owner: str - ) -> List[DockerImageBatchJob]: + async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatchJob]: return [job for job in self.db.values() if job["owner"] == owner] - async def update_docker_image_batch_job( - self, batch_job_id: str, cancel: bool - ) -> bool: + async def update_docker_image_batch_job(self, batch_job_id: str, cancel: bool) -> bool: if batch_job_id not in self.db: return False @@ -1591,9 +1545,7 @@ async def create_fine_tune( return job_id - async def get_fine_tune( - self, owner: str, fine_tune_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: di_batch_job = self.db.get(fine_tune_id) if di_batch_job is None or di_batch_job.owner != owner: return None @@ -1618,9 +1570,7 @@ async def get_fine_tune_model_name_from_id( return None -class FakeStreamingModelEndpointInferenceGateway( - StreamingModelEndpointInferenceGateway -): +class FakeStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): def __init__(self): self.responses = [ SyncEndpointPredictV1Response( @@ -1789,18 +1739,12 @@ def __init__( self, contents: Optional[Dict[str, ModelEndpoint]] = None, model_bundle_repository: Optional[ModelBundleRepository] = None, - async_model_endpoint_inference_gateway: Optional[ - AsyncModelEndpointInferenceGateway - ] = None, + async_model_endpoint_inference_gateway: Optional[AsyncModelEndpointInferenceGateway] = None, streaming_model_endpoint_inference_gateway: Optional[ StreamingModelEndpointInferenceGateway ] = None, - sync_model_endpoint_inference_gateway: Optional[ - SyncModelEndpointInferenceGateway - ] = None, - inference_autoscaling_metrics_gateway: Optional[ - InferenceAutoscalingMetricsGateway - ] = None, + sync_model_endpoint_inference_gateway: Optional[SyncModelEndpointInferenceGateway] = None, + inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway] = None, can_scale_http_endpoint_from_zero_flag: bool = True, ): if contents: @@ -1819,44 +1763,28 @@ def __init__( self.model_bundle_repository = model_bundle_repository if async_model_endpoint_inference_gateway is None: - async_model_endpoint_inference_gateway = ( - FakeAsyncModelEndpointInferenceGateway() - ) - self.async_model_endpoint_inference_gateway = ( - async_model_endpoint_inference_gateway - ) + async_model_endpoint_inference_gateway = FakeAsyncModelEndpointInferenceGateway() + self.async_model_endpoint_inference_gateway = async_model_endpoint_inference_gateway if streaming_model_endpoint_inference_gateway is None: streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - self.streaming_model_endpoint_inference_gateway = ( - streaming_model_endpoint_inference_gateway - ) + self.streaming_model_endpoint_inference_gateway = streaming_model_endpoint_inference_gateway if sync_model_endpoint_inference_gateway is None: - sync_model_endpoint_inference_gateway = ( - FakeSyncModelEndpointInferenceGateway() - ) - self.sync_model_endpoint_inference_gateway = ( - sync_model_endpoint_inference_gateway - ) + sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() + self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway if inference_autoscaling_metrics_gateway is None: - inference_autoscaling_metrics_gateway = ( - FakeInferenceAutoscalingMetricsGateway() - ) - self.inference_autoscaling_metrics_gateway = ( - inference_autoscaling_metrics_gateway - ) + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway self.model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway() ) - self.can_scale_http_endpoint_from_zero_flag = ( - can_scale_http_endpoint_from_zero_flag - ) + self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag def get_async_model_endpoint_inference_gateway( self, @@ -1920,9 +1848,7 @@ async def create_model_endpoint( endpoint_name=name, endpoint_type=endpoint_type, ) - current_model_bundle = await self.model_bundle_repository.get_model_bundle( - model_bundle_id - ) + current_model_bundle = await self.model_bundle_repository.get_model_bundle(model_bundle_id) assert current_model_bundle is not None model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -2013,9 +1939,7 @@ async def update_model_endpoint( task_expires_seconds: Optional[int] = None, public_inference: Optional[bool] = None, ) -> ModelEndpointRecord: - model_endpoint = await self.get_model_endpoint( - model_endpoint_id=model_endpoint_id - ) + model_endpoint = await self.get_model_endpoint(model_endpoint_id=model_endpoint_id) if model_endpoint is None: raise ObjectNotFoundException current_model_bundle = None @@ -2043,15 +1967,11 @@ async def update_model_endpoint( ) return model_endpoint.record - async def get_model_endpoint( - self, model_endpoint_id: str - ) -> Optional[ModelEndpoint]: + async def get_model_endpoint(self, model_endpoint_id: str) -> Optional[ModelEndpoint]: return self.db.get(model_endpoint_id) async def get_model_endpoints_schema(self, owner: str) -> ModelEndpointsSchema: - endpoints = await self.list_model_endpoints( - owner=owner, name=None, order_by=None - ) + endpoints = await self.list_model_endpoints(owner=owner, name=None, order_by=None) records = [endpoint.record for endpoint in endpoints] return self.model_endpoints_schema_gateway.get_model_endpoints_schema( model_endpoint_records=records @@ -2068,9 +1988,7 @@ async def get_model_endpoint_record( def _filter_by_name_owner( record: ModelEndpointRecord, owner: Optional[str], name: Optional[str] ): - return (not owner or record.owner == owner) and ( - not name or record.name == name - ) + return (not owner or record.owner == owner) and (not name or record.name == name) async def list_model_endpoints( self, @@ -2194,9 +2112,7 @@ def fake_docker_repository_image_never_exists() -> FakeDockerRepository: @pytest.fixture -def fake_docker_repository_image_never_exists_and_builds_dont_work() -> ( - FakeDockerRepository -): +def fake_docker_repository_image_never_exists_and_builds_dont_work() -> FakeDockerRepository: repo = FakeDockerRepository(image_always_exists=False, raises_error=True) return repo @@ -2226,9 +2142,7 @@ def fake_model_endpoint_record_repository() -> FakeModelEndpointRecordRepository @pytest.fixture -def fake_docker_image_batch_job_bundle_repository() -> ( - FakeDockerImageBatchJobBundleRepository -): +def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBundleRepository: repo = FakeDockerImageBatchJobBundleRepository() return repo @@ -2317,33 +2231,25 @@ def fake_model_primitive_gateway() -> FakeModelPrimitiveGateway: @pytest.fixture -def fake_async_model_endpoint_inference_gateway() -> ( - FakeAsyncModelEndpointInferenceGateway -): +def fake_async_model_endpoint_inference_gateway() -> FakeAsyncModelEndpointInferenceGateway: gateway = FakeAsyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_streaming_model_endpoint_inference_gateway() -> ( - FakeStreamingModelEndpointInferenceGateway -): +def fake_streaming_model_endpoint_inference_gateway() -> FakeStreamingModelEndpointInferenceGateway: gateway = FakeStreamingModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_sync_model_endpoint_inference_gateway() -> ( - FakeSyncModelEndpointInferenceGateway -): +def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferenceGateway: gateway = FakeSyncModelEndpointInferenceGateway() return gateway @pytest.fixture -def fake_inference_autoscaling_metrics_gateway() -> ( - FakeInferenceAutoscalingMetricsGateway -): +def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: gateway = FakeInferenceAutoscalingMetricsGateway() return gateway @@ -2440,18 +2346,14 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: model_endpoint_record_repository=fake_model_endpoint_record_repository, ) fake_model_endpoint_cache_repository = FakeModelEndpointCacheRepository() - async_model_endpoint_inference_gateway = ( - FakeAsyncModelEndpointInferenceGateway() - ) + async_model_endpoint_inference_gateway = FakeAsyncModelEndpointInferenceGateway() streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - sync_model_endpoint_inference_gateway = ( - FakeSyncModelEndpointInferenceGateway(fake_sync_inference_content) - ) - inference_autoscaling_metrics_gateway = ( - FakeInferenceAutoscalingMetricsGateway() + sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway( + fake_sync_inference_content ) + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway(), ) @@ -2479,10 +2381,8 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: ), ), ) - fake_docker_image_batch_job_bundle_repository = ( - FakeDockerImageBatchJobBundleRepository( - contents=fake_docker_image_batch_job_bundle_repository_contents - ) + fake_docker_image_batch_job_bundle_repository = FakeDockerImageBatchJobBundleRepository( + contents=fake_docker_image_batch_job_bundle_repository_contents ) fake_trigger_repository = FakeTriggerRepository( contents=fake_trigger_repository_contents @@ -2503,9 +2403,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_llm_fine_tuning_service_contents ) fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() - fake_file_storage_gateway = FakeFileStorageGateway( - fake_file_storage_gateway_contents - ) + fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) fake_tokenizer_repository = FakeTokenizerRepository() fake_streaming_storage_gateway = FakeStreamingStorageGateway() @@ -3085,9 +2983,7 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd @pytest.fixture -def model_endpoint_public( - test_api_key: str, model_bundle_1: ModelBundle -) -> ModelEndpoint: +def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_1", @@ -3154,9 +3050,7 @@ def model_endpoint_public( @pytest.fixture -def model_endpoint_public_sync( - test_api_key: str, model_bundle_1: ModelBundle -) -> ModelEndpoint: +def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_1", @@ -3223,9 +3117,7 @@ def model_endpoint_public_sync( @pytest.fixture -def model_endpoint_runnable( - test_api_key: str, model_bundle_4: ModelBundle -) -> ModelEndpoint: +def model_endpoint_runnable(test_api_key: str, model_bundle_4: ModelBundle) -> ModelEndpoint: # model_bundle_4 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -3283,9 +3175,7 @@ def model_endpoint_runnable( @pytest.fixture -def model_endpoint_streaming( - test_api_key: str, model_bundle_5: ModelBundle -) -> ModelEndpoint: +def model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: # model_bundle_5 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( @@ -3343,9 +3233,7 @@ def model_endpoint_streaming( @pytest.fixture -def model_endpoint_multinode( - test_api_key: str, model_bundle_1: ModelBundle -) -> ModelEndpoint: +def model_endpoint_multinode(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEndpoint: model_endpoint = ModelEndpoint( record=ModelEndpointRecord( id="test_model_endpoint_id_multinode", @@ -3411,9 +3299,7 @@ def model_endpoint_multinode( @pytest.fixture -def batch_job_1( - model_bundle_1: ModelBundle, model_endpoint_1: ModelEndpoint -) -> BatchJob: +def batch_job_1(model_bundle_1: ModelBundle, model_endpoint_1: ModelEndpoint) -> BatchJob: batch_job = BatchJob( record=BatchJobRecord( id="test_batch_job_id_1", @@ -3568,9 +3454,7 @@ def build_endpoint_request_async_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="username", password="password" - ) + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3615,9 +3499,7 @@ def build_endpoint_request_streaming_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="username", password="password" - ) + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3662,9 +3544,7 @@ def build_endpoint_request_sync_runnable_image( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="username", password="password" - ) + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3709,9 +3589,7 @@ def build_endpoint_request_sync_pytorch( broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="username", password="password" - ) + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3755,9 +3633,7 @@ def build_endpoint_request_async_tensorflow( optimize_costs=False, default_callback_url="https://example.com/path", default_callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="username", password="password" - ) + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -3908,9 +3784,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An args=["test_arg_1", "test_arg_2"], callback_url="http://test_callback_url.xyz", callback_auth=CallbackAuth( - root=CallbackBasicAuth( - kind="basic", username="test_username", password="test_password" - ) + root=CallbackBasicAuth(kind="basic", username="test_username", password="test_password") ), return_pickled=True, ) @@ -3919,9 +3793,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An @pytest.fixture -def sync_endpoint_predict_request_1() -> ( - Tuple[SyncEndpointPredictV1Request, Dict[str, Any]] -): +def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: request = SyncEndpointPredictV1Request( url="test_url", return_pickled=False, @@ -4749,9 +4621,7 @@ def llm_model_endpoint_sync_trt_llm( @pytest.fixture -def llm_model_endpoint_streaming( - test_api_key: str, model_bundle_5: ModelBundle -) -> ModelEndpoint: +def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> ModelEndpoint: # model_bundle_5 is a runnable bundle model_endpoint = ModelEndpoint( record=ModelEndpointRecord( diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 5d580828..326c97b8 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -139,9 +139,7 @@ async def test_create_model_endpoint_use_case_success( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, request=create_llm_model_endpoint_request_async - ) + response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) assert response_1.endpoint_creation_task_id assert isinstance(response_1, CreateLLMModelEndpointV1Response) endpoint = ( @@ -165,9 +163,7 @@ async def test_create_model_endpoint_use_case_success( } } - response_2 = await use_case.execute( - user=user, request=create_llm_model_endpoint_request_sync - ) + response_2 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_sync) assert response_2.endpoint_creation_task_id assert isinstance(response_2, CreateLLMModelEndpointV1Response) endpoint = ( @@ -225,10 +221,7 @@ async def test_create_model_endpoint_use_case_success( bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( owner=user.team_id, name=create_llm_model_endpoint_request_llama_2.name ) - assert ( - "--max-total-tokens" in bundle.flavor.command[-1] - and "4096" in bundle.flavor.command[-1] - ) + assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] response_5 = await use_case.execute( user=user, request=create_llm_model_endpoint_request_llama_3_70b @@ -313,9 +306,7 @@ async def test_create_model_bundle_fails_if_no_checkpoint( docker_repository=fake_docker_repository_image_always_exists, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - request = ( - create_llm_model_endpoint_text_generation_inference_request_streaming.copy() - ) + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() with pytest.raises(expected_error): await use_case.execute( @@ -376,18 +367,14 @@ async def test_create_model_bundle_inference_framework_image_tag_validation( llm_artifact_gateway=fake_llm_artifact_gateway, ) - request = ( - create_llm_model_endpoint_text_generation_inference_request_streaming.copy() - ) + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() request.inference_framework = inference_framework request.inference_framework_image_tag = inference_framework_image_tag user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) if valid: await use_case.execute(user=user, request=request) else: - llm_bundle_use_case.docker_repository = ( - fake_docker_repository_image_never_exists - ) + llm_bundle_use_case.docker_repository = fake_docker_repository_image_never_exists with pytest.raises(DockerImageNotFoundException): await use_case.execute(user=user, request=request) @@ -728,9 +715,7 @@ def test_load_model_files_sub_commands_trt_llm_gcs( ) checkpoint_path = "gs://fake-bucket/fake-checkpoint" - subcommands = llm_bundle_use_case.load_model_files_sub_commands_trt_llm( - checkpoint_path - ) + subcommands = llm_bundle_use_case.load_model_files_sub_commands_trt_llm(checkpoint_path) expected_result = [ "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" @@ -852,9 +837,7 @@ async def test_get_llm_model_endpoint_use_case_raises_not_found( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) with pytest.raises(ObjectNotFoundException): - await use_case.execute( - user=user, model_endpoint_name="invalid_model_endpoint_name" - ) + await use_case.execute(user=user, model_endpoint_name="invalid_model_endpoint_name") @pytest.mark.asyncio @@ -919,9 +902,7 @@ async def test_update_model_endpoint_use_case_success( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - await create_use_case.execute( - user=user, request=create_llm_model_endpoint_request_streaming - ) + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) endpoint = ( await fake_model_endpoint_service.list_model_endpoints( owner=None, @@ -957,10 +938,7 @@ async def test_update_model_endpoint_use_case_success( "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } - assert ( - endpoint.infra_state.resource_state.memory - == update_llm_model_endpoint_request.memory - ) + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory assert ( endpoint.infra_state.deployment_state.min_workers == update_llm_model_endpoint_request.min_workers @@ -996,10 +974,7 @@ async def test_update_model_endpoint_use_case_success( "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, } } - assert ( - endpoint.infra_state.resource_state.memory - == update_llm_model_endpoint_request.memory - ) + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory assert ( endpoint.infra_state.deployment_state.min_workers == update_llm_model_endpoint_request_only_workers.min_workers @@ -1053,9 +1028,7 @@ async def test_update_model_endpoint_use_case_failure( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - await create_use_case.execute( - user=user, request=create_llm_model_endpoint_request_streaming - ) + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) endpoint = ( await fake_model_endpoint_service.list_model_endpoints( owner=None, @@ -1184,9 +1157,7 @@ async def test_completion_sync_text_generation_inference_use_case_success( llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_sync_request: CompletionSyncV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint( - llm_model_endpoint_text_generation_inference - ) + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, result={"result": """ @@ -1418,9 +1389,7 @@ async def test_completion_sync_use_case_predict_failed_lightllm( llm_model_endpoint_sync_lightllm: Tuple[ModelEndpoint, Any], completion_sync_request: CompletionSyncV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint( - llm_model_endpoint_sync_lightllm[0] - ) + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_lightllm[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, @@ -1453,9 +1422,7 @@ async def test_completion_sync_use_case_predict_failed_trt_llm( completion_sync_request: CompletionSyncV1Request, ): completion_sync_request.return_token_log_probs = False # not yet supported - fake_llm_model_endpoint_service.add_model_endpoint( - llm_model_endpoint_sync_trt_llm[0] - ) + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_trt_llm[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( SyncEndpointPredictV1Response( status=TaskStatus.FAILURE, @@ -1546,9 +1513,7 @@ async def test_validate_and_update_completion_params(): return_token_log_probs=True, ) - validate_and_update_completion_params( - LLMInferenceFramework.VLLM, completion_sync_request - ) + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) validate_and_update_completion_params( LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request @@ -1566,9 +1531,7 @@ async def test_validate_and_update_completion_params(): completion_sync_request.guided_choice = [""] completion_sync_request.guided_grammar = "" with pytest.raises(ObjectHasInvalidValueException): - validate_and_update_completion_params( - LLMInferenceFramework.VLLM, completion_sync_request - ) + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) completion_sync_request.guided_regex = None completion_sync_request.guided_choice = None @@ -1794,9 +1757,7 @@ async def test_completion_stream_text_generation_inference_use_case_success( llm_model_endpoint_text_generation_inference: ModelEndpoint, completion_stream_request: CompletionStreamV1Request, ): - fake_llm_model_endpoint_service.add_model_endpoint( - llm_model_endpoint_text_generation_inference - ) + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, @@ -1825,9 +1786,7 @@ async def test_completion_stream_text_generation_inference_use_case_success( ), SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": {"token": {"text": "."}, "generated_text": "I am a newbie."} - }, + result={"result": {"token": {"text": "."}, "generated_text": "I am a newbie."}}, traceback=None, ), ] @@ -2087,9 +2046,7 @@ async def test_get_fine_tune_events_success( llm_fine_tuning_service=fake_llm_fine_tuning_service, ) response_2 = await use_case.execute(user=user, fine_tune_id=response.id) - assert len(response_2.events) == len( - fake_llm_fine_tuning_events_repository.all_events_list - ) + assert len(response_2.events) == len(fake_llm_fine_tuning_events_repository.all_events_list) @pytest.mark.asyncio @@ -2157,10 +2114,8 @@ async def test_delete_model_success( response = await use_case.execute( user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name ) - remaining_endpoint_model_service = ( - await fake_model_endpoint_service.get_model_endpoint( - llm_model_endpoint_sync[0].record.id - ) + remaining_endpoint_model_service = await fake_model_endpoint_service.get_model_endpoint( + llm_model_endpoint_sync[0].record.id ) assert remaining_endpoint_model_service is None assert response.deleted is True @@ -2309,9 +2264,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 102400, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "") assert hardware.cpus == 160 assert hardware.gpus == 8 assert hardware.memory == "800Gi" @@ -2439,9 +2392,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32064, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "") assert hardware.cpus == 5 assert hardware.gpus == 1 assert hardware.memory == "20Gi" @@ -2501,9 +2452,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 100352, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "") print(hardware) assert hardware.cpus == 5 assert hardware.gpus == 1 @@ -2553,9 +2502,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "vocab_size": 32064, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "") assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2684,9 +2631,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB assert hardware.nodes_per_worker == 1 - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2723,9 +2668,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB assert hardware.nodes_per_worker == 1 - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) assert hardware.cpus == 10 assert hardware.gpus == 1 assert hardware.memory == "40Gi" @@ -2908,9 +2851,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "transformers_version": "4.41.0.dev0", "vocab_size": 128256, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") assert hardware.cpus == 40 assert hardware.gpus == 2 assert hardware.memory == "160Gi" @@ -2943,9 +2884,7 @@ async def test_infer_hardware(fake_llm_artifact_gateway): "use_sliding_window": False, "vocab_size": 152064, } - hardware = await _infer_hardware( - fake_llm_artifact_gateway, "qwen2-72b-instruct", "" - ) + hardware = await _infer_hardware(fake_llm_artifact_gateway, "qwen2-72b-instruct", "") assert hardware.cpus == 80 assert hardware.gpus == 4 assert hardware.memory == "320Gi" @@ -3027,9 +2966,7 @@ async def test_create_batch_completions_v1( user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) result = await use_case.execute(user, create_batch_completions_v1_request) - job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job( - result.job_id - ) + job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job(result.job_id) assert job.num_workers == create_batch_completions_v1_request.data_parallelism bundle = list(fake_docker_image_batch_job_bundle_repository.db.values())[0] @@ -3185,9 +3122,7 @@ def test_merge_metadata(): def test_validate_chat_template(): assert validate_chat_template(None, LLMInferenceFramework.DEEPSPEED) is None good_chat_template = CHAT_TEMPLATE_MAX_LENGTH * "_" - assert ( - validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None - ) + assert validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None bad_chat_template = (CHAT_TEMPLATE_MAX_LENGTH + 1) * "_" with pytest.raises(ObjectHasInvalidValueException): From eab9b601c6e2079d5a2ef7e7a932c8fe45d971e6 Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 20:22:43 -0400 Subject: [PATCH 06/12] fix(gcp): use gcloud storage rsync instead of cp for --include support --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 4 ++-- model-engine/tests/unit/domain/test_llm_use_cases.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 3b61965f..9682098f 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -746,8 +746,8 @@ def load_model_weights_sub_commands_gcs( file_selection_str += ' --include="*.py"' subcommands.append( - f"/opt/google-cloud-sdk/bin/gcloud storage cp -r" - f" {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + f"/opt/google-cloud-sdk/bin/gcloud storage rsync -r" + f" {file_selection_str} {checkpoint_path} {final_weights_folder}" ) return subcommands diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 326c97b8..fd236005 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -672,7 +672,10 @@ def test_load_model_weights_sub_commands( "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" " | tar -xz -C /opt" " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", - '/opt/google-cloud-sdk/bin/gcloud storage cp -r --include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*" gs://fake-bucket/fake-checkpoint/* test_folder', + "/opt/google-cloud-sdk/bin/gcloud storage rsync -r" + ' --include="*.model" --include="*.model.v*" --include="*.json"' + ' --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' + " gs://fake-bucket/fake-checkpoint test_folder", ] assert expected_result == subcommands @@ -689,7 +692,10 @@ def test_load_model_weights_sub_commands( "curl -sSL https://dl.google.com/dl/cloudsdk/channels/rapid/google-cloud-sdk.tar.gz" " | tar -xz -C /opt" " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", - '/opt/google-cloud-sdk/bin/gcloud storage cp -r --include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*" --include="*.py" gs://fake-bucket/fake-checkpoint/* test_folder', + "/opt/google-cloud-sdk/bin/gcloud storage rsync -r" + ' --include="*.model" --include="*.model.v*" --include="*.json"' + ' --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' + ' --include="*.py" gs://fake-bucket/fake-checkpoint test_folder', ] assert expected_result == subcommands From 0b5bde340b44d043e1bf172344b2c212af4c6415 Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 20:28:24 -0400 Subject: [PATCH 07/12] lint --- .../model_engine_server/core/docker/docker_image.py | 8 ++------ model-engine/tests/unit/api/test_llms.py | 6 ++---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index 8d68f8c8..70f5f0e5 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -144,14 +144,10 @@ def build( ) if test_command: - logger.info( - textwrap.dedent( - f""" + logger.info(textwrap.dedent(f""" Testing with 'docker run' on the built image. ARGS: {test_command} - (NOTE: Expecting the test command to terminate. """ - ) - ) + (NOTE: Expecting the test command to terminate. """)) home_dir = str(pathlib.Path.home()) output = docker_client.containers.run( # pylint:disable=no-member image=image, diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 163466c2..0194e7a2 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -108,13 +108,11 @@ def test_completion_sync_success( fake_docker_image_batch_job_bundle_repository_contents={}, fake_sync_inference_content=SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={ - "result": """{ + result={"result": """{ "text": "output", "count_prompt_tokens": 1, "count_output_tokens": 1 - }""" - }, + }"""}, traceback=None, status_code=200, ), From f3025a6563d55f226ef5483a7d0c75980df4a1d4 Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 20:57:28 -0400 Subject: [PATCH 08/12] fix: black 24.8.0 formatting --- .../model_engine_server/core/docker/docker_image.py | 8 ++++++-- model-engine/tests/unit/api/test_llms.py | 6 ++++-- model-engine/tests/unit/domain/test_llm_use_cases.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/model-engine/model_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py index 70f5f0e5..8d68f8c8 100644 --- a/model-engine/model_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -144,10 +144,14 @@ def build( ) if test_command: - logger.info(textwrap.dedent(f""" + logger.info( + textwrap.dedent( + f""" Testing with 'docker run' on the built image. ARGS: {test_command} - (NOTE: Expecting the test command to terminate. """)) + (NOTE: Expecting the test command to terminate. """ + ) + ) home_dir = str(pathlib.Path.home()) output = docker_client.containers.run( # pylint:disable=no-member image=image, diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 0194e7a2..163466c2 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -108,11 +108,13 @@ def test_completion_sync_success( fake_docker_image_batch_job_bundle_repository_contents={}, fake_sync_inference_content=SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """{ + result={ + "result": """{ "text": "output", "count_prompt_tokens": 1, "count_output_tokens": 1 - }"""}, + }""" + }, traceback=None, status_code=200, ), diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index fd236005..81b6c472 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -1166,7 +1166,8 @@ async def test_completion_sync_text_generation_inference_use_case_success( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """ + result={ + "result": """ { "generated_text": " Deep Learning is a new type of machine learning", "details": { @@ -1234,7 +1235,8 @@ async def test_completion_sync_text_generation_inference_use_case_success( ] } } -"""}, +""" + }, traceback=None, status_code=200, ) @@ -1463,12 +1465,14 @@ async def test_completion_sync_use_case_predict_failed_with_errors( fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( status=TaskStatus.SUCCESS, - result={"result": """ + result={ + "result": """ { "error": "Request failed during generation: Server error: transport error", "error_type": "generation" } -"""}, +""" + }, traceback="failed to predict", status_code=500, ) From 8bc986bb61487de9a4a5259ad6be8b9524469e2e Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 21:01:55 -0400 Subject: [PATCH 09/12] fix: isort formatting --- .../use_cases/llm_model_endpoint_use_cases.py | 19 ++++--------------- model-engine/tests/unit/conftest.py | 18 ++++-------------- .../tests/unit/domain/test_llm_use_cases.py | 13 +++---------- 3 files changed, 11 insertions(+), 39 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 9682098f..e201baee 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -16,9 +16,7 @@ import yaml from model_engine_server.common.config import hmi_config -from model_engine_server.common.dtos.batch_jobs import ( - CreateDockerImageBatchJobResourceRequests, -) +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( ChatCompletionV2Request, ChatCompletionV2StreamSuccessChunk, @@ -57,16 +55,10 @@ CompletionV2SyncResponse, ) from model_engine_server.common.dtos.llms.sglang import SGLangEndpointAdditionalArgs -from model_engine_server.common.dtos.llms.vllm import ( - VLLMEndpointAdditionalArgs, - VLLMModelConfig, -) +from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs, VLLMModelConfig from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from model_engine_server.common.dtos.tasks import ( - SyncEndpointPredictV1Request, - TaskStatus, -) +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus from model_engine_server.common.resource_limits import validate_resource_requests from model_engine_server.core.auth.authentication_repository import User from model_engine_server.core.config import infra_config @@ -120,10 +112,7 @@ ModelBundleRepository, TokenizerRepository, ) -from model_engine_server.domain.services import ( - LLMModelEndpointService, - ModelEndpointService, -) +from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService from model_engine_server.domain.services.llm_batch_completions_service import ( LLMBatchCompletionsService, ) diff --git a/model-engine/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py index 4f6b0301..903d36bc 100644 --- a/model-engine/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -21,13 +21,8 @@ import pytest from model_engine_server.api.dependencies import ExternalInterfaces from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from model_engine_server.common.dtos.batch_jobs import ( - CreateDockerImageBatchJobResourceRequests, -) -from model_engine_server.common.dtos.docker_repository import ( - BuildImageRequest, - BuildImageResponse, -) +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy from model_engine_server.common.dtos.model_endpoints import ( @@ -37,9 +32,7 @@ ModelEndpointOrderBy, StorageSpecificationType, ) -from model_engine_server.common.dtos.resource_manager import ( - CreateOrUpdateResourcesRequest, -) +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, @@ -159,10 +152,7 @@ translate_kwargs_to_model_bundle_orm, translate_model_bundle_orm_to_model_bundle, ) -from model_engine_server.infra.services import ( - LiveBatchJobService, - LiveModelEndpointService, -) +from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService from model_engine_server.infra.services.fake_llm_batch_completions_service import ( FakeLLMBatchCompletionsService, ) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 81b6c472..26dc4ceb 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -3,9 +3,7 @@ from unittest import mock import pytest -from model_engine_server.common.dtos.batch_jobs import ( - CreateDockerImageBatchJobResourceRequests, -) +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests from model_engine_server.common.dtos.llms import ( CompletionOutput, CompletionStreamV1Request, @@ -22,10 +20,7 @@ CreateBatchCompletionsEngineRequest, CreateBatchCompletionsV2Request, ) -from model_engine_server.common.dtos.tasks import ( - SyncEndpointPredictV1Response, - TaskStatus, -) +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus from model_engine_server.core.auth.authentication_repository import User from model_engine_server.domain.entities import ( LLMInferenceFramework, @@ -71,9 +66,7 @@ validate_checkpoint_files, validate_checkpoint_path_uri, ) -from model_engine_server.domain.use_cases.model_bundle_use_cases import ( - CreateModelBundleV2UseCase, -) +from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase from ..conftest import mocked__get_recommended_hardware_config_map from .conftest import CreateLLMModelEndpointV1Request_gen From d9f8d76f78014696e8402d35ccb143f5979e296a Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 21:40:26 -0400 Subject: [PATCH 10/12] fix(gcp): remove unsupported --include flags from gcloud storage rsync gcloud storage rsync only supports --exclude (Python regex), not --include. The --include flags caused rsync to fail on GCP model weight downloads. Co-Authored-By: Claude Opus 4.6 --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index e201baee..7151f233 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -730,9 +730,9 @@ def load_model_weights_sub_commands_gcs( " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null" ) - file_selection_str = '--include="*.model" --include="*.model.v*" --include="*.json" --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' - if trust_remote_code: - file_selection_str += ' --include="*.py"' + # gcloud storage rsync only supports --exclude (Python regex), not --include. + # Exclude optimizer files; all other files in the checkpoint path are synced. + file_selection_str = '--exclude="optimizer.*"' subcommands.append( f"/opt/google-cloud-sdk/bin/gcloud storage rsync -r" From ed690627ec6c78b9e34ecb1fe6dc4b4e517f3c7c Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Thu, 2 Apr 2026 21:50:20 -0400 Subject: [PATCH 11/12] fix(gcp): exclude .py files from gcloud rsync when trust_remote_code is false --- .../domain/use_cases/llm_model_endpoint_use_cases.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 7151f233..bf15c10e 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -731,8 +731,10 @@ def load_model_weights_sub_commands_gcs( ) # gcloud storage rsync only supports --exclude (Python regex), not --include. - # Exclude optimizer files; all other files in the checkpoint path are synced. - file_selection_str = '--exclude="optimizer.*"' + excludes = ['--exclude="optimizer.*"'] + if not trust_remote_code: + excludes.append('--exclude=".*\\.py$"') + file_selection_str = " ".join(excludes) subcommands.append( f"/opt/google-cloud-sdk/bin/gcloud storage rsync -r" From afcb6dee7d288c85389a45d6f19e0f1b0cecae05 Mon Sep 17 00:00:00 2001 From: Arnav Chopra Date: Sun, 5 Apr 2026 18:47:28 -0400 Subject: [PATCH 12/12] fix(tests): update GCS unit tests to match --exclude-only rsync flags Co-Authored-By: Claude Opus 4.6 --- model-engine/tests/unit/domain/test_llm_use_cases.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py index 26dc4ceb..e254aac1 100644 --- a/model-engine/tests/unit/domain/test_llm_use_cases.py +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -666,8 +666,7 @@ def test_load_model_weights_sub_commands( " | tar -xz -C /opt" " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", "/opt/google-cloud-sdk/bin/gcloud storage rsync -r" - ' --include="*.model" --include="*.model.v*" --include="*.json"' - ' --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' + ' --exclude="optimizer.*" --exclude=".*\\.py$"' " gs://fake-bucket/fake-checkpoint test_folder", ] assert expected_result == subcommands @@ -686,9 +685,8 @@ def test_load_model_weights_sub_commands( " | tar -xz -C /opt" " && /opt/google-cloud-sdk/bin/gcloud config set disable_usage_reporting true 2>/dev/null", "/opt/google-cloud-sdk/bin/gcloud storage rsync -r" - ' --include="*.model" --include="*.model.v*" --include="*.json"' - ' --include="*.safetensors" --include="*.txt" --exclude="optimizer*"' - ' --include="*.py" gs://fake-bucket/fake-checkpoint test_folder', + ' --exclude="optimizer.*"' + " gs://fake-bucket/fake-checkpoint test_folder", ] assert expected_result == subcommands