From 6c6b0a6340fc07af319a34474954a27fb61b2478 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:18:36 -0400 Subject: [PATCH 1/4] fix: Enhancements Needed for Secure Tar Extraction (5560) --- .../src/sagemaker/core/common_utils.py | 13 +- .../src/sagemaker/core/utils/__init__.py | 4 + .../unit/test_common_utils_tar_safety.py | 243 ++++++++++++++++++ 3 files changed, 254 insertions(+), 6 deletions(-) create mode 100644 sagemaker-core/tests/unit/test_common_utils_tar_safety.py diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index b8d9ca6866..2cfd58044c 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -1688,7 +1688,8 @@ def _is_bad_path(path, base): bool: True if the path is not rooted under the base directory, False otherwise. """ # joinpath will ignore base if path is absolute - return not _get_resolved_path(joinpath(base, path)).startswith(base) + resolved = _get_resolved_path(joinpath(base, path)) + return os.path.commonpath([resolved, base]) != base def _is_bad_link(info, base): @@ -1708,19 +1709,18 @@ def _is_bad_link(info, base): return _is_bad_path(info.linkname, base=tip) -def _get_safe_members(members): +def _get_safe_members(members, base): """A generator that yields members that are safe to extract. It filters out bad paths and bad links. Args: - members (list): A list of members to check. + members (list): A list of TarInfo members to check. + base (str): The base directory for extraction. Yields: tarfile.TarInfo: The tar file info. """ - base = _get_resolved_path("") - for file_info in members: if _is_bad_path(file_info.name, base): logger.error("%s is blocked (illegal path)", file_info.name) @@ -1783,7 +1783,8 @@ def custom_extractall_tarfile(tar, extract_path): if hasattr(tarfile, "data_filter"): tar.extractall(path=extract_path, filter="data") else: - tar.extractall(path=extract_path, members=_get_safe_members(tar)) + base = _get_resolved_path(extract_path) + tar.extractall(path=extract_path, members=_get_safe_members(tar.getmembers(), base)) # Re-validate extracted paths to catch symlink race conditions _validate_extracted_paths(extract_path) diff --git a/sagemaker-core/src/sagemaker/core/utils/__init__.py b/sagemaker-core/src/sagemaker/core/utils/__init__.py index 9947387537..be483efcd6 100644 --- a/sagemaker-core/src/sagemaker/core/utils/__init__.py +++ b/sagemaker-core/src/sagemaker/core/utils/__init__.py @@ -36,6 +36,10 @@ "sagemaker_timestamp", "sagemaker_short_timestamp", "get_config_value", + "_get_resolved_path", + "_is_bad_path", + "_is_bad_link", + "_get_safe_members", ] diff --git a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py new file mode 100644 index 0000000000..738f3c40a4 --- /dev/null +++ b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py @@ -0,0 +1,243 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for tar extraction safety functions in common_utils.""" +from __future__ import absolute_import + +import os +import pytest +import tarfile +import tempfile +from unittest.mock import Mock, patch, MagicMock + +from sagemaker.core.common_utils import ( + _get_resolved_path, + _is_bad_path, + _is_bad_link, + _get_safe_members, + custom_extractall_tarfile, +) + + +def test_get_resolved_path_returns_normalized_absolute_path(): + """Test _get_resolved_path returns normalized absolute path.""" + path = "./test/path" + result = _get_resolved_path(path) + assert os.path.isabs(result) + assert result == os.path.normpath(os.path.realpath(os.path.abspath(path))) + + +def test_get_resolved_path_with_absolute_path(): + """Test _get_resolved_path with absolute path.""" + path = "/absolute/test/path" + result = _get_resolved_path(path) + assert result == os.path.normpath(os.path.realpath(os.path.abspath(path))) + + +def test_is_bad_path_returns_false_for_safe_relative_path(): + """Test _is_bad_path returns False for safe relative paths.""" + base = _get_resolved_path("/tmp/extract") + safe_path = "safe/path/file.txt" + assert _is_bad_path(safe_path, base) is False + + +def test_is_bad_path_returns_true_for_absolute_escape_path(): + """Test _is_bad_path returns True for absolute paths that escape base.""" + base = _get_resolved_path("/tmp/safe") + unsafe_path = "/etc/passwd" + assert _is_bad_path(unsafe_path, base) is True + + +def test_is_bad_path_returns_true_for_parent_traversal(): + """Test _is_bad_path detects parent directory traversal.""" + base = _get_resolved_path("/tmp/safe/extract") + traversal_path = "../../../etc/passwd" + assert _is_bad_path(traversal_path, base) is True + + +def test_is_bad_path_with_similar_prefix_does_not_false_positive(): + """Test that /tmp/x2 is correctly identified as bad when base is /tmp/x. + + This verifies the commonpath fix: startswith would incorrectly allow + /tmp/x2 when base is /tmp/x, but commonpath correctly rejects it. + """ + base = _get_resolved_path("/tmp/x") + # A path like /tmp/x2/file should NOT be under /tmp/x + # With startswith, "/tmp/x2".startswith("/tmp/x") would be True (bug) + # With commonpath, commonpath(["/tmp/x2", "/tmp/x"]) == "/tmp" != "/tmp/x" (correct) + escape_path = "/tmp/x2/file" + result = _is_bad_path(escape_path, base) + assert result is True + + +def test_is_bad_link_returns_false_for_safe_symlink(): + """Test _is_bad_link returns False for safe links.""" + base = _get_resolved_path("/tmp/extract") + + mock_info = Mock() + mock_info.name = "safe/link" + mock_info.linkname = "safe/target" + + assert _is_bad_link(mock_info, base) is False + + +def test_is_bad_link_returns_true_for_escape_symlink(): + """Test _is_bad_link returns True for links that escape base.""" + base = _get_resolved_path("/tmp/safe") + + mock_info = Mock() + mock_info.name = "link" + mock_info.linkname = "/etc/passwd" + + result = _is_bad_link(mock_info, base) + assert result is True + + +def test_get_safe_members_yields_all_safe_members(): + """Test _get_safe_members yields all safe members.""" + base = _get_resolved_path("/tmp/extract") + + mock_member1 = Mock() + mock_member1.name = "safe/file1.txt" + mock_member1.issym = Mock(return_value=False) + mock_member1.islnk = Mock(return_value=False) + + mock_member2 = Mock() + mock_member2.name = "safe/file2.txt" + mock_member2.issym = Mock(return_value=False) + mock_member2.islnk = Mock(return_value=False) + + members = [mock_member1, mock_member2] + safe_members = list(_get_safe_members(members, base)) + + assert len(safe_members) == 2 + assert mock_member1 in safe_members + assert mock_member2 in safe_members + + +def test_get_safe_members_filters_bad_path_member(): + """Test _get_safe_members filters out members with bad paths.""" + base = _get_resolved_path("/tmp/extract") + + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) + + mock_member_bad = Mock() + mock_member_bad.name = "/etc/passwd" + mock_member_bad.issym = Mock(return_value=False) + mock_member_bad.islnk = Mock(return_value=False) + + members = [mock_member_safe, mock_member_bad] + safe_members = list(_get_safe_members(members, base)) + + assert len(safe_members) == 1 + assert mock_member_safe in safe_members + + +def test_get_safe_members_filters_bad_symlink_member(): + """Test _get_safe_members filters out bad symlinks.""" + base = _get_resolved_path("/tmp/extract") + + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) + + mock_member_symlink = Mock() + mock_member_symlink.name = "bad/symlink" + mock_member_symlink.issym = Mock(return_value=True) + mock_member_symlink.islnk = Mock(return_value=False) + mock_member_symlink.linkname = "/etc/passwd" + + members = [mock_member_safe, mock_member_symlink] + safe_members = list(_get_safe_members(members, base)) + + assert len(safe_members) == 1 + assert mock_member_safe in safe_members + + +def test_get_safe_members_filters_bad_hardlink_member(): + """Test _get_safe_members filters out bad hardlinks.""" + base = _get_resolved_path("/tmp/extract") + + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) + + mock_member_hardlink = Mock() + mock_member_hardlink.name = "bad/hardlink" + mock_member_hardlink.issym = Mock(return_value=False) + mock_member_hardlink.islnk = Mock(return_value=True) + mock_member_hardlink.linkname = "/etc/passwd" + + members = [mock_member_safe, mock_member_hardlink] + safe_members = list(_get_safe_members(members, base)) + + assert len(safe_members) == 1 + assert mock_member_safe in safe_members + + +def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): + """Test custom_extractall_tarfile uses data_filter when available.""" + mock_tar = Mock() + mock_tar.extractall = Mock() + extract_path = "/tmp/extract" + + with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile: + mock_tarfile.data_filter = "data" + + custom_extractall_tarfile(mock_tar, extract_path) + + mock_tar.extractall.assert_called_once_with(path=extract_path, filter="data") + + +def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): + """Test custom_extractall_tarfile uses safe members with getmembers() and resolved extract_path.""" + mock_member = Mock() + mock_member.name = "safe/file.txt" + mock_member.issym = Mock(return_value=False) + mock_member.islnk = Mock(return_value=False) + + mock_tar = Mock() + mock_tar.extractall = Mock() + mock_tar.getmembers = Mock(return_value=[mock_member]) + extract_path = "/tmp/extract" + + with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile: + # Remove data_filter attribute to simulate Python < 3.12 + if hasattr(mock_tarfile, 'data_filter'): + delattr(mock_tarfile, 'data_filter') + + with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe: + mock_safe.return_value = [mock_member] + + with patch('sagemaker.core.common_utils._validate_extracted_paths'): + custom_extractall_tarfile(mock_tar, extract_path) + + # Verify getmembers() was called (not iterating over tar directly) + mock_tar.getmembers.assert_called_once() + + # Verify _get_safe_members was called with the members list and resolved base + mock_safe.assert_called_once() + call_args = mock_safe.call_args + assert call_args[0][0] == [mock_member] # members list + # base should be resolved extract_path, not cwd + expected_base = _get_resolved_path(extract_path) + assert call_args[0][1] == expected_base + + mock_tar.extractall.assert_called_once() + call_kwargs = mock_tar.extractall.call_args[1] + assert call_kwargs['path'] == extract_path + assert 'members' in call_kwargs From b8c6e492230a980ef0fc60069349e9c362e29da7 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:32:26 -0400 Subject: [PATCH 2/4] fix: address review comments (iteration #1) --- .../src/sagemaker/core/common_utils.py | 5 +- .../src/sagemaker/core/utils/__init__.py | 9 +- .../unit/test_common_utils_tar_safety.py | 292 ++++++++++-------- 3 files changed, 178 insertions(+), 128 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 2cfd58044c..7e99314173 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -1689,7 +1689,10 @@ def _is_bad_path(path, base): """ # joinpath will ignore base if path is absolute resolved = _get_resolved_path(joinpath(base, path)) - return os.path.commonpath([resolved, base]) != base + try: + return os.path.commonpath([resolved, base]) != base + except ValueError: + return True # If we can't determine safety, treat as bad path def _is_bad_link(info, base): diff --git a/sagemaker-core/src/sagemaker/core/utils/__init__.py b/sagemaker-core/src/sagemaker/core/utils/__init__.py index be483efcd6..1f8cf1d374 100644 --- a/sagemaker-core/src/sagemaker/core/utils/__init__.py +++ b/sagemaker-core/src/sagemaker/core/utils/__init__.py @@ -19,6 +19,9 @@ """ from __future__ import absolute_import +# Public API surface: only non-private functions are exported via __all__. +# Private helpers (_get_resolved_path, _is_bad_path, _is_bad_link, _get_safe_members) +# are still importable directly but are not part of the public API. __all__ = [ "_save_model", "download_file_from_url", @@ -36,6 +39,10 @@ "sagemaker_timestamp", "sagemaker_short_timestamp", "get_config_value", +] + +# Internal helpers that are importable but not part of the public API +_INTERNAL_NAMES = [ "_get_resolved_path", "_is_bad_path", "_is_bad_link", @@ -45,7 +52,7 @@ def __getattr__(name): """Lazy import to avoid circular dependencies.""" - if name in __all__: + if name in __all__ or name in _INTERNAL_NAMES: from sagemaker.core import common_utils return getattr(common_utils, name) diff --git a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py index 738f3c40a4..db4364a790 100644 --- a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py +++ b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit tests for tar extraction safety functions in common_utils.""" -from __future__ import absolute_import +from __future__ import annotations import os -import pytest -import tarfile import tempfile +import tarfile + +import pytest from unittest.mock import Mock, patch, MagicMock from sagemaker.core.common_utils import ( @@ -38,173 +39,198 @@ def test_get_resolved_path_returns_normalized_absolute_path(): def test_get_resolved_path_with_absolute_path(): """Test _get_resolved_path with absolute path.""" - path = "/absolute/test/path" - result = _get_resolved_path(path) - assert result == os.path.normpath(os.path.realpath(os.path.abspath(path))) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "absolute", "test", "path") + result = _get_resolved_path(path) + assert result == os.path.normpath(os.path.realpath(os.path.abspath(path))) def test_is_bad_path_returns_false_for_safe_relative_path(): """Test _is_bad_path returns False for safe relative paths.""" - base = _get_resolved_path("/tmp/extract") - safe_path = "safe/path/file.txt" - assert _is_bad_path(safe_path, base) is False + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) + safe_path = "safe/path/file.txt" + assert _is_bad_path(safe_path, base) is False def test_is_bad_path_returns_true_for_absolute_escape_path(): """Test _is_bad_path returns True for absolute paths that escape base.""" - base = _get_resolved_path("/tmp/safe") - unsafe_path = "/etc/passwd" - assert _is_bad_path(unsafe_path, base) is True + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "safe")) + unsafe_path = "/etc/passwd" + assert _is_bad_path(unsafe_path, base) is True def test_is_bad_path_returns_true_for_parent_traversal(): """Test _is_bad_path detects parent directory traversal.""" - base = _get_resolved_path("/tmp/safe/extract") - traversal_path = "../../../etc/passwd" - assert _is_bad_path(traversal_path, base) is True + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "safe", "extract")) + traversal_path = "../../../etc/passwd" + assert _is_bad_path(traversal_path, base) is True def test_is_bad_path_with_similar_prefix_does_not_false_positive(): - """Test that /tmp/x2 is correctly identified as bad when base is /tmp/x. - + """Test that base/x2 is correctly identified as bad when base is base/x. + This verifies the commonpath fix: startswith would incorrectly allow - /tmp/x2 when base is /tmp/x, but commonpath correctly rejects it. + base/x2 when base is base/x, but commonpath correctly rejects it. """ - base = _get_resolved_path("/tmp/x") - # A path like /tmp/x2/file should NOT be under /tmp/x - # With startswith, "/tmp/x2".startswith("/tmp/x") would be True (bug) - # With commonpath, commonpath(["/tmp/x2", "/tmp/x"]) == "/tmp" != "/tmp/x" (correct) - escape_path = "/tmp/x2/file" - result = _is_bad_path(escape_path, base) - assert result is True + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "x")) + # A path like tmpdir/x2/file should NOT be under tmpdir/x + # With startswith, "tmpdir/x2".startswith("tmpdir/x") would be True (bug) + # With commonpath, commonpath(["tmpdir/x2", "tmpdir/x"]) == "tmpdir" != "tmpdir/x" (correct) + escape_path = os.path.join(tmpdir, "x2", "file") + result = _is_bad_path(escape_path, base) + assert result is True def test_is_bad_link_returns_false_for_safe_symlink(): """Test _is_bad_link returns False for safe links.""" - base = _get_resolved_path("/tmp/extract") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) - mock_info = Mock() - mock_info.name = "safe/link" - mock_info.linkname = "safe/target" + mock_info = Mock() + mock_info.name = "safe/link" + mock_info.linkname = "safe/target" - assert _is_bad_link(mock_info, base) is False + assert _is_bad_link(mock_info, base) is False def test_is_bad_link_returns_true_for_escape_symlink(): """Test _is_bad_link returns True for links that escape base.""" - base = _get_resolved_path("/tmp/safe") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "safe")) - mock_info = Mock() - mock_info.name = "link" - mock_info.linkname = "/etc/passwd" + mock_info = Mock() + mock_info.name = "link" + mock_info.linkname = "/etc/passwd" - result = _is_bad_link(mock_info, base) - assert result is True + result = _is_bad_link(mock_info, base) + assert result is True def test_get_safe_members_yields_all_safe_members(): """Test _get_safe_members yields all safe members.""" - base = _get_resolved_path("/tmp/extract") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) - mock_member1 = Mock() - mock_member1.name = "safe/file1.txt" - mock_member1.issym = Mock(return_value=False) - mock_member1.islnk = Mock(return_value=False) + mock_member1 = Mock() + mock_member1.name = "safe/file1.txt" + mock_member1.issym = Mock(return_value=False) + mock_member1.islnk = Mock(return_value=False) - mock_member2 = Mock() - mock_member2.name = "safe/file2.txt" - mock_member2.issym = Mock(return_value=False) - mock_member2.islnk = Mock(return_value=False) + mock_member2 = Mock() + mock_member2.name = "safe/file2.txt" + mock_member2.issym = Mock(return_value=False) + mock_member2.islnk = Mock(return_value=False) - members = [mock_member1, mock_member2] - safe_members = list(_get_safe_members(members, base)) + members = [mock_member1, mock_member2] + safe_members = list(_get_safe_members(members, base)) - assert len(safe_members) == 2 - assert mock_member1 in safe_members - assert mock_member2 in safe_members + assert len(safe_members) == 2 + assert mock_member1 in safe_members + assert mock_member2 in safe_members def test_get_safe_members_filters_bad_path_member(): """Test _get_safe_members filters out members with bad paths.""" - base = _get_resolved_path("/tmp/extract") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) - mock_member_safe = Mock() - mock_member_safe.name = "safe/file.txt" - mock_member_safe.issym = Mock(return_value=False) - mock_member_safe.islnk = Mock(return_value=False) + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) - mock_member_bad = Mock() - mock_member_bad.name = "/etc/passwd" - mock_member_bad.issym = Mock(return_value=False) - mock_member_bad.islnk = Mock(return_value=False) + mock_member_bad = Mock() + mock_member_bad.name = "/etc/passwd" + mock_member_bad.issym = Mock(return_value=False) + mock_member_bad.islnk = Mock(return_value=False) - members = [mock_member_safe, mock_member_bad] - safe_members = list(_get_safe_members(members, base)) + members = [mock_member_safe, mock_member_bad] + safe_members = list(_get_safe_members(members, base)) - assert len(safe_members) == 1 - assert mock_member_safe in safe_members + assert len(safe_members) == 1 + assert mock_member_safe in safe_members def test_get_safe_members_filters_bad_symlink_member(): """Test _get_safe_members filters out bad symlinks.""" - base = _get_resolved_path("/tmp/extract") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) - mock_member_safe = Mock() - mock_member_safe.name = "safe/file.txt" - mock_member_safe.issym = Mock(return_value=False) - mock_member_safe.islnk = Mock(return_value=False) + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) - mock_member_symlink = Mock() - mock_member_symlink.name = "bad/symlink" - mock_member_symlink.issym = Mock(return_value=True) - mock_member_symlink.islnk = Mock(return_value=False) - mock_member_symlink.linkname = "/etc/passwd" + mock_member_symlink = Mock() + mock_member_symlink.name = "bad/symlink" + mock_member_symlink.issym = Mock(return_value=True) + mock_member_symlink.islnk = Mock(return_value=False) + mock_member_symlink.linkname = "/etc/passwd" - members = [mock_member_safe, mock_member_symlink] - safe_members = list(_get_safe_members(members, base)) + members = [mock_member_safe, mock_member_symlink] + safe_members = list(_get_safe_members(members, base)) - assert len(safe_members) == 1 - assert mock_member_safe in safe_members + assert len(safe_members) == 1 + assert mock_member_safe in safe_members def test_get_safe_members_filters_bad_hardlink_member(): """Test _get_safe_members filters out bad hardlinks.""" - base = _get_resolved_path("/tmp/extract") + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) - mock_member_safe = Mock() - mock_member_safe.name = "safe/file.txt" - mock_member_safe.issym = Mock(return_value=False) - mock_member_safe.islnk = Mock(return_value=False) + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) - mock_member_hardlink = Mock() - mock_member_hardlink.name = "bad/hardlink" - mock_member_hardlink.issym = Mock(return_value=False) - mock_member_hardlink.islnk = Mock(return_value=True) - mock_member_hardlink.linkname = "/etc/passwd" + mock_member_hardlink = Mock() + mock_member_hardlink.name = "bad/hardlink" + mock_member_hardlink.issym = Mock(return_value=False) + mock_member_hardlink.islnk = Mock(return_value=True) + mock_member_hardlink.linkname = "/etc/passwd" - members = [mock_member_safe, mock_member_hardlink] - safe_members = list(_get_safe_members(members, base)) + members = [mock_member_safe, mock_member_hardlink] + safe_members = list(_get_safe_members(members, base)) - assert len(safe_members) == 1 - assert mock_member_safe in safe_members + assert len(safe_members) == 1 + assert mock_member_safe in safe_members def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): - """Test custom_extractall_tarfile uses data_filter when available.""" + """Test custom_extractall_tarfile uses data_filter when available. + + We set mock_tarfile.data_filter explicitly to ensure hasattr returns True. + The MagicMock would auto-create the attribute anyway, but we set it + explicitly for clarity. The key assertion is that filter="data" is passed. + """ mock_tar = Mock() mock_tar.extractall = Mock() - extract_path = "/tmp/extract" - with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile: - mock_tarfile.data_filter = "data" + with tempfile.TemporaryDirectory() as tmpdir: + extract_path = os.path.join(tmpdir, "extract") + + with patch("sagemaker.core.common_utils.tarfile") as mock_tarfile: + # Explicitly set data_filter to ensure the hasattr check passes + mock_tarfile.data_filter = True - custom_extractall_tarfile(mock_tar, extract_path) + custom_extractall_tarfile(mock_tar, extract_path) - mock_tar.extractall.assert_called_once_with(path=extract_path, filter="data") + mock_tar.extractall.assert_called_once_with(path=extract_path, filter="data") def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): - """Test custom_extractall_tarfile uses safe members with getmembers() and resolved extract_path.""" + """Test custom_extractall_tarfile uses safe members when data_filter is unavailable. + + Verifies that: + 1. tar.getmembers() is called (not iterating over tar directly) + 2. _get_safe_members is called with the members list and resolved extract_path as base + 3. _validate_extracted_paths is called after extraction + """ mock_member = Mock() mock_member.name = "safe/file.txt" mock_member.issym = Mock(return_value=False) @@ -213,31 +239,45 @@ def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): mock_tar = Mock() mock_tar.extractall = Mock() mock_tar.getmembers = Mock(return_value=[mock_member]) - extract_path = "/tmp/extract" - - with patch('sagemaker.core.common_utils.tarfile') as mock_tarfile: - # Remove data_filter attribute to simulate Python < 3.12 - if hasattr(mock_tarfile, 'data_filter'): - delattr(mock_tarfile, 'data_filter') - - with patch('sagemaker.core.common_utils._get_safe_members') as mock_safe: - mock_safe.return_value = [mock_member] - - with patch('sagemaker.core.common_utils._validate_extracted_paths'): - custom_extractall_tarfile(mock_tar, extract_path) - - # Verify getmembers() was called (not iterating over tar directly) - mock_tar.getmembers.assert_called_once() - - # Verify _get_safe_members was called with the members list and resolved base - mock_safe.assert_called_once() - call_args = mock_safe.call_args - assert call_args[0][0] == [mock_member] # members list - # base should be resolved extract_path, not cwd - expected_base = _get_resolved_path(extract_path) - assert call_args[0][1] == expected_base - - mock_tar.extractall.assert_called_once() - call_kwargs = mock_tar.extractall.call_args[1] - assert call_kwargs['path'] == extract_path - assert 'members' in call_kwargs + + with tempfile.TemporaryDirectory() as tmpdir: + extract_path = os.path.join(tmpdir, "extract") + + # Use spec= to restrict the mock to only have 'TarFile' attribute, + # so hasattr(mock_tarfile, 'data_filter') returns False + with patch( + "sagemaker.core.common_utils.tarfile", spec=["TarFile"] + ): + with patch("sagemaker.core.common_utils._get_safe_members") as mock_safe: + mock_safe.return_value = [mock_member] + + with patch("sagemaker.core.common_utils._validate_extracted_paths"): + custom_extractall_tarfile(mock_tar, extract_path) + + # Verify getmembers() was called (not iterating over tar directly) + mock_tar.getmembers.assert_called_once() + + # Verify _get_safe_members was called with the members list and resolved base + mock_safe.assert_called_once() + call_args = mock_safe.call_args + assert call_args[0][0] == [mock_member] # members list + # base should be resolved extract_path, not cwd + expected_base = _get_resolved_path(extract_path) + assert call_args[0][1] == expected_base + + mock_tar.extractall.assert_called_once() + call_kwargs = mock_tar.extractall.call_args[1] + assert call_kwargs["path"] == extract_path + assert "members" in call_kwargs + + +def test_is_bad_path_handles_value_error_gracefully(): + """Test that _is_bad_path returns True when os.path.commonpath raises ValueError. + + This can happen on Windows with paths on different drives, or with mixed + absolute/relative paths. + """ + with patch("sagemaker.core.common_utils.os.path.commonpath", side_effect=ValueError): + # Should return True (treat as bad path) when commonpath raises ValueError + result = _is_bad_path("some/path", "/some/base") + assert result is True From f4eee66edbc28373acc4e37008d60422c8e6f751 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 14:56:15 -0400 Subject: [PATCH 3/4] fix: address review comments (iteration #2) --- .../src/sagemaker/core/utils/__init__.py | 16 +++++----------- .../tests/unit/test_common_utils_tar_safety.py | 17 ++++++++++------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/utils/__init__.py b/sagemaker-core/src/sagemaker/core/utils/__init__.py index 1f8cf1d374..ad9f173004 100644 --- a/sagemaker-core/src/sagemaker/core/utils/__init__.py +++ b/sagemaker-core/src/sagemaker/core/utils/__init__.py @@ -19,9 +19,11 @@ """ from __future__ import absolute_import -# Public API surface: only non-private functions are exported via __all__. +# Public API surface. +# Note: _save_model is underscore-prefixed but was already in __all__ (pre-existing). +# custom_extractall_tarfile is the main public entry point for safe tar extraction. # Private helpers (_get_resolved_path, _is_bad_path, _is_bad_link, _get_safe_members) -# are still importable directly but are not part of the public API. +# are importable directly from sagemaker.core.common_utils but are not re-exported here. __all__ = [ "_save_model", "download_file_from_url", @@ -41,18 +43,10 @@ "get_config_value", ] -# Internal helpers that are importable but not part of the public API -_INTERNAL_NAMES = [ - "_get_resolved_path", - "_is_bad_path", - "_is_bad_link", - "_get_safe_members", -] - def __getattr__(name): """Lazy import to avoid circular dependencies.""" - if name in __all__ or name in _INTERNAL_NAMES: + if name in __all__: from sagemaker.core import common_utils return getattr(common_utils, name) diff --git a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py index db4364a790..3bccd0648f 100644 --- a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py +++ b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py @@ -15,10 +15,8 @@ import os import tempfile -import tarfile -import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch from sagemaker.core.common_utils import ( _get_resolved_path, @@ -204,9 +202,10 @@ def test_get_safe_members_filters_bad_hardlink_member(): def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): """Test custom_extractall_tarfile uses data_filter when available. - We set mock_tarfile.data_filter explicitly to ensure hasattr returns True. - The MagicMock would auto-create the attribute anyway, but we set it - explicitly for clarity. The key assertion is that filter="data" is passed. + We patch the module-level `tarfile` import in common_utils (not the `tar` parameter). + Setting mock_tarfile.data_filter explicitly ensures hasattr(tarfile, 'data_filter') + returns True inside custom_extractall_tarfile. The key assertion is that + filter="data" is passed to tar.extractall. """ mock_tar = Mock() mock_tar.extractall = Mock() @@ -214,8 +213,9 @@ def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): with tempfile.TemporaryDirectory() as tmpdir: extract_path = os.path.join(tmpdir, "extract") + # Patch the module-level tarfile import in common_utils with patch("sagemaker.core.common_utils.tarfile") as mock_tarfile: - # Explicitly set data_filter to ensure the hasattr check passes + # Explicitly set data_filter so hasattr check passes mock_tarfile.data_filter = True custom_extractall_tarfile(mock_tar, extract_path) @@ -226,6 +226,9 @@ def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): """Test custom_extractall_tarfile uses safe members when data_filter is unavailable. + Uses spec=['TarFile'] to restrict the mock so that + hasattr(mock_tarfile, 'data_filter') returns False, forcing the fallback path. + Verifies that: 1. tar.getmembers() is called (not iterating over tar directly) 2. _get_safe_members is called with the members list and resolved extract_path as base From ad3e58b6884a6f3a68c3ccbee94b88b4985a1eac Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:01:59 -0400 Subject: [PATCH 4/4] fix: address review comments (iteration #3) --- .../src/sagemaker/core/common_utils.py | 8 +- .../src/sagemaker/core/utils/__init__.py | 9 +- .../unit/test_common_utils_tar_safety.py | 141 +++++++++++++----- 3 files changed, 114 insertions(+), 44 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 7e99314173..2f3a8ea92e 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -1708,22 +1708,26 @@ def _is_bad_link(info, base): bool: True if the link is not rooted under the base directory, False otherwise. """ # Links are interpreted relative to the directory containing the link + # Wrap with _get_resolved_path to ensure consistent normalization for commonpath comparison tip = _get_resolved_path(joinpath(base, dirname(info.name))) return _is_bad_path(info.linkname, base=tip) -def _get_safe_members(members, base): +def _get_safe_members(members, base=None): """A generator that yields members that are safe to extract. It filters out bad paths and bad links. Args: members (list): A list of TarInfo members to check. - base (str): The base directory for extraction. + base (str): The base directory for extraction. If None, defaults to the + current working directory (for backward compatibility). Yields: tarfile.TarInfo: The tar file info. """ + if base is None: + base = _get_resolved_path("") for file_info in members: if _is_bad_path(file_info.name, base): logger.error("%s is blocked (illegal path)", file_info.name) diff --git a/sagemaker-core/src/sagemaker/core/utils/__init__.py b/sagemaker-core/src/sagemaker/core/utils/__init__.py index ad9f173004..184550100d 100644 --- a/sagemaker-core/src/sagemaker/core/utils/__init__.py +++ b/sagemaker-core/src/sagemaker/core/utils/__init__.py @@ -16,14 +16,19 @@ for backward compatibility and convenience. Note: Uses lazy imports via __getattr__ to avoid circular import issues. + +Private tar extraction safety helpers (_get_resolved_path, _is_bad_path, +_is_bad_link, _get_safe_members, _validate_extracted_paths) are internal +implementation details and are NOT re-exported from this package. They can +be imported directly from sagemaker.core.common_utils if needed. + +custom_extractall_tarfile is the public entry point for safe tar extraction. """ from __future__ import absolute_import # Public API surface. # Note: _save_model is underscore-prefixed but was already in __all__ (pre-existing). # custom_extractall_tarfile is the main public entry point for safe tar extraction. -# Private helpers (_get_resolved_path, _is_bad_path, _is_bad_link, _get_safe_members) -# are importable directly from sagemaker.core.common_utils but are not re-exported here. __all__ = [ "_save_model", "download_file_from_url", diff --git a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py index 3bccd0648f..fd88701e7d 100644 --- a/sagemaker-core/tests/unit/test_common_utils_tar_safety.py +++ b/sagemaker-core/tests/unit/test_common_utils_tar_safety.py @@ -67,22 +67,49 @@ def test_is_bad_path_returns_true_for_parent_traversal(): assert _is_bad_path(traversal_path, base) is True -def test_is_bad_path_with_similar_prefix_does_not_false_positive(): +def test_is_bad_path_with_similar_prefix_absolute_escape(): """Test that base/x2 is correctly identified as bad when base is base/x. - This verifies the commonpath fix: startswith would incorrectly allow - base/x2 when base is base/x, but commonpath correctly rejects it. + This verifies the commonpath fix for absolute path escape: + With startswith, "tmpdir/x2".startswith("tmpdir/x") would be True (bug). + With commonpath, commonpath(["tmpdir/x2", "tmpdir/x"]) == "tmpdir" != "tmpdir/x" (correct). """ with tempfile.TemporaryDirectory() as tmpdir: base = _get_resolved_path(os.path.join(tmpdir, "x")) - # A path like tmpdir/x2/file should NOT be under tmpdir/x - # With startswith, "tmpdir/x2".startswith("tmpdir/x") would be True (bug) - # With commonpath, commonpath(["tmpdir/x2", "tmpdir/x"]) == "tmpdir" != "tmpdir/x" (correct) + # An absolute path like tmpdir/x2/file should NOT be under tmpdir/x + # joinpath ignores base for absolute paths, so this tests absolute escape escape_path = os.path.join(tmpdir, "x2", "file") result = _is_bad_path(escape_path, base) assert result is True +def test_is_bad_path_with_similar_prefix_relative_traversal(): + """Test the commonpath fix with a relative path that triggers the prefix issue. + + Using ../x2/file as a relative path from base tmpdir/x should resolve to + tmpdir/x2/file which is outside tmpdir/x. This more directly tests the + commonpath fix for relative traversal. + """ + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "x")) + # ../x2/file relative to base tmpdir/x resolves to tmpdir/x2/file + escape_path = "../x2/file" + result = _is_bad_path(escape_path, base) + assert result is True + + +def test_is_bad_path_handles_value_error_gracefully(): + """Test that _is_bad_path returns True when os.path.commonpath raises ValueError. + + This can happen on Windows with paths on different drives, or with mixed + absolute/relative paths. + """ + with patch("sagemaker.core.common_utils.os.path.commonpath", side_effect=ValueError): + # Should return True (treat as bad path) when commonpath raises ValueError + result = _is_bad_path("some/path", "/some/base") + assert result is True + + def test_is_bad_link_returns_false_for_safe_symlink(): """Test _is_bad_link returns False for safe links.""" with tempfile.TemporaryDirectory() as tmpdir: @@ -153,6 +180,32 @@ def test_get_safe_members_filters_bad_path_member(): assert mock_member_safe in safe_members +def test_get_safe_members_filters_directory_traversal_member(): + """Test _get_safe_members filters out members with directory traversal. + + This tests the common tar slip attack vector where archive entries + contain ../../ paths to escape the extraction directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + base = _get_resolved_path(os.path.join(tmpdir, "extract")) + + mock_member_safe = Mock() + mock_member_safe.name = "safe/file.txt" + mock_member_safe.issym = Mock(return_value=False) + mock_member_safe.islnk = Mock(return_value=False) + + mock_member_traversal = Mock() + mock_member_traversal.name = "../../etc/passwd" + mock_member_traversal.issym = Mock(return_value=False) + mock_member_traversal.islnk = Mock(return_value=False) + + members = [mock_member_safe, mock_member_traversal] + safe_members = list(_get_safe_members(members, base)) + + assert len(safe_members) == 1 + assert mock_member_safe in safe_members + + def test_get_safe_members_filters_bad_symlink_member(): """Test _get_safe_members filters out bad symlinks.""" with tempfile.TemporaryDirectory() as tmpdir: @@ -199,6 +252,22 @@ def test_get_safe_members_filters_bad_hardlink_member(): assert mock_member_safe in safe_members +def test_get_safe_members_backward_compatible_without_base(): + """Test _get_safe_members works without base parameter for backward compatibility. + + When base is not provided, it defaults to the current working directory. + """ + mock_member = Mock() + mock_member.name = "safe/file.txt" + mock_member.issym = Mock(return_value=False) + mock_member.islnk = Mock(return_value=False) + + # Call without base parameter - should not raise TypeError + safe_members = list(_get_safe_members([mock_member])) + assert len(safe_members) == 1 + assert mock_member in safe_members + + def test_custom_extractall_tarfile_with_data_filter_uses_filter_param(): """Test custom_extractall_tarfile uses data_filter when available. @@ -231,7 +300,7 @@ def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): Verifies that: 1. tar.getmembers() is called (not iterating over tar directly) - 2. _get_safe_members is called with the members list and resolved extract_path as base + 2. _get_safe_members is called with members list and resolved extract_path as base 3. _validate_extracted_paths is called after extraction """ mock_member = Mock() @@ -248,39 +317,31 @@ def test_custom_extractall_tarfile_without_data_filter_uses_safe_members(): # Use spec= to restrict the mock to only have 'TarFile' attribute, # so hasattr(mock_tarfile, 'data_filter') returns False - with patch( - "sagemaker.core.common_utils.tarfile", spec=["TarFile"] - ): - with patch("sagemaker.core.common_utils._get_safe_members") as mock_safe: - mock_safe.return_value = [mock_member] - - with patch("sagemaker.core.common_utils._validate_extracted_paths"): - custom_extractall_tarfile(mock_tar, extract_path) - - # Verify getmembers() was called (not iterating over tar directly) - mock_tar.getmembers.assert_called_once() - - # Verify _get_safe_members was called with the members list and resolved base - mock_safe.assert_called_once() - call_args = mock_safe.call_args - assert call_args[0][0] == [mock_member] # members list - # base should be resolved extract_path, not cwd - expected_base = _get_resolved_path(extract_path) - assert call_args[0][1] == expected_base + with patch("sagemaker.core.common_utils.tarfile", spec=["TarFile"]), \ + patch("sagemaker.core.common_utils._get_safe_members") as mock_safe, \ + patch("sagemaker.core.common_utils._validate_extracted_paths") as mock_validate: - mock_tar.extractall.assert_called_once() - call_kwargs = mock_tar.extractall.call_args[1] - assert call_kwargs["path"] == extract_path - assert "members" in call_kwargs + mock_safe.return_value = [mock_member] + custom_extractall_tarfile(mock_tar, extract_path) -def test_is_bad_path_handles_value_error_gracefully(): - """Test that _is_bad_path returns True when os.path.commonpath raises ValueError. - - This can happen on Windows with paths on different drives, or with mixed - absolute/relative paths. - """ - with patch("sagemaker.core.common_utils.os.path.commonpath", side_effect=ValueError): - # Should return True (treat as bad path) when commonpath raises ValueError - result = _is_bad_path("some/path", "/some/base") - assert result is True + # Verify getmembers() was called (not iterating over tar directly) + mock_tar.getmembers.assert_called_once() + + # Verify _get_safe_members was called with: + # arg 0: members list (from tar.getmembers()) + # arg 1: base (resolved extract_path, not cwd) + mock_safe.assert_called_once() + call_args = mock_safe.call_args + assert call_args[0][0] == [mock_member] # members list + expected_base = _get_resolved_path(extract_path) + assert call_args[0][1] == expected_base # base is resolved extract_path + + # Verify extractall was called with path and members + mock_tar.extractall.assert_called_once() + call_kwargs = mock_tar.extractall.call_args[1] + assert call_kwargs["path"] == extract_path + assert "members" in call_kwargs + + # Verify _validate_extracted_paths is called after extraction + mock_validate.assert_called_once_with(extract_path)