Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.core.helper.session_helper import Session

from sagemaker.train.base_trainer import BaseTrainer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing blank line between this import and the # Module-level logger comment. PEP 8 and the existing file style expect a blank line separating import groups from module-level code.

from sagemaker.train.common_utils.finetune_utils import _resolve_mlflow_resource_arn

# Module-level logger

from sagemaker.train.common_utils.finetune_utils import _resolve_mlflow_resource_arn
# Module-level logger
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -163,8 +164,6 @@ def _resolve_dataset(cls, v):
@validator('mlflow_resource_arn', pre=True, always=True)
def _resolve_mlflow_arn(cls, v, values):
"""Resolve MLflow resource ARN using default experience logic if not provided."""
from ..common_utils.finetune_utils import _resolve_mlflow_resource_arn

# Get sagemaker_session from values
sagemaker_session = values.get('sagemaker_session')
if sagemaker_session is None:
Expand Down Expand Up @@ -709,6 +708,10 @@ def _get_base_template_context(
Returns:
dict: Base template context dictionary
"""
# Resolve MLflow ARN if not already resolved (e.g. session was None at construction time)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor consideration: _resolve_mlflow_resource_arn may raise exceptions (e.g., if the SageMaker API call fails). In the validator, failures are silently swallowed because the validator just returns None. Here, an exception would propagate up and fail evaluate(). Is that the desired behavior? If not, consider wrapping this in a try/except with a warning log, consistent with how the validator handles failures:

if not self.mlflow_resource_arn and self.sagemaker_session:
    try:
        self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session)
    except Exception:
        _logger.warning("Failed to resolve MLflow resource ARN during deferred resolution.")

if not self.mlflow_resource_arn and self.sagemaker_session:
self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session)

# Generate default mlflow_experiment_name if not provided
# This is required by AWS when ModelPackageGroupArn is not provided in training jobs
mlflow_experiment_name = self.mlflow_experiment_name
Expand Down
31 changes: 31 additions & 0 deletions sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,37 @@ def test_get_base_template_context(self, mock_resolve, mock_session, mock_model_
assert context['dataset_artifact_arn'] == DEFAULT_ARTIFACT_ARN
assert 'action_arn_prefix' in context

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is well over 100 characters (the project's line length limit). Please wrap it:

    @patch(
        "sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn"
    )
    def test_get_base_template_context_deferred_mlflow_resolution(
        self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info
    ):

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
@patch("sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn")
def test_get_base_template_context_deferred_mlflow_resolution(self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info):
"""Test that mlflow_resource_arn is resolved in _get_base_template_context when session was None at construction."""
mock_resolve.return_value = mock_model_info
# Validator returns None because session was None at construction time
mock_resolve_mlflow.return_value = None

evaluator = BaseEvaluator(
model=DEFAULT_MODEL,
s3_output_path=DEFAULT_S3_OUTPUT,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting evaluator.mlflow_resource_arn = None directly after construction to simulate the deferred case works, but it's a bit fragile — it relies on Pydantic allowing direct attribute mutation. A comment explaining why this is necessary (or configuring the mock to return None only during construction) would improve clarity. Actually, you already set mock_resolve_mlflow.return_value = None before construction, so the validator should already leave it as None. Is this explicit assignment redundant? If so, removing it would simplify the test.

)
# Simulate the case where ARN was not resolved at construction (session was None)
evaluator.mlflow_resource_arn = None

resolved_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/deferred"
mock_resolve_mlflow.return_value = resolved_arn

context = evaluator._get_base_template_context(
role_arn=DEFAULT_ROLE_ARN,
region=DEFAULT_REGION,
account_id="123456789012",
model_package_group_arn=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
resolved_model_artifact_arn=DEFAULT_ARTIFACT_ARN,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using mock_resolve_mlflow.assert_called_once_with(mock_session) instead of assert_called_with to verify it was called exactly once during _get_base_template_context. Since the mock is also invoked by the validator during construction, you may want to reset the mock after construction (mock_resolve_mlflow.reset_mock()) before calling _get_base_template_context, then assert assert_called_once_with. This makes the test more precise about which call path triggered the resolution.

)

assert context['mlflow_resource_arn'] == resolved_arn
mock_resolve_mlflow.assert_called_with(mock_session)


class TestResolveModelArtifacts:
"""Tests for model artifacts resolution."""
Expand Down
Loading