diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index d53981ca43..4bf718b050 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -21,6 +21,7 @@ from sagemaker.core.helper.session_helper import Session from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.train.common_utils.finetune_utils import _resolve_mlflow_resource_arn # Module-level logger _logger = logging.getLogger(__name__) @@ -163,8 +164,6 @@ def _resolve_dataset(cls, v): @validator('mlflow_resource_arn', pre=True, always=True) def _resolve_mlflow_arn(cls, v, values): """Resolve MLflow resource ARN using default experience logic if not provided.""" - from ..common_utils.finetune_utils import _resolve_mlflow_resource_arn - # Get sagemaker_session from values sagemaker_session = values.get('sagemaker_session') if sagemaker_session is None: @@ -709,6 +708,10 @@ def _get_base_template_context( Returns: dict: Base template context dictionary """ + # Resolve MLflow ARN if not already resolved (e.g. session was None at construction time) + if not self.mlflow_resource_arn and self.sagemaker_session: + self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session) + # Generate default mlflow_experiment_name if not provided # This is required by AWS when ModelPackageGroupArn is not provided in training jobs mlflow_experiment_name = self.mlflow_experiment_name diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 55ded87bd7..5676add83a 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -925,6 +925,37 @@ def test_get_base_template_context(self, mock_resolve, mock_session, mock_model_ assert context['dataset_artifact_arn'] == DEFAULT_ARTIFACT_ARN assert 'action_arn_prefix' in context + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") + @patch("sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn") + def test_get_base_template_context_deferred_mlflow_resolution(self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info): + """Test that mlflow_resource_arn is resolved in _get_base_template_context when session was None at construction.""" + mock_resolve.return_value = mock_model_info + # Validator returns None because session was None at construction time + mock_resolve_mlflow.return_value = None + + evaluator = BaseEvaluator( + model=DEFAULT_MODEL, + s3_output_path=DEFAULT_S3_OUTPUT, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + # Simulate the case where ARN was not resolved at construction (session was None) + evaluator.mlflow_resource_arn = None + + resolved_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/deferred" + mock_resolve_mlflow.return_value = resolved_arn + + context = evaluator._get_base_template_context( + role_arn=DEFAULT_ROLE_ARN, + region=DEFAULT_REGION, + account_id="123456789012", + model_package_group_arn=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + resolved_model_artifact_arn=DEFAULT_ARTIFACT_ARN, + ) + + assert context['mlflow_resource_arn'] == resolved_arn + mock_resolve_mlflow.assert_called_with(mock_session) + class TestResolveModelArtifacts: """Tests for model artifacts resolution."""