Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions sagemaker-serve/tests/integ/test_model_customization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
from __future__ import absolute_import

import boto3
import json
import logging
import os
import time
import pytest
import random

logger = logging.getLogger(__name__)

from sagemaker.core.helper.session_helper import Session

# This test relies on resources in a specific region
AWS_REGION = "us-west-2"
os.environ.setdefault("AWS_DEFAULT_REGION", AWS_REGION)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -135,6 +142,38 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION)
assert adapter_ic is not None

# Invoke verification
time.sleep(10) # brief buffer for IC readiness

invoke_ic_name = adapter_name if peft_type == "LORA" else f"{endpoint_name}-inference-component"

test_payload = {
"inputs": "What is machine learning?",
"parameters": {"max_new_tokens": 32},
}

invoke_response = endpoint.invoke(
body=json.dumps(test_payload),
content_type="application/json",
accept="application/json",
inference_component_name=invoke_ic_name,
)

response_body = json.loads(invoke_response.body.read())

# Validate response structure
assert response_body is not None, f"Empty response from invoke on {invoke_ic_name}"
if isinstance(response_body, list):
assert len(response_body) > 0
assert "generated_text" in response_body[0] or "generation" in response_body[0]
elif isinstance(response_body, dict):
assert (
"generated_text" in response_body
or "generation" in response_body
or "outputs" in response_body
)


def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session):
"""Test fetching endpoint names for base model."""
from sagemaker.core.resources import TrainingJob
Expand Down Expand Up @@ -316,7 +355,7 @@ def setup_config(self, training_job_name):
from sagemaker.core.helper.session_helper import get_execution_role
return {
"training_job_name": training_job_name,
"region": "us-west-2",
"region": AWS_REGION,
"bucket": "models-sdk-testing-pdx",
"role_arn": get_execution_role()
}
Expand Down Expand Up @@ -381,7 +420,7 @@ def deployed_model_arn(self, training_job, bedrock_client, s3_client, setup_conf
break
time.sleep(30)

model_arn = response['importedModelName']
model_arn = response['importedModelArn']
return model_arn

except Exception as e:
Expand Down Expand Up @@ -504,6 +543,50 @@ def test_bedrock_job_created(self, deployed_model_arn):
"""Test that Bedrock import job was created successfully."""
assert deployed_model_arn is not None

@pytest.mark.slow
def test_bedrock_model_invoke(self, deployed_model_arn, bedrock_runtime):
"""Test invoking the imported Bedrock model to ensure it works end-to-end.

Retries on failure since models can take several minutes
to become ready after import.
"""
max_retries = 5
base_delay = 10

for attempt in range(max_retries):
try:
response = bedrock_runtime.invoke_model(
modelId=deployed_model_arn,
body=json.dumps({
"prompt": "What is the capital of France?",
"max_gen_len": 100,
"temperature": 0.7,
"top_p": 0.9
})
)

result = json.loads(response['body'].read().decode())

# Validate response structure
assert "generation" in result, "Response missing 'generation' field"
assert isinstance(result["generation"], str), "'generation' should be a string"
assert len(result["generation"]) > 0, "'generation' should not be empty"
return # Success

except Exception as e:
if attempt < max_retries - 1:
logger.info(
f"Invoke failed (attempt {attempt + 1}/{max_retries}): {e}. "
f"Retrying in {base_delay}s..."
)
time.sleep(base_delay)
else:
pytest.fail(
f"Invoke failed after {max_retries} attempts. "
f"Last error: {e}"
)


def test_zzz_cleanup_deployed_model(self, bedrock_client):
"""Cleanup deployed model and import jobs (runs last due to zzz prefix)."""
if hasattr(self, 'model_arn_for_cleanup'):
Expand Down
Loading