diff --git a/sagemaker-serve/tests/integ/test_model_customization_deployment.py b/sagemaker-serve/tests/integ/test_model_customization_deployment.py index b38ca249c7..21eef47558 100644 --- a/sagemaker-serve/tests/integ/test_model_customization_deployment.py +++ b/sagemaker-serve/tests/integ/test_model_customization_deployment.py @@ -14,13 +14,20 @@ from __future__ import absolute_import import boto3 +import json +import logging +import os +import time import pytest import random +logger = logging.getLogger(__name__) + from sagemaker.core.helper.session_helper import Session # This test relies on resources in a specific region AWS_REGION = "us-west-2" +os.environ.setdefault("AWS_DEFAULT_REGION", AWS_REGION) @pytest.fixture(scope="module") @@ -135,6 +142,38 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION) assert adapter_ic is not None + # Invoke verification + time.sleep(10) # brief buffer for IC readiness + + invoke_ic_name = adapter_name if peft_type == "LORA" else f"{endpoint_name}-inference-component" + + test_payload = { + "inputs": "What is machine learning?", + "parameters": {"max_new_tokens": 32}, + } + + invoke_response = endpoint.invoke( + body=json.dumps(test_payload), + content_type="application/json", + accept="application/json", + inference_component_name=invoke_ic_name, + ) + + response_body = json.loads(invoke_response.body.read()) + + # Validate response structure + assert response_body is not None, f"Empty response from invoke on {invoke_ic_name}" + if isinstance(response_body, list): + assert len(response_body) > 0 + assert "generated_text" in response_body[0] or "generation" in response_body[0] + elif isinstance(response_body, dict): + assert ( + "generated_text" in response_body + or "generation" in response_body + or "outputs" in response_body + ) + + def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session): """Test fetching endpoint names for base model.""" from sagemaker.core.resources import TrainingJob @@ -316,7 +355,7 @@ def setup_config(self, training_job_name): from sagemaker.core.helper.session_helper import get_execution_role return { "training_job_name": training_job_name, - "region": "us-west-2", + "region": AWS_REGION, "bucket": "models-sdk-testing-pdx", "role_arn": get_execution_role() } @@ -381,7 +420,7 @@ def deployed_model_arn(self, training_job, bedrock_client, s3_client, setup_conf break time.sleep(30) - model_arn = response['importedModelName'] + model_arn = response['importedModelArn'] return model_arn except Exception as e: @@ -504,6 +543,50 @@ def test_bedrock_job_created(self, deployed_model_arn): """Test that Bedrock import job was created successfully.""" assert deployed_model_arn is not None + @pytest.mark.slow + def test_bedrock_model_invoke(self, deployed_model_arn, bedrock_runtime): + """Test invoking the imported Bedrock model to ensure it works end-to-end. + + Retries on failure since models can take several minutes + to become ready after import. + """ + max_retries = 5 + base_delay = 10 + + for attempt in range(max_retries): + try: + response = bedrock_runtime.invoke_model( + modelId=deployed_model_arn, + body=json.dumps({ + "prompt": "What is the capital of France?", + "max_gen_len": 100, + "temperature": 0.7, + "top_p": 0.9 + }) + ) + + result = json.loads(response['body'].read().decode()) + + # Validate response structure + assert "generation" in result, "Response missing 'generation' field" + assert isinstance(result["generation"], str), "'generation' should be a string" + assert len(result["generation"]) > 0, "'generation' should not be empty" + return # Success + + except Exception as e: + if attempt < max_retries - 1: + logger.info( + f"Invoke failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {base_delay}s..." + ) + time.sleep(base_delay) + else: + pytest.fail( + f"Invoke failed after {max_retries} attempts. " + f"Last error: {e}" + ) + + def test_zzz_cleanup_deployed_model(self, bedrock_client): """Cleanup deployed model and import jobs (runs last due to zzz prefix).""" if hasattr(self, 'model_arn_for_cleanup'):