-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix(eval): resolve mlflow_resource_arn in _get_base_template_context #5758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from sagemaker.core.helper.session_helper import Session | ||
|
|
||
| from sagemaker.train.base_trainer import BaseTrainer | ||
| from sagemaker.train.common_utils.finetune_utils import _resolve_mlflow_resource_arn | ||
| # Module-level logger | ||
| _logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -163,8 +164,6 @@ def _resolve_dataset(cls, v): | |
| @validator('mlflow_resource_arn', pre=True, always=True) | ||
| def _resolve_mlflow_arn(cls, v, values): | ||
| """Resolve MLflow resource ARN using default experience logic if not provided.""" | ||
| from ..common_utils.finetune_utils import _resolve_mlflow_resource_arn | ||
|
|
||
| # Get sagemaker_session from values | ||
| sagemaker_session = values.get('sagemaker_session') | ||
| if sagemaker_session is None: | ||
|
|
@@ -709,6 +708,10 @@ def _get_base_template_context( | |
| Returns: | ||
| dict: Base template context dictionary | ||
| """ | ||
| # Resolve MLflow ARN if not already resolved (e.g. session was None at construction time) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor consideration: if not self.mlflow_resource_arn and self.sagemaker_session:
try:
self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session)
except Exception:
_logger.warning("Failed to resolve MLflow resource ARN during deferred resolution.") |
||
| if not self.mlflow_resource_arn and self.sagemaker_session: | ||
| self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session) | ||
|
|
||
| # Generate default mlflow_experiment_name if not provided | ||
| # This is required by AWS when ModelPackageGroupArn is not provided in training jobs | ||
| mlflow_experiment_name = self.mlflow_experiment_name | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -925,6 +925,37 @@ def test_get_base_template_context(self, mock_resolve, mock_session, mock_model_ | |
| assert context['dataset_artifact_arn'] == DEFAULT_ARTIFACT_ARN | ||
| assert 'action_arn_prefix' in context | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is well over 100 characters (the project's line length limit). Please wrap it: @patch(
"sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn"
)
def test_get_base_template_context_deferred_mlflow_resolution(
self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info
): |
||
| @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") | ||
| @patch("sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn") | ||
| def test_get_base_template_context_deferred_mlflow_resolution(self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info): | ||
| """Test that mlflow_resource_arn is resolved in _get_base_template_context when session was None at construction.""" | ||
| mock_resolve.return_value = mock_model_info | ||
| # Validator returns None because session was None at construction time | ||
| mock_resolve_mlflow.return_value = None | ||
|
|
||
| evaluator = BaseEvaluator( | ||
| model=DEFAULT_MODEL, | ||
| s3_output_path=DEFAULT_S3_OUTPUT, | ||
| model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, | ||
| sagemaker_session=mock_session, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting |
||
| ) | ||
| # Simulate the case where ARN was not resolved at construction (session was None) | ||
| evaluator.mlflow_resource_arn = None | ||
|
|
||
| resolved_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/deferred" | ||
| mock_resolve_mlflow.return_value = resolved_arn | ||
|
|
||
| context = evaluator._get_base_template_context( | ||
| role_arn=DEFAULT_ROLE_ARN, | ||
| region=DEFAULT_REGION, | ||
| account_id="123456789012", | ||
| model_package_group_arn=DEFAULT_MODEL_PACKAGE_GROUP_ARN, | ||
| resolved_model_artifact_arn=DEFAULT_ARTIFACT_ARN, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using |
||
| ) | ||
|
|
||
| assert context['mlflow_resource_arn'] == resolved_arn | ||
| mock_resolve_mlflow.assert_called_with(mock_session) | ||
|
|
||
|
|
||
| class TestResolveModelArtifacts: | ||
| """Tests for model artifacts resolution.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing blank line between this import and the
# Module-level loggercomment. PEP 8 and the existing file style expect a blank line separating import groups from module-level code.