Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions sagemaker-core/src/sagemaker/core/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
def _get_compose_cmd_prefix():
"""Gets the Docker Compose command.

The method initially looks for 'docker compose' v2
The method initially looks for 'docker compose' v2+
executable, if not found looks for 'docker-compose' executable.

Returns:
Expand All @@ -162,10 +162,12 @@ def _get_compose_cmd_prefix():
"Proceeding to check for 'docker-compose' CLI."
)

if output and "v2" in output.strip():
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix
if output:
match = re.search(r"v(\d+)", output.strip())
if match and int(match.group(1)) >= 2:
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix

if shutil.which("docker-compose") is not None:
logger.info("'Docker Compose' found using Docker Compose CLI.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
def _get_compose_cmd_prefix(self) -> List[str]:
"""Gets the Docker Compose command.

The method initially looks for 'docker compose' v2
The method initially looks for 'docker compose' v2+
executable, if not found looks for 'docker-compose' executable.

Returns:
Expand All @@ -617,10 +617,12 @@ def _get_compose_cmd_prefix(self) -> List[str]:
"Proceeding to check for 'docker-compose' CLI."
)

if output and "v2" in output.strip():
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix
if output:
match = re.search(r"v(\d+)", output.strip())
if match and int(match.group(1)) >= 2:
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix

if shutil.which("docker-compose") is not None:
logger.info("'Docker Compose' found using Docker Compose CLI.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,82 @@ def test_get_compose_cmd_prefix_not_found(
with pytest.raises(ImportError, match="Docker Compose is not installed"):
container._get_compose_cmd_prefix()

@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
def test_get_compose_cmd_prefix_docker_compose_v5(
self, mock_check_output, mock_session, basic_channel
):
"""Test _get_compose_cmd_prefix accepts Docker Compose v5"""
container = _LocalContainer(
training_job_name="test-job",
instance_type="local",
instance_count=1,
image="test-image:latest",
container_root="/tmp/test",
input_data_config=[basic_channel],
environment={},
hyper_parameters={},
container_entrypoint=[],
container_arguments=[],
sagemaker_session=mock_session,
)

mock_check_output.return_value = "Docker Compose version v5.1.1"

result = container._get_compose_cmd_prefix()

assert result == ["docker", "compose"]

@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
def test_get_compose_cmd_prefix_docker_compose_v3(
self, mock_check_output, mock_session, basic_channel
):
"""Test _get_compose_cmd_prefix accepts Docker Compose v3"""
container = _LocalContainer(
training_job_name="test-job",
instance_type="local",
instance_count=1,
image="test-image:latest",
container_root="/tmp/test",
input_data_config=[basic_channel],
environment={},
hyper_parameters={},
container_entrypoint=[],
container_arguments=[],
sagemaker_session=mock_session,
)

mock_check_output.return_value = "Docker Compose version v3.0.0"

result = container._get_compose_cmd_prefix()

assert result == ["docker", "compose"]

@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
@patch("sagemaker.core.modules.local_core.local_container.shutil.which")
def test_get_compose_cmd_prefix_docker_compose_v1_rejected(
self, mock_which, mock_check_output, mock_session, basic_channel
):
"""Test _get_compose_cmd_prefix rejects Docker Compose v1"""
container = _LocalContainer(
training_job_name="test-job",
instance_type="local",
instance_count=1,
image="test-image:latest",
container_root="/tmp/test",
input_data_config=[basic_channel],
environment={},
hyper_parameters={},
container_entrypoint=[],
container_arguments=[],
sagemaker_session=mock_session,
)

mock_check_output.return_value = "docker-compose version v1.29.2"
mock_which.return_value = None

with pytest.raises(ImportError, match="Docker Compose is not installed"):
container._get_compose_cmd_prefix()

def test_init_with_container_entrypoint(self, mock_session, basic_channel):
"""Test initialization with container entrypoint"""
container = _LocalContainer(
Expand Down
12 changes: 7 additions & 5 deletions sagemaker-train/src/sagemaker/train/local/local_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
def _get_compose_cmd_prefix(self) -> List[str]:
"""Gets the Docker Compose command.

The method initially looks for 'docker compose' v2
The method initially looks for 'docker compose' v2+
executable, if not found looks for 'docker-compose' executable.

Returns:
Expand All @@ -625,10 +625,12 @@ def _get_compose_cmd_prefix(self) -> List[str]:
"Proceeding to check for 'docker-compose' CLI."
)

if output and "v2" in output.strip():
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix
if output:
match = re.search(r"v(\d+)", output.strip())
if match and int(match.group(1)) >= 2:
logger.info("'Docker Compose' found using Docker CLI.")
compose_cmd_prefix.extend(["docker", "compose"])
return compose_cmd_prefix

if shutil.which("docker-compose") is not None:
logger.info("'Docker Compose' found using Docker Compose CLI.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Integration tests for Docker Compose version detection fix (issue #5739).

These tests verify that _get_compose_cmd_prefix correctly accepts Docker Compose
versions >= 2 (including v3, v4, v5, etc.) rather than only accepting v2.

The tests run against the real Docker Compose installation on the machine — no mocking.
Requires: Docker with Compose plugin installed (any version >= 2).
"""
from __future__ import absolute_import

import re
import subprocess
import tempfile

import pytest

from sagemaker.core.local.image import _SageMakerContainer
from sagemaker.core.modules.local_core.local_container import (
_LocalContainer as CoreModulesLocalContainer,
)
from sagemaker.core.shapes import Channel, DataSource, S3DataSource
from sagemaker.train.local.local_container import (
_LocalContainer as TrainLocalContainer,
)


def _get_installed_compose_major_version():
"""Return the major version int of the installed Docker Compose, or None."""
try:
output = subprocess.check_output(
["docker", "compose", "version"],
stderr=subprocess.DEVNULL,
encoding="UTF-8",
)
match = re.search(r"v(\d+)", output.strip())
if match:
return int(match.group(1))
except (subprocess.CalledProcessError, FileNotFoundError):
pass
return None


# Skip the entire module if Docker Compose >= 2 is not available
_compose_major = _get_installed_compose_major_version()
pytestmark = pytest.mark.skipif(
_compose_major is None or _compose_major < 2,
reason=f"Docker Compose >= 2 required (found: v{_compose_major})",
)


def _make_basic_channel():
"""Create a minimal Channel for constructing _LocalContainer instances."""
data_source = DataSource(
s3_data_source=S3DataSource(
s3_uri="s3://bucket/data",
s3_data_type="S3Prefix",
s3_data_distribution_type="FullyReplicated",
)
)
return Channel(channel_name="training", data_source=data_source)


def _make_local_container(container_cls):
"""Construct a _LocalContainer with minimal valid args.

sagemaker_session is None since _get_compose_cmd_prefix doesn't use it,
and the Pydantic model rejects Mock objects.
"""
container_root = tempfile.mkdtemp(prefix="sagemaker-integ-compose-")
return container_cls(
training_job_name="integ-test-compose-detection",
instance_type="local",
instance_count=1,
image="test-image:latest",
container_root=container_root,
input_data_config=[_make_basic_channel()],
environment={},
hyper_parameters={},
container_entrypoint=[],
container_arguments=[],
sagemaker_session=None,
)


@pytest.fixture
def _core_modules_container():
return _make_local_container(CoreModulesLocalContainer)


@pytest.fixture
def _train_container():
return _make_local_container(TrainLocalContainer)


class TestDockerComposeVersionDetection:
"""Integration tests for _get_compose_cmd_prefix across all three code locations.

Validates the fix for https://github.com/aws/sagemaker-python-sdk/issues/5739
where Docker Compose v3+ was incorrectly rejected.
"""

def test_sagemaker_core_image_accepts_installed_compose(self):
"""sagemaker-core local/image.py _SageMakerContainer._get_compose_cmd_prefix
should accept the installed Docker Compose version."""
result = _SageMakerContainer._get_compose_cmd_prefix()

assert result == ["docker", "compose"], (
f"Expected ['docker', 'compose'] but got {result}. "
f"Installed Docker Compose is v{_compose_major}."
)

def test_sagemaker_core_modules_local_container_accepts_installed_compose(
self, _core_modules_container
):
"""sagemaker-core modules/local_core/local_container.py
_LocalContainer._get_compose_cmd_prefix should accept the installed version."""
result = _core_modules_container._get_compose_cmd_prefix()

assert result == ["docker", "compose"], (
f"Expected ['docker', 'compose'] but got {result}. "
f"Installed Docker Compose is v{_compose_major}."
)

def test_sagemaker_train_local_container_accepts_installed_compose(
self, _train_container
):
"""sagemaker-train local/local_container.py
_LocalContainer._get_compose_cmd_prefix should accept the installed version."""
result = _train_container._get_compose_cmd_prefix()

assert result == ["docker", "compose"], (
f"Expected ['docker', 'compose'] but got {result}. "
f"Installed Docker Compose is v{_compose_major}."
)

def test_returned_command_is_functional(self):
"""The command returned by _get_compose_cmd_prefix should actually work."""
cmd = _SageMakerContainer._get_compose_cmd_prefix()

# Run the returned command with "version" to prove it's functional
result = subprocess.run(
cmd + ["version"],
capture_output=True,
text=True,
timeout=10,
)
assert result.returncode == 0, (
f"Command {cmd + ['version']} failed: {result.stderr}"
)
assert "version" in result.stdout.lower(), (
f"Unexpected output from {cmd + ['version']}: {result.stdout}"
)

@pytest.mark.skipif(
_compose_major is not None and _compose_major < 3,
reason="This test specifically validates v3+ acceptance (installed is v2)",
)
def test_v3_plus_specifically_accepted(self):
"""When Docker Compose v3+ is installed, it must be accepted — not rejected.

This is the core regression test for issue #5739.
"""
result = _SageMakerContainer._get_compose_cmd_prefix()
assert result == ["docker", "compose"], (
f"Docker Compose v{_compose_major} was rejected. "
"This is the exact bug described in issue #5739."
)
Loading
Loading