diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 44d0949..3c470df 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -3,6 +3,8 @@ ## PR Checklist - [ ] Merged latest master - [ ] Updated version number in `pyproject.toml`. +- [ ] Added tests for new features or bug fixes. +- [ ] Passed all tests - [ ] Update README.md if needed. ## Breaking Changes diff --git a/.github/tasks.md b/.github/tasks.md index 40ae55c..256eeae 100644 --- a/.github/tasks.md +++ b/.github/tasks.md @@ -1,2 +1 @@ ## Tasks -- [] diff --git a/README.md b/README.md index 1ca4b12..ec8f61b 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,25 @@ S3MPConfig.set_mirror_root("s3_mirror") S3MPConfig.assume_role("arn:aws:iam:::role/") ``` +To manage projects that require managing different S3 buckets or different IAM roles, you can create different `MirrorPath` objects with different `bucket_key` and `iam_role_arn` parameters. These will use the appropriate sessions and clients under the hood, so you can easily interact with multiple buckets and roles within the same project if necessary. When no bucket or role is specified, the defaults from `S3MPConfig` are used — if `S3MPConfig.assume_role(...)` was called, that role becomes the default session; otherwise the ambient AWS credentials are used. + +```python +from S3MP.mirror_path import MirrorPath + +# MirrorPath using the default bucket and default session from S3MPConfig +default_mp = MirrorPath.from_s3_key("path/to/object.jpg") +# MirrorPath using specific bucket and IAM role +custom_mp = MirrorPath.from_s3_key( + "path/to/object.jpg", + bucket_key="custom-bucket", + iam_role_arn="arn:aws:iam:::role/" +) +# MirrorPath using bucket from s3 url and the default session +# The bucket key will be parsed from the url +url_mp = MirrorPath.from_s3_key("s3://custom-bucket/path/to/object.jpg") +``` + + ## Installation [uv](https://docs.astral.sh/uv/) is a fast, cross-platform Python package installer and resolver. diff --git a/S3MP/__init__.py b/S3MP/__init__.py index 2424519..f5a31d2 100644 --- a/S3MP/__init__.py +++ b/S3MP/__init__.py @@ -1,5 +1,6 @@ """S3 MirrorPath package.""" from S3MP._version import __version__ +from S3MP.global_config import S3Session -__all__ = ["__version__"] +__all__ = ["__version__", "S3Session"] diff --git a/S3MP/async_utils.py b/S3MP/async_utils.py index 4983c8a..cbe09b8 100644 --- a/S3MP/async_utils.py +++ b/S3MP/async_utils.py @@ -12,8 +12,20 @@ async def async_upload_from_mirror(mirror_path: MirrorPath): """Asynchronously upload a file from a MirrorPath.""" session = aioboto3.Session() + if mirror_path.iam_role_arn: + async with session.client("sts") as sts_client: + assumed = await sts_client.assume_role( + RoleArn=mirror_path.iam_role_arn, + RoleSessionName="S3MPAsyncUploadSession", + ) + creds = assumed["Credentials"] + session = aioboto3.Session( + aws_access_key_id=creds["AccessKeyId"], + aws_secret_access_key=creds["SecretAccessKey"], + aws_session_token=creds["SessionToken"], + ) async with session.resource("s3") as s3_resource: - bucket = await s3_resource.Bucket(S3MPConfig.default_bucket_key) + bucket = s3_resource.Bucket(mirror_path.bucket_key) await bucket.upload_file(str(mirror_path.local_path), mirror_path.s3_key) @@ -22,7 +34,7 @@ def upload_from_mirror_thread( ) -> Coroutine: """Upload from mirror on a separate thread.""" return asyncio.to_thread( - S3MPConfig.bucket.upload_file, + mirror_path.bucket.upload_file, str(mirror_path.local_path), mirror_path.s3_key, Callback=S3MPConfig.callback, diff --git a/S3MP/callbacks.py b/S3MP/callbacks.py index 39f53ec..82a00d4 100644 --- a/S3MP/callbacks.py +++ b/S3MP/callbacks.py @@ -31,22 +31,27 @@ def __init__( """ if transfer_objs is None: return + if not isinstance(transfer_objs, list): + transfer_objs = [transfer_objs] + + # Fall back to global defaults for non-MirrorPath objects if resource is None: resource = S3MPConfig.s3_resource if bucket_key is None: bucket_key = S3MPConfig.default_bucket_key - if not isinstance(transfer_objs, list): - transfer_objs = [transfer_objs] self._total_bytes = 0 for transfer_mapping in transfer_objs: if is_download: - s3_key = str( - transfer_mapping.s3_key - if isinstance(transfer_mapping, MirrorPath) - else transfer_mapping - ) - self._total_bytes += resource.Object(bucket_key, s3_key).content_length + if isinstance(transfer_mapping, MirrorPath): + mp_resource = transfer_mapping.session.s3_resource + mp_bucket_key = transfer_mapping.bucket_key + s3_key = transfer_mapping.s3_key + else: + mp_resource = resource + mp_bucket_key = bucket_key + s3_key = str(transfer_mapping) + self._total_bytes += mp_resource.Object(mp_bucket_key, s3_key).content_length else: local_path = ( transfer_mapping.local_path diff --git a/S3MP/global_config.py b/S3MP/global_config.py index e7c06c3..38c942d 100644 --- a/S3MP/global_config.py +++ b/S3MP/global_config.py @@ -1,5 +1,7 @@ """Set global values for S3MP module.""" +from __future__ import annotations + import tempfile from collections.abc import Callable from configparser import ConfigParser @@ -13,6 +15,58 @@ from S3MP.types import S3Bucket, S3Client, S3Resource, S3TransferConfig +@dataclass +class S3Session: + """Holds cached boto3 objects for a single IAM credential context.""" + + s3_client: S3Client + s3_resource: S3Resource + _bucket_map: dict[str, S3Bucket] | None = None + + @staticmethod + def from_role_arn(role_arn: str, boto3_config: Config | None = None) -> S3Session: + """Create a session by assuming an IAM role.""" + sts_client = boto3.client("sts") + assumed_role = sts_client.assume_role( + RoleArn=role_arn, RoleSessionName="S3MPAssumeRoleSession" + ) + credentials = assumed_role["Credentials"] + cfg = boto3_config or Config() + return S3Session( + s3_client=boto3.client( + "s3", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + config=cfg, + ), + s3_resource=boto3.resource( + "s3", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + config=cfg, + ), + ) + + @staticmethod + def no_role(boto3_config: Config | None = None) -> S3Session: + """Create a session using no IAM role.""" + cfg = boto3_config or Config() + return S3Session( + s3_client=boto3.client("s3", config=cfg), + s3_resource=boto3.resource("s3", config=cfg), + ) + + def get_bucket(self, bucket_key: str) -> S3Bucket: + """Get boto3 S3Bucket object.""" + if self._bucket_map is None: + self._bucket_map = {} + if bucket_key not in self._bucket_map: + self._bucket_map[bucket_key] = self.s3_resource.Bucket(bucket_key) + return self._bucket_map[bucket_key] + + def get_config_file_path() -> Path: """Get the location of the config file.""" root_module_folder = Path(__file__).parent.resolve() @@ -33,10 +87,8 @@ def __call__(cls, *args, **kwargs): class _S3MPConfigClass(metaclass=Singleton): """Singleton class for S3MP globals.""" - # Boto3 Objects - _s3_client: S3Client | None = None - _s3_resource: S3Resource | None = None - _bucket: S3Bucket | None = None + # Session registry: maps role ARN -> S3Session (None key = default session) + _session_map: dict[str | None, S3Session] | None = None _boto3_config: Config | None = None # Config Items @@ -50,32 +102,28 @@ class _S3MPConfigClass(metaclass=Singleton): callback: Callable | None = None use_async_global_thread_queue: bool = True - def assume_role(self, role_arn: str) -> None: - """Assume an IAM role and update the S3 client and resource with the new credentials.""" - sts_client = boto3.client("sts") - assumed_role = sts_client.assume_role( - RoleArn=role_arn, RoleSessionName="S3MPAssumeRoleSession" - ) - credentials = assumed_role["Credentials"] + def get_session(self, role_arn: str | None = None) -> S3Session: + """Get or create a cached S3Session for the given role ARN. - self._s3_client = boto3.client( - "s3", - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - config=self.boto3_config, - ) - self._s3_resource = boto3.resource( - "s3", - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - config=self.boto3_config, - ) - self._iam_role_arn = role_arn + Args: + role_arn: IAM role ARN. None uses the no-role session. + """ + if self._session_map is None: + self._session_map = {} + + if role_arn not in self._session_map: + if role_arn is not None: + self._session_map[role_arn] = S3Session.from_role_arn(role_arn, self.boto3_config) + else: + self._session_map[None] = S3Session.no_role(self.boto3_config) - # Clear cached bucket - self._bucket = None + return self._session_map[role_arn] + + def assume_role(self, role_arn: str) -> None: + """Set the default IAM role for the global config.""" + self._iam_role_arn = role_arn + # Pre-cache the session for this role + self.get_session(role_arn) @property def default_bucket_key(self) -> str: @@ -89,14 +137,10 @@ def default_bucket_key(self) -> str: def set_default_bucket_key(self, bucket_key: str) -> None: """Set default bucket key.""" self._default_bucket_key = bucket_key - # Clear cached bucket - self._bucket = None def clear_boto3_cache(self) -> None: - """Clear cached boto3 client and resource.""" - self._s3_client = None - self._s3_resource = None - self._bucket = None + """Clear cached boto3 sessions, config, and buckets.""" + self._session_map = {} self._boto3_config = None @property @@ -123,34 +167,24 @@ def boto3_config(self) -> Config: return self._boto3_config @property - def s3_client(self) -> S3Client: - """Get S3 client.""" - if not self._s3_client and self._iam_role_arn: - self.assume_role(self._iam_role_arn) - - if not self._s3_client: - self._s3_client = boto3.client("s3", config=self.boto3_config) + def default_session(self) -> S3Session: + """Get the default session (uses default role if set, otherwise default credentials).""" + return self.get_session(self._iam_role_arn) - return self._s3_client + @property + def s3_client(self) -> S3Client: + """Get S3 client from the default session.""" + return self.default_session.s3_client @property def s3_resource(self) -> S3Resource: - """Get S3 resource.""" - if not self._s3_resource and self._iam_role_arn: - self.assume_role(self._iam_role_arn) - - if not self._s3_resource: - self._s3_resource = boto3.resource("s3", config=self.boto3_config) - - return self._s3_resource + """Get S3 resource from the default session.""" + return self.default_session.s3_resource def get_bucket(self, bucket_key: str | None = None) -> S3Bucket: - """Get bucket.""" - if bucket_key: - return self.s3_resource.Bucket(bucket_key) - elif self._bucket is None: - self._bucket = self.s3_resource.Bucket(self.default_bucket_key) - return self._bucket + """Get boto3 S3Bucket object from the default session.""" + bucket_key = bucket_key or self.default_bucket_key + return self.default_session.get_bucket(bucket_key) @property def bucket(self) -> S3Bucket: diff --git a/S3MP/keys.py b/S3MP/keys.py index 1519cac..142f2ca 100644 --- a/S3MP/keys.py +++ b/S3MP/keys.py @@ -5,6 +5,8 @@ from enum import Enum from S3MP.prefix_queries import get_files_within_folder, get_folders_within_folder +from S3MP.types import S3Client +from S3MP.utils.local_file_utils import has_file_extension @dataclass @@ -13,10 +15,15 @@ class KeySegment: depth: int name: str | None = None - is_file: bool = False # Most things are folders. - incomplete_name: str | None = ( - None # Used when searching for part of a key segment (i.e. a file extension). - ) + # is_file: Show as strongly typed for mypy. It will be auto-inferred in __post_init__ if not provided. + is_file: bool = None # type: ignore[assignment] + # incomplete_name: Used when searching for part of a key segment (i.e. a file extension). + incomplete_name: str | None = None + + def __post_init__(self): + """Auto-infer is_file from name when not explicitly provided.""" + if self.is_file is None: + self.is_file = has_file_extension(self.name) if self.name else False def __call__(self, *args, **kwargs): """Set data via calling.""" @@ -36,6 +43,10 @@ def __call__(self, *args, **kwargs): if key in kwargs: setattr(self, key, kwargs[key]) + # Re-infer is_file if the name changed and is_file wasn't explicitly set + if "is_file" not in kwargs and (args or "name" in kwargs): + self.is_file = has_file_extension(self.name) if self.name else False + return self # For chaining def __copy__(self): @@ -118,12 +129,18 @@ def replace_key_segments_at_relative_depth(key: str, segments: list[KeySegment]) return "/".join(key_segments) -def unpack_s3_obj_generator(path: str, filter_name: str, is_file: bool): +def unpack_s3_obj_generator( + path: str, + filter_name: str, + is_file: bool, + bucket_key: str | None = None, + client: S3Client | None = None, +) -> list[str]: """Produce generator for S3 objects, and then unpack it. Used for multiprocessing.""" if is_file: - objs_at_depth = get_files_within_folder(path, filter_name) + objs_at_depth = get_files_within_folder(path, bucket_key, filter_name, client=client) else: - objs_at_depth = get_folders_within_folder(path, filter_name) + objs_at_depth = get_folders_within_folder(path, bucket_key, filter_name, client=client) return [f"{path}{obj}" for obj in objs_at_depth] @@ -140,7 +157,11 @@ def get_filter_name(segments: list[KeySegment], current_depth: int) -> str | Non async def dfs_matching_key_gen( - segments: list[KeySegment], path: str | None = None, current_depth: int | None = None + segments: list[KeySegment], + path: str | None = None, + current_depth: int | None = None, + bucket_key: str | None = None, + client: S3Client | None = None, ): """Generate all matching keys from a path, depth first.""" if current_depth is None: @@ -152,9 +173,11 @@ async def dfs_matching_key_gen( # Ensure path and filter_name are not None assert path is not None if filter_name is not None: - paths_at_depth = unpack_s3_obj_generator(path, filter_name, file_search_flag) + paths_at_depth = unpack_s3_obj_generator( + path, filter_name, file_search_flag, bucket_key, client + ) else: - paths_at_depth = unpack_s3_obj_generator(path, "", file_search_flag) + paths_at_depth = unpack_s3_obj_generator(path, "", file_search_flag, bucket_key, client) if current_depth == segments[-1].depth: for path in paths_at_depth: yield path @@ -163,12 +186,18 @@ async def dfs_matching_key_gen( return for path in paths_at_depth: - async for matching_key in dfs_matching_key_gen(segments, path, current_depth + 1): + async for matching_key in dfs_matching_key_gen( + segments, path, current_depth + 1, bucket_key, client + ): yield matching_key def sync_dfs_matching_key_gen( - segments: list[KeySegment], path: str | None = None, current_depth: int | None = None + segments: list[KeySegment], + path: str | None = None, + current_depth: int | None = None, + bucket_key: str | None = None, + client: S3Client | None = None, ): """Synchronous generation of all matching keys from a path, depth first.""" if current_depth is None: @@ -180,9 +209,11 @@ def sync_dfs_matching_key_gen( # Ensure path and filter_name are not None assert path is not None if filter_name is not None: - paths_at_depth = unpack_s3_obj_generator(path, filter_name, file_search_flag) + paths_at_depth = unpack_s3_obj_generator( + path, filter_name, file_search_flag, bucket_key, client + ) else: - paths_at_depth = unpack_s3_obj_generator(path, "", file_search_flag) + paths_at_depth = unpack_s3_obj_generator(path, "", file_search_flag, bucket_key, client) if current_depth == segments[-1].depth: yield from paths_at_depth n_paths = len(paths_at_depth) @@ -190,10 +221,14 @@ def sync_dfs_matching_key_gen( return for path in paths_at_depth: - yield from sync_dfs_matching_key_gen(segments, path, current_depth + 1) + yield from sync_dfs_matching_key_gen(segments, path, current_depth + 1, bucket_key, client) -def get_matching_s3_keys(segments: list[KeySegment]) -> list[str]: +def get_matching_s3_keys( + segments: list[KeySegment], + bucket_key: str | None = None, + client: S3Client | None = None, +) -> list[str]: """ Get all S3 keys matching the given segments. @@ -215,7 +250,7 @@ def get_matching_s3_keys(segments: list[KeySegment]) -> list[str]: # Search for files at max depth file_search_flag = (current_depth == max_depth) and (segments[-1].is_file) new_paths = [ - unpack_s3_obj_generator(path, filter_name or "", file_search_flag) + unpack_s3_obj_generator(path, filter_name or "", file_search_flag, bucket_key, client) for path in current_paths ] current_paths = list(itertools.chain(*new_paths)) diff --git a/S3MP/mirror_path.py b/S3MP/mirror_path.py index 0e38dd2..42d5fa6 100644 --- a/S3MP/mirror_path.py +++ b/S3MP/mirror_path.py @@ -10,8 +10,9 @@ import psutil from tqdm import tqdm -from S3MP.global_config import S3MPConfig +from S3MP.global_config import S3MPConfig, S3Session from S3MP.keys import KeySegment, get_matching_s3_keys +from S3MP.types import S3Bucket, S3Client from S3MP.utils.local_file_utils import ( DEFAULT_LOAD_LEDGER, DEFAULT_SAVE_LEDGER, @@ -34,31 +35,54 @@ def __init__( self, key_segments: list[KeySegment], mirror_root: Path | None = None, + bucket_key: str | None = None, + iam_role_arn: str | None = None, ): """Init.""" # Solving issues before they happen self.key_segments: list[KeySegment] = [seg.__copy__() for seg in key_segments] self._local_path_override: Path | None = None - + self.bucket_key = bucket_key or S3MPConfig.default_bucket_key self.mirror_root = mirror_root or S3MPConfig.mirror_root + self.iam_role_arn = iam_role_arn or S3MPConfig._iam_role_arn @property def s3_key(self) -> str: - """Get s3 key.""" + """Get s3 key, without bucket prefix.""" ret_key = "/".join([str(s.name) for s in self.key_segments]) - # We'll infer folder/file based on extension - # HACK to catch case where the "extension" is actually a part of the folder name - # (eg, a folder named "v0.1.0"), we check if the extension is actually a number - name = self.key_segments[-1].name - if name is None: + if self.key_segments[-1].is_file: return ret_key - ext = name.split(".")[-1] - return ret_key if (ext is not name and not ext.isdigit()) else f"{ret_key}/" + return f"{ret_key}/" + + @property + def s3_url(self) -> str: + """Get full s3 url.""" + return f"s3://{self.bucket_key}/{self.s3_key}" + + @property + def session(self) -> S3Session: + """Get the S3Session for this MirrorPath.""" + return S3MPConfig.get_session(self.iam_role_arn) + + @property + def s3_client(self) -> S3Client: + """Get the S3 client for this MirrorPath.""" + return self.session.s3_client + + @property + def bucket(self) -> S3Bucket: + """Get the S3 bucket for this MirrorPath.""" + return self.session.get_bucket(self.bucket_key) @property def local_path(self) -> Path: """Get local path.""" - return self._local_path_override or Path(S3MPConfig.mirror_root) / self.s3_key + return self._local_path_override or Path(self.mirror_root) / self.bucket_key / self.s3_key + + @property + def is_file(self) -> bool: + """Check if this MirrorPath represents a file.""" + return self.key_segments[-1].is_file def override_local_path(self, local_path: Path): """Override local path.""" @@ -66,21 +90,76 @@ def override_local_path(self, local_path: Path): @staticmethod def from_s3_key(s3_key: str, **kwargs) -> MirrorPath: - """Create a MirrorPath from an s3 key.""" - s3_key = s3_key[:-1] if s3_key.endswith("/") else s3_key - key_segments = [KeySegment(idx, s) for idx, s in enumerate(s3_key.split("/"))] + """Create a MirrorPath from an s3 key or s3 url.""" + # Parse bucket name if key is in s3://bucket/key format + if s3_key.startswith("s3://"): + # Remove s3:// prefix and split bucket name from key + s3_key = s3_key[5:] + bucket_name, _, s3_key = s3_key.partition("/") + if "bucket_key" not in kwargs: + kwargs["bucket_key"] = bucket_name + + # Trailing "/" always means folder; otherwise KeySegment infers from the name + is_folder = s3_key.endswith("/") + s3_key = s3_key.rstrip("/") + + if not s3_key: + raise ValueError( + "Cannot create a MirrorPath from an empty S3 key. " + "Provide at least one path segment (e.g. 'folder/' or 'file.txt')." + ) + + key_segments = [] + seg_strings = s3_key.split("/") + for idx, s in enumerate(seg_strings): + if is_folder or idx < len(seg_strings) - 1: + # Intermediate segments and explicit folders are never files + key_segments.append(KeySegment(idx, s, is_file=False)) + else: + # Last segment: let KeySegment auto-infer from the name + key_segments.append(KeySegment(idx, s)) + return MirrorPath(key_segments, **kwargs) @staticmethod def from_local_path(local_path: Path, mirror_root: Path | None = None, **kwargs) -> MirrorPath: - """Create a MirrorPath from a local path.""" + """Create a MirrorPath from a local path. + + Assumes the local path follows the mirror layout: + mirror_root / bucket_key / s3_key + """ mirror_root = mirror_root or S3MPConfig.mirror_root - s3_key = local_path.relative_to(mirror_root).as_posix() + relative = local_path.relative_to(mirror_root) + parts = relative.parts + if len(parts) < 2: + raise ValueError( + f"Local path '{local_path}' is not a valid mirror path under '{mirror_root}'. " + "Expected layout: mirror_root / bucket_key / s3_key." + ) + bucket_key = parts[0] + s3_key = "/".join(parts[1:]) + if "bucket_key" not in kwargs: + kwargs["bucket_key"] = bucket_key + if "mirror_root" not in kwargs: + kwargs["mirror_root"] = mirror_root return MirrorPath.from_s3_key(s3_key, **kwargs) + def set_bucket_key(self, bucket_key: str) -> None: + """Set bucket key for this MirrorPath.""" + self.bucket_key = bucket_key + + def set_iam_role_arn(self, iam_role_arn: str) -> None: + """Set IAM role ARN for this MirrorPath.""" + self.iam_role_arn = iam_role_arn + def __copy__(self): """Copy.""" - return MirrorPath(self.key_segments, **self.__dict__) + return MirrorPath( + self.key_segments, + mirror_root=self.mirror_root, + bucket_key=self.bucket_key, + iam_role_arn=self.iam_role_arn, + ) def __repr__(self): """Class representation.""" @@ -92,11 +171,11 @@ def exists_in_mirror(self) -> bool: def exists_on_s3(self) -> bool: """Check if file exists on S3.""" - return key_exists_on_s3(self.s3_key) + return key_exists_on_s3(self.s3_key, self.bucket, self.s3_client) def is_file_on_s3(self) -> bool: """Check if is a file on s3.""" - return key_is_file_on_s3(self.s3_key) + return key_is_file_on_s3(self.s3_key, self.bucket, self.s3_client) def is_file_and_exists_on_s3(self) -> bool: """Check if is a file and exists on s3.""" @@ -113,7 +192,7 @@ def download_to_mirror(self, overwrite: bool = False): if not overwrite and self.exists_in_mirror(): self.update_callback_on_skipped_transfer() return - download_key(self.s3_key, self.local_path) + download_key(self.s3_key, self.local_path, self.bucket, self.s3_client) def download_to_mirror_if_not_present(self): """Download to mirror if not present in mirror.""" @@ -124,7 +203,7 @@ def upload_from_mirror(self, overwrite: bool = False): if not overwrite and self.exists_on_s3(): self.update_callback_on_skipped_transfer() return - upload_to_key(self.s3_key, self.local_path) + upload_to_key(self.s3_key, self.local_path, self.bucket, self.s3_client) def upload_from_mirror_if_not_present(self): """Upload from mirror if not present on S3.""" @@ -132,7 +211,12 @@ def upload_from_mirror_if_not_present(self): def trim(self, max_depth) -> MirrorPath: """Trim key from s3 key.""" - return MirrorPath(self.key_segments[:max_depth]) + return MirrorPath( + self.key_segments[:max_depth], + mirror_root=self.mirror_root, + bucket_key=self.bucket_key, + iam_role_arn=self.iam_role_arn, + ) def get_key_segment(self, index: int) -> KeySegment: """Get key segment.""" @@ -143,9 +227,14 @@ def replace_key_segments(self, replace_segments: list[KeySegment]) -> MirrorPath new_segments = self.key_segments[:] for seg in replace_segments: while seg.depth >= len(new_segments): - new_segments.append(KeySegment(len(new_segments) - 1, "")) + new_segments.append(KeySegment(len(new_segments), "")) new_segments[seg.depth] = seg - return MirrorPath(new_segments) + return MirrorPath( + new_segments, + mirror_root=self.mirror_root, + bucket_key=self.bucket_key, + iam_role_arn=self.iam_role_arn, + ) def replace_key_segments_at_relative_depth( self, replace_segments: list[KeySegment] @@ -161,16 +250,34 @@ def get_sibling(self, sibling_name: str) -> MirrorPath: return self.replace_key_segments_at_relative_depth([KeySegment(0, sibling_name)]) def get_child(self, child_name: str) -> MirrorPath: - """Get a child of this file.""" + """Get a child of this folder.""" + if self.is_file: + raise ValueError( + f"Cannot get a child of file key '{self.s3_key}'. " + "get_child is only valid on folder MirrorPaths." + ) return self.replace_key_segments_at_relative_depth([KeySegment(1, child_name)]) def get_children_on_s3(self) -> list[MirrorPath]: """Get all children on s3.""" - return [MirrorPath.from_s3_key(key) for key in s3_list_child_keys(self.s3_key)] + return [ + MirrorPath.from_s3_key( + key, + mirror_root=self.mirror_root, + bucket_key=self.bucket_key, + iam_role_arn=self.iam_role_arn, + ) + for key in s3_list_child_keys(self.s3_key, self.bucket, self.s3_client) + ] def get_parent(self) -> MirrorPath: """Get the parent of this file.""" - return MirrorPath(self.key_segments[:-1]) + return MirrorPath( + self.key_segments[:-1], + mirror_root=self.mirror_root, + bucket_key=self.bucket_key, + iam_role_arn=self.iam_role_arn, + ) def delete_local(self): """Delete local file.""" @@ -178,7 +285,7 @@ def delete_local(self): def delete_s3(self): """Delete s3 file.""" - delete_key_on_s3(self.s3_key) + delete_key_on_s3(self.s3_key, self.bucket, self.s3_client) def delete_all(self): """Delete all files.""" @@ -222,10 +329,9 @@ def save_local( def copy_to_mp_s3_only(self, dest_mp: MirrorPath): """Copy this file from S3 to a destination on S3.""" - bucket_key = S3MPConfig.default_bucket_key - S3MPConfig.s3_client.copy_object( - CopySource={"Bucket": bucket_key, "Key": self.s3_key}, - Bucket=bucket_key, + self.s3_client.copy_object( + CopySource={"Bucket": self.bucket_key, "Key": self.s3_key}, + Bucket=dest_mp.bucket_key, Key=dest_mp.s3_key, ) @@ -272,9 +378,17 @@ def compute_gsd(self, coords: tuple[float, float]) -> float: return ImageMetadata.parse_metadata(self).compute_gsd(coords) -def get_matching_s3_mirror_paths(segments: list[KeySegment]): +def get_matching_s3_mirror_paths( + segments: list[KeySegment], + bucket_key: str | None = None, + iam_role_arn: str | None = None, +): """Get matching S3 mirror paths.""" - return [MirrorPath.from_s3_key(key) for key in get_matching_s3_keys(segments)] + client = S3MPConfig.get_session(iam_role_arn).s3_client if iam_role_arn else None + return [ + MirrorPath.from_s3_key(key, bucket_key=bucket_key, iam_role_arn=iam_role_arn) + for key in get_matching_s3_keys(segments, bucket_key=bucket_key, client=client) + ] def multithread_download_mps_to_mirror(mps: list[MirrorPath], overwrite: bool = False): diff --git a/S3MP/multipart_uploads.py b/S3MP/multipart_uploads.py index ac5041e..a2fbc72 100644 --- a/S3MP/multipart_uploads.py +++ b/S3MP/multipart_uploads.py @@ -6,13 +6,12 @@ from S3MP.global_config import S3MPConfig from S3MP.mirror_path import MirrorPath from S3MP.transfer_configs import MB -from S3MP.types import S3Bucket # TODO prefix optimization def get_mpu(mirror_path: MirrorPath): """Check if a multipart upload has started.""" - bucket: S3Bucket = S3MPConfig.bucket + bucket = mirror_path.bucket mpus = bucket.multipart_uploads.all() for mpu in mpus: if mpu.key == mirror_path.s3_key: diff --git a/S3MP/prefix_queries.py b/S3MP/prefix_queries.py index 2378c56..0c6efa0 100644 --- a/S3MP/prefix_queries.py +++ b/S3MP/prefix_queries.py @@ -5,40 +5,59 @@ from collections.abc import Generator from S3MP.global_config import S3MPConfig +from S3MP.types import S3Client -def get_prefix_paginator(folder_key: str, bucket_key: str | None = None, delimiter: str = "/"): +def get_prefix_paginator( + folder_key: str, + bucket_key: str | None = None, + delimiter: str = "/", + client: S3Client | None = None, +): """Get a paginator for a specified prefix.""" if bucket_key is None: bucket_key = S3MPConfig.default_bucket_key if folder_key != "" and folder_key[-1] != "/": folder_key += "/" - s3_client = S3MPConfig.s3_client + s3_client = client or S3MPConfig.s3_client paginator = s3_client.get_paginator("list_objects_v2") return paginator.paginate(Bucket=bucket_key, Prefix=folder_key, Delimiter=delimiter) def get_files_within_folder( - folder_key: str, key_filter: str | None = None + folder_key: str, + bucket_key: str | None = None, + key_filter: str | None = None, + client: S3Client | None = None, ) -> Generator[str, None, None]: """Get files within a folder.""" - for page in get_prefix_paginator(folder_key): + if folder_key != "" and not folder_key.endswith("/"): + folder_key += "/" + for page in get_prefix_paginator(folder_key, bucket_key, client=client): if "Contents" in page: for obj in page["Contents"]: - obj = obj["Key"].replace(folder_key, "") - if key_filter and key_filter not in obj: + key = obj["Key"] + if key == folder_key: continue - yield obj + relative = key.removeprefix(folder_key) + if key_filter and key_filter not in relative: + continue + yield relative def get_folders_within_folder( - folder_key: str, key_filter: str | None = None + folder_key: str, + bucket_key: str | None = None, + key_filter: str | None = None, + client: S3Client | None = None, ) -> Generator[str, None, None]: """Get folders within folder.""" - for page in get_prefix_paginator(folder_key): + if folder_key != "" and not folder_key.endswith("/"): + folder_key += "/" + for page in get_prefix_paginator(folder_key, bucket_key, client=client): if "CommonPrefixes" in page: for obj in page["CommonPrefixes"]: - obj = obj["Prefix"].replace(folder_key, "") + obj = obj["Prefix"].removeprefix(folder_key) if key_filter and key_filter not in obj: continue yield obj diff --git a/S3MP/utils/local_file_utils.py b/S3MP/utils/local_file_utils.py index f201fb9..2feeb9b 100644 --- a/S3MP/utils/local_file_utils.py +++ b/S3MP/utils/local_file_utils.py @@ -4,6 +4,16 @@ from pathlib import Path +def has_file_extension(name: str) -> bool: + """Heuristic: treat a segment as a file if it has an alphabetic extension. + + Handles dotted version-like folder names (e.g. ``v0.1.0``) by requiring + at least one alphabetic character in the suffix. + """ + suffix = Path(name).suffix # e.g. ".txt", ".0", "" + return bool(suffix) and any(c.isalpha() for c in suffix) + + def get_local_file_size_bytes(path: Path) -> int: """Get the size of a local file in bytes.""" return path.stat().st_size diff --git a/S3MP/utils/s3_utils.py b/S3MP/utils/s3_utils.py index d0b9879..3813e6c 100644 --- a/S3MP/utils/s3_utils.py +++ b/S3MP/utils/s3_utils.py @@ -85,8 +85,10 @@ def download_key( Config=S3MPConfig.transfer_config, # type: ignore[arg-type] ) else: - for child_key in s3_list_child_keys(key, bucket, client): - download_key(child_key, local_path / child_key.replace(key, "")) + folder_key = key if key.endswith("/") else f"{key}/" + for child_key in s3_list_child_keys(folder_key, bucket, client): + relative = child_key.removeprefix(folder_key) + download_key(child_key, local_path / relative, bucket, client) def upload_to_key( @@ -139,7 +141,10 @@ def key_is_file_on_s3( assert client is not None try: - client.head_object(Bucket=bucket.name, Key=key) + response = client.head_object(Bucket=bucket.name, Key=key) + # Zero-byte objects with trailing "/" are folder markers, not files + if key.endswith("/") and response["ContentLength"] == 0: + return False return True except Exception as e: # 404 occurs if the key is a "folder" or does not exist diff --git a/pyproject.toml b/pyproject.toml index 76017f1..546b531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "S3MP" -version = "0.8.4" +version = "0.9.0" description = "" authors = [ {name = "Joshua Dean", email = "joshua.dean@sentera.com"}, diff --git a/tests/test_global_config.py b/tests/test_global_config.py new file mode 100644 index 0000000..fbc82fe --- /dev/null +++ b/tests/test_global_config.py @@ -0,0 +1,188 @@ +"""Tests for global_config.py — config lifecycle (no S3 required for most).""" + +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from S3MP.global_config import S3MPConfig, _S3MPConfigClass + +# ── helpers ─────────────────────────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def _reset_config(tmp_path): + """Reset S3MPConfig state before each test.""" + S3MPConfig._default_bucket_key = None + S3MPConfig._mirror_root = None + S3MPConfig._iam_role_arn = None + S3MPConfig._max_pool_connections = None + S3MPConfig._session_map = None + S3MPConfig._boto3_config = None + S3MPConfig.transfer_config = None + S3MPConfig.callback = None + yield + + +# ── default_bucket_key ──────────────────────────────────────────────────── + + +class TestDefaultBucketKey: + """Tests for default_bucket_key get/set.""" + + def test_raises_when_not_set(self): + with pytest.raises(ValueError, match="No default bucket key set"): + _ = S3MPConfig.default_bucket_key + + def test_set_and_get(self): + S3MPConfig.set_default_bucket_key("my-bucket") + assert S3MPConfig.default_bucket_key == "my-bucket" + + def test_overwrite(self): + S3MPConfig.set_default_bucket_key("first") + S3MPConfig.set_default_bucket_key("second") + assert S3MPConfig.default_bucket_key == "second" + + +# ── mirror_root ─────────────────────────────────────────────────────────── + + +class TestMirrorRoot: + """Tests for mirror_root get/set.""" + + def test_default_uses_temp_dir(self): + root = S3MPConfig.mirror_root + assert root.exists() # tempdir should exist + + def test_set_absolute_path(self, tmp_path): + S3MPConfig.set_mirror_root(tmp_path) + assert S3MPConfig.mirror_root == tmp_path + + def test_set_string_path(self, tmp_path): + S3MPConfig.set_mirror_root(str(tmp_path)) + assert S3MPConfig.mirror_root == tmp_path + + def test_set_relative_path_windows(self): + with patch.object(sys.modules["S3MP.global_config"], "platform", "win32"): + S3MPConfig.set_mirror_root("data/mirror") + assert S3MPConfig.mirror_root == Path("C:\\data\\mirror") + + def test_set_relative_path_linux(self): + with patch.object(sys.modules["S3MP.global_config"], "platform", "linux"): + S3MPConfig.set_mirror_root("data/mirror") + assert S3MPConfig.mirror_root == Path("/data/mirror") + + +# ── max_pool_connections / boto3_config ─────────────────────────────────── + + +class TestMaxPoolConnections: + """Tests for max_pool_connections and boto3_config interaction.""" + + def test_default_is_none(self): + assert S3MPConfig.max_pool_connections is None + + def test_set_max_pool_connections(self): + S3MPConfig.set_max_pool_connections(50) + assert S3MPConfig.max_pool_connections == 50 + + def test_boto3_config_reflects_pool_connections(self): + S3MPConfig.set_max_pool_connections(25) + assert S3MPConfig.max_pool_connections == 25 + # Verify a fresh Config object is created (not None) + assert S3MPConfig.boto3_config is not None + + def test_boto3_config_without_pool_connections(self): + cfg = S3MPConfig.boto3_config + # Should still return a valid Config object + assert cfg is not None + + def test_set_pool_connections_clears_cache(self): + # Access boto3_config to cache it + _ = S3MPConfig.boto3_config + S3MPConfig.set_max_pool_connections(100) + # _boto3_config should have been cleared + assert S3MPConfig._boto3_config is None + + +# ── clear_boto3_cache ───────────────────────────────────────────────────── + + +class TestClearBoto3Cache: + """Tests for clear_boto3_cache.""" + + def test_clears_session_map(self): + S3MPConfig._session_map = {} + S3MPConfig.clear_boto3_cache() + assert S3MPConfig._session_map == {} + + def test_clears_boto3_config(self): + _ = S3MPConfig.boto3_config # Cache it + S3MPConfig.clear_boto3_cache() + assert S3MPConfig._boto3_config is None + + +# ── save_config / load_config roundtrip ─────────────────────────────────── + + +class TestConfigRoundtrip: + """Tests for save_config and load_config.""" + + def test_roundtrip_bucket_and_mirror(self, tmp_path): + config_path = tmp_path / "test_config.ini" + S3MPConfig.set_default_bucket_key("roundtrip-bucket") + S3MPConfig.set_mirror_root(tmp_path / "mirror") + S3MPConfig.save_config(config_path) + + # Reset and reload + S3MPConfig._default_bucket_key = None + S3MPConfig._mirror_root = None + S3MPConfig.load_config(config_path) + + assert S3MPConfig.default_bucket_key == "roundtrip-bucket" + assert S3MPConfig._mirror_root == Path(str(tmp_path / "mirror")) + + def test_roundtrip_max_pool_connections(self, tmp_path): + config_path = tmp_path / "test_config.ini" + S3MPConfig.set_default_bucket_key("bucket") + S3MPConfig.set_max_pool_connections(42) + S3MPConfig.save_config(config_path) + + S3MPConfig._max_pool_connections = None + S3MPConfig.load_config(config_path) + assert S3MPConfig.max_pool_connections == 42 + + def test_load_empty_config(self, tmp_path): + config_path = tmp_path / "empty_config.ini" + config_path.write_text("") + # Should not raise + S3MPConfig.load_config(config_path) + + def test_load_missing_file_is_noop(self, tmp_path): + config_path = tmp_path / "nonexistent.ini" + # ConfigParser.read is a no-op for missing files + S3MPConfig.load_config(config_path) + + def test_save_only_set_fields(self, tmp_path): + config_path = tmp_path / "partial.ini" + S3MPConfig.set_default_bucket_key("only-bucket") + S3MPConfig.save_config(config_path) + + content = config_path.read_text() + assert "only-bucket" in content + # mirror_root and iam_role_arn should not appear + assert "mirror_root" not in content + assert "iam_role_arn" not in content + + +# ── Singleton behavior ──────────────────────────────────────────────────── + + +class TestSingleton: + """Tests for Singleton metaclass.""" + + def test_same_instance(self): + a = _S3MPConfigClass() + b = _S3MPConfigClass() + assert a is b diff --git a/tests/test_keys.py b/tests/test_keys.py new file mode 100644 index 0000000..84e7e52 --- /dev/null +++ b/tests/test_keys.py @@ -0,0 +1,277 @@ +"""Tests for keys.py — pure unit tests (no S3 required).""" + +from copy import copy + +from S3MP.keys import ( + KeySegment, + build_s3_key, + get_arbitrary_keys_from_names, + get_filter_name, + get_segments_from_key, + replace_key_segments, + replace_key_segments_at_relative_depth, +) + +# ── KeySegment ──────────────────────────────────────────────────────────── + + +class TestKeySegment: + """Tests for KeySegment dataclass.""" + + def test_basic_construction(self): + seg = KeySegment(depth=0, name="folder") + assert seg.depth == 0 + assert seg.name == "folder" + assert seg.is_file is False + assert seg.incomplete_name is None + + def test_file_segment(self): + seg = KeySegment(depth=2, name="image.png", is_file=True) + assert seg.is_file is True + assert seg.name == "image.png" + + def test_incomplete_name(self): + seg = KeySegment(depth=1, incomplete_name=".txt") + assert seg.name is None + assert seg.incomplete_name == ".txt" + + def test_copy(self): + seg = KeySegment(depth=0, name="folder", is_file=True, incomplete_name="part") + seg_copy = copy(seg) + assert seg_copy.depth == seg.depth + assert seg_copy.name == seg.name + assert seg_copy.is_file == seg.is_file + assert seg_copy.incomplete_name == seg.incomplete_name + # Verify independence + seg_copy.name = "changed" + assert seg.name == "folder" + + def test_copy_method(self): + seg = KeySegment(depth=0, name="test") + seg_copy = seg.copy() + seg_copy.name = "modified" + assert seg.name == "test" + + def test_call_with_string(self): + seg = KeySegment(depth=0) + result = seg("my_folder") + assert result.name == "my_folder" + # Should return a new instance (copy) + assert result is not seg + + def test_call_with_kwargs(self): + seg = KeySegment(depth=0, name="original") + result = seg(is_file=True, name="new_name") + assert result.name == "new_name" + assert result.is_file is True + assert seg.name == "original" # Original unchanged + + def test_call_with_int(self): + seg = KeySegment(depth=0) + result = seg(42) + assert result.name == "42" + + def test_call_with_enum(self): + from enum import Enum + + class Color(Enum): + RED = "red" + + seg = KeySegment(depth=0) + result = seg(Color.RED) + assert result.name == "red" + + def test_repr(self): + seg = KeySegment(depth=1, name="test", is_file=True, incomplete_name="part") + r = repr(seg) + assert "KeySegment" in r + assert "depth=1" in r + assert "name=test" in r + assert "is_file=True" in r + assert "incomplete_name=part" in r + + +# ── get_arbitrary_keys_from_names ───────────────────────────────────────── + + +class TestGetArbitraryKeysFromNames: + """Tests for get_arbitrary_keys_from_names.""" + + def test_basic(self): + segments = get_arbitrary_keys_from_names(["a", "b", "c"]) + assert len(segments) == 3 + assert segments[0].depth == 0 + assert segments[0].name == "a" + assert segments[2].depth == 2 + assert segments[2].name == "c" + + def test_all_are_folders_by_default(self): + segments = get_arbitrary_keys_from_names(["x", "y"]) + assert all(not seg.is_file for seg in segments) + + def test_empty_list(self): + segments = get_arbitrary_keys_from_names([]) + assert segments == [] + + def test_single_name(self): + segments = get_arbitrary_keys_from_names(["only"]) + assert len(segments) == 1 + assert segments[0].name == "only" + assert segments[0].depth == 0 + + +# ── get_segments_from_key ───────────────────────────────────────────────── + + +class TestGetSegmentsFromKey: + """Tests for get_segments_from_key.""" + + def test_simple_key(self): + segments = get_segments_from_key("a/b/c") + assert len(segments) == 3 + assert segments[0].name == "a" + assert segments[1].name == "b" + assert segments[2].name == "c" + assert segments[2].depth == 2 + + def test_single_segment(self): + segments = get_segments_from_key("root") + assert len(segments) == 1 + assert segments[0].name == "root" + + def test_key_with_trailing_slash(self): + segments = get_segments_from_key("folder/subfolder/") + # split produces an empty trailing string + assert segments[-1].name == "" + + def test_key_with_extension(self): + segments = get_segments_from_key("data/file.csv") + assert segments[-1].name == "file.csv" + + def test_depths_are_sequential(self): + segments = get_segments_from_key("a/b/c/d/e") + assert [s.depth for s in segments] == [0, 1, 2, 3, 4] + + +# ── build_s3_key ────────────────────────────────────────────────────────── + + +class TestBuildS3Key: + """Tests for build_s3_key.""" + + def test_contiguous_segments(self): + segments = [ + KeySegment(0, "a"), + KeySegment(1, "b"), + KeySegment(2, "c.txt", is_file=True), + ] + key, depth = build_s3_key(segments) + assert key == "a/b/c.txt" + assert depth == 3 + + def test_gap_in_depths(self): + segments = [ + KeySegment(0, "a"), + KeySegment(2, "c"), # depth 1 is missing + ] + key, depth = build_s3_key(segments) + # Should stop at the first missing depth + assert depth == 1 + assert key == "a" + + def test_single_segment(self): + segments = [KeySegment(0, "root")] + key, depth = build_s3_key(segments) + assert key == "root" + assert depth == 1 + + def test_unordered_input_is_sorted(self): + segments = [ + KeySegment(2, "c"), + KeySegment(0, "a"), + KeySegment(1, "b"), + ] + key, depth = build_s3_key(segments) + assert key == "a/b/c" + + def test_roundtrip_with_get_segments(self): + """Build a key from segments, then parse it back.""" + original = "alpha/beta/gamma" + segments = get_segments_from_key(original) + key, depth = build_s3_key(segments) + assert key == original + assert depth == 3 + + +# ── replace_key_segments ────────────────────────────────────────────────── + + +class TestReplaceKeySegments: + """Tests for replace_key_segments (module-level function).""" + + def test_replace_single_segment(self): + result = replace_key_segments("a/b/c", [KeySegment(1, "replaced")]) + assert result == "a/replaced/c" + + def test_replace_last_segment(self): + result = replace_key_segments("a/b/c", [KeySegment(2, "new")]) + assert result == "a/b/new" + + def test_replace_first_segment(self): + result = replace_key_segments("a/b/c", [KeySegment(0, "root")]) + assert result == "root/b/c" + + def test_replace_multiple_segments(self): + result = replace_key_segments( + "a/b/c/d", + [KeySegment(1, "x"), KeySegment(3, "z")], + ) + assert result == "a/x/c/z" + + def test_with_max_len(self): + result = replace_key_segments("a/b/c/d", [KeySegment(1, "x")], max_len=3) + assert result == "a/x/c" + + +# ── replace_key_segments_at_relative_depth ──────────────────────────────── + + +class TestReplaceKeySegmentsAtRelativeDepth: + """Tests for replace_key_segments_at_relative_depth.""" + + def test_replace_at_deepest(self): + # depth=0 means deepest (last) segment + result = replace_key_segments_at_relative_depth("a/b/c", [KeySegment(0, "sibling")]) + assert result == "a/b/sibling" + + def test_append_child(self): + # depth=1 means one deeper than deepest + result = replace_key_segments_at_relative_depth("a/b", [KeySegment(1, "child")]) + assert result == "a/b/child" + + +# ── get_filter_name ─────────────────────────────────────────────────────── + + +class TestGetFilterName: + """Tests for get_filter_name.""" + + def test_match_by_depth_with_name(self): + segments = [KeySegment(0, "a"), KeySegment(1, "b"), KeySegment(2, "c")] + assert get_filter_name(segments, 1) == "b" + + def test_match_by_depth_with_incomplete_name(self): + segments = [KeySegment(0, "a"), KeySegment(1, incomplete_name=".txt")] + assert get_filter_name(segments, 1) == ".txt" + + def test_no_match_returns_none(self): + segments = [KeySegment(0, "a")] + assert get_filter_name(segments, 5) is None + + def test_incomplete_name_preferred_when_name_is_none(self): + seg = KeySegment(depth=0, name=None, incomplete_name="partial") + assert get_filter_name([seg], 0) == "partial" + + def test_name_preferred_over_incomplete_name(self): + seg = KeySegment(depth=0, name="exact", incomplete_name="partial") + assert get_filter_name([seg], 0) == "exact" diff --git a/tests/test_local_file_utils.py b/tests/test_local_file_utils.py new file mode 100644 index 0000000..df99c35 --- /dev/null +++ b/tests/test_local_file_utils.py @@ -0,0 +1,86 @@ +"""Tests for local_file_utils.py — pure unit tests (no S3 required).""" + +import json + +import pytest + +from S3MP.utils.local_file_utils import ( + delete_local_path, + get_local_file_size_bytes, + load_json, + save_json, +) + + +class TestGetLocalFileSizeBytes: + """Tests for get_local_file_size_bytes.""" + + def test_non_empty_file(self, tmp_path): + f = tmp_path / "data.bin" + f.write_bytes(b"hello world") + assert get_local_file_size_bytes(f) == 11 + + def test_empty_file(self, tmp_path): + f = tmp_path / "empty.txt" + f.write_bytes(b"") + assert get_local_file_size_bytes(f) == 0 + + def test_nonexistent_file_raises(self, tmp_path): + with pytest.raises((FileNotFoundError, OSError)): + get_local_file_size_bytes(tmp_path / "nope.txt") + + +class TestDeleteLocalPath: + """Tests for delete_local_path.""" + + def test_delete_file(self, tmp_path): + f = tmp_path / "to_delete.txt" + f.write_text("x") + delete_local_path(f) + assert not f.exists() + + def test_delete_empty_directory(self, tmp_path): + d = tmp_path / "empty_dir" + d.mkdir() + delete_local_path(d) + assert not d.exists() + + def test_delete_directory_with_files(self, tmp_path): + d = tmp_path / "dir_with_files" + d.mkdir() + (d / "child.txt").write_text("data") + delete_local_path(d) + assert not d.exists() + + def test_delete_nonexistent_is_noop(self, tmp_path): + # Should not raise + delete_local_path(tmp_path / "nonexistent") + + +class TestJsonRoundtrip: + """Tests for load_json / save_json.""" + + def test_roundtrip(self, tmp_path): + f = tmp_path / "data.json" + data = {"key": "value", "nested": {"a": 1}} + save_json(str(f), data) + loaded = load_json(str(f)) + assert loaded == data + + def test_save_with_indent(self, tmp_path): + f = tmp_path / "indented.json" + save_json(str(f), {"x": 1}, indent=2) + raw = f.read_text() + parsed = json.loads(raw) + assert parsed == {"x": 1} + # Check that indentation is present (not compact) + assert "\n" in raw + + def test_load_nonexistent_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + load_json(str(tmp_path / "missing.json")) + + def test_empty_dict(self, tmp_path): + f = tmp_path / "empty.json" + save_json(str(f), {}) + assert load_json(str(f)) == {} diff --git a/tests/test_mirror_path.py b/tests/test_mirror_path.py new file mode 100644 index 0000000..cdca9ed --- /dev/null +++ b/tests/test_mirror_path.py @@ -0,0 +1,677 @@ +"""Tests for MirrorPath class.""" + +import contextlib +from copy import copy +from pathlib import Path + +import boto3 +import pytest + +from S3MP.global_config import S3MPConfig +from S3MP.mirror_path import MirrorPath +from S3MP.types import S3Client + +TESTING_BUCKET_NAME = "s3mp-testing" + + +@pytest.fixture(autouse=True) +def _configure_s3mp(tmp_path): + """Configure S3MPConfig for testing, restoring previous values on teardown.""" + prev_bucket_key = S3MPConfig._default_bucket_key + prev_mirror_root = S3MPConfig._mirror_root + + S3MPConfig.set_default_bucket_key(TESTING_BUCKET_NAME) + S3MPConfig.set_mirror_root(tmp_path) + S3MPConfig.clear_boto3_cache() + yield + + if prev_bucket_key is not None: + S3MPConfig.set_default_bucket_key(prev_bucket_key) + if prev_mirror_root is not None: + S3MPConfig.set_mirror_root(prev_mirror_root) + S3MPConfig.clear_boto3_cache() + + +@pytest.fixture() +def s3_client(): + """Create an S3 client and ensure the test bucket exists.""" + client: S3Client = boto3.client("s3") + with contextlib.suppress(client.exceptions.BucketAlreadyOwnedByYou): + client.create_bucket(Bucket=TESTING_BUCKET_NAME) + return client + + +# ── from_s3_key construction ────────────────────────────────────────────── + + +class TestFromS3Key: + """Tests for MirrorPath.from_s3_key.""" + + def test_simple_file_key(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + assert mp.s3_key == "folder/file.txt" + assert mp.key_segments[-1].is_file is True + + def test_folder_key_with_trailing_slash(self): + mp = MirrorPath.from_s3_key("folder/subfolder/") + assert mp.s3_key == "folder/subfolder/" + assert mp.key_segments[-1].is_file is False + + def test_s3_url_parsing(self): + mp = MirrorPath.from_s3_key("s3://my-bucket/folder/file.txt") + assert mp.s3_key == "folder/file.txt" + assert mp.bucket_key == "my-bucket" + + def test_s3_url_bucket_overridden_by_kwarg(self): + mp = MirrorPath.from_s3_key("s3://url-bucket/folder/file.txt", bucket_key="kwarg-bucket") + assert mp.bucket_key == "kwarg-bucket" + + def test_s3_url_bucket_extracted_when_no_kwarg(self): + mp = MirrorPath.from_s3_key("s3://url-bucket/folder/file.txt") + assert mp.bucket_key == "url-bucket" + + def test_iam_role_arn_passed_through(self): + mp = MirrorPath.from_s3_key( + "folder/file.txt", iam_role_arn="arn:aws:iam::123:role/TestRole" + ) + assert mp.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_key_segments_depth(self): + mp = MirrorPath.from_s3_key("a/b/c/d.txt") + assert len(mp.key_segments) == 4 + assert mp.key_segments[0].name == "a" + assert mp.key_segments[3].name == "d.txt" + assert mp.key_segments[3].depth == 3 + + def test_no_extension_treated_as_folder(self): + """A key without trailing / and no file extension is a folder.""" + mp = MirrorPath.from_s3_key("folder/subfolder") + assert mp.key_segments[-1].is_file is False + assert mp.s3_key == "folder/subfolder/" + + def test_dotted_version_folder_not_treated_as_file(self): + """Dotted version names like v0.1.0 are folders, not files.""" + mp = MirrorPath.from_s3_key("models/v0.1.0") + assert mp.key_segments[-1].is_file is False + assert mp.s3_key == "models/v0.1.0/" + + def test_alpha_extension_treated_as_file(self): + """Standard file extensions are recognised as files.""" + for key in ["data/file.csv", "images/photo.jpg", "archive.tar.gz"]: + mp = MirrorPath.from_s3_key(key) + assert mp.key_segments[-1].is_file is True, f"{key} should be a file" + + @pytest.mark.parametrize( + ("s3_key", "expected_is_file", "expected_s3_key"), + [ + # Files — alpha extension, no trailing / + ("folder/file.txt", True, "folder/file.txt"), + ("data.csv", True, "data.csv"), + ("a/b/archive.tar.gz", True, "a/b/archive.tar.gz"), + ("folder/.hidden.conf", True, "folder/.hidden.conf"), + # Files without common extensions but still alpha + ("logs/output.log", True, "logs/output.log"), + ("models/weights.pt", True, "models/weights.pt"), + # Folders — trailing / + ("folder/subfolder/", False, "folder/subfolder/"), + ("folder/file.txt/", False, "folder/file.txt/"), + # Folders — no extension, no trailing / + ("folder/subfolder", False, "folder/subfolder/"), + ("data", False, "data/"), + # Folders — numeric-only extension (dotted versions) + ("models/v0.1.0", False, "models/v0.1.0/"), + ("releases/2.0", False, "releases/2.0/"), + ("checkpoints/step.100", False, "checkpoints/step.100/"), + # Dotfiles — bare dotfile has no suffix, treated as folder + ("config/.env", False, "config/.env/"), + # Extensionless files can't be distinguished from folders — treated as folders + ("folder/LICENSE", False, "folder/LICENSE/"), + ("bin/my_script", False, "bin/my_script/"), + ], + ids=lambda v: str(v).replace("/", "|"), + ) + def test_file_vs_folder_classification(self, s3_key, expected_is_file, expected_s3_key): + """Parametrized: various key shapes produce the correct is_file and s3_key.""" + mp = MirrorPath.from_s3_key(s3_key) + assert mp.key_segments[-1].is_file is expected_is_file, f"is_file mismatch for {s3_key!r}" + assert mp.s3_key == expected_s3_key, f"s3_key mismatch for {s3_key!r}" + + +# ── s3_key property ────────────────────────────────────────────────────── + + +class TestS3Key: + """Tests for MirrorPath.s3_key property.""" + + def test_file_key_has_no_trailing_slash(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + assert not mp.s3_key.endswith("/") + + def test_folder_key_has_trailing_slash(self): + mp = MirrorPath.from_s3_key("folder/subfolder/") + assert mp.s3_key.endswith("/") + + +# ── s3_url property ────────────────────────────────────────────────────── + + +class TestS3Url: + """Tests for MirrorPath.s3_url property.""" + + def test_s3_url_format(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + assert mp.s3_url == f"s3://{TESTING_BUCKET_NAME}/folder/file.txt" + + +# ── local_path property ────────────────────────────────────────────────── + + +class TestLocalPath: + """Tests for MirrorPath.local_path property.""" + + def test_local_path_includes_bucket_key(self, tmp_path): + mp = MirrorPath.from_s3_key("folder/file.txt") + expected = tmp_path / TESTING_BUCKET_NAME / "folder" / "file.txt" + assert mp.local_path == expected + + def test_override_local_path(self, tmp_path): + mp = MirrorPath.from_s3_key("folder/file.txt") + override = tmp_path / "custom" / "path.txt" + mp.override_local_path(override) + assert mp.local_path == override + + +# ── __copy__ ───────────────────────────────────────────────────────────── + + +class TestCopy: + """Tests for MirrorPath.__copy__.""" + + def test_copy_preserves_fields(self): + mp = MirrorPath.from_s3_key( + "folder/file.txt", + bucket_key="test-bucket", + iam_role_arn="arn:aws:iam::123:role/TestRole", + ) + mp_copy = copy(mp) + assert mp_copy.s3_key == mp.s3_key + assert mp_copy.bucket_key == "test-bucket" + assert mp_copy.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_copy_does_not_leak_local_path_override(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + mp.override_local_path(Path("/tmp/override.txt")) + mp_copy = copy(mp) + # The copy should NOT have the override + assert mp_copy._local_path_override is None + + def test_copy_key_segments_are_independent(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + mp_copy = copy(mp) + mp_copy.key_segments[0].name = "changed" + assert mp.key_segments[0].name == "folder" + + +# ── Navigation: sibling, child, parent, trim ───────────────────────────── + + +class TestNavigation: + """Tests for sibling, child, parent, and trim methods.""" + + def test_get_child(self): + mp = MirrorPath.from_s3_key("folder/subfolder/") + child = mp.get_child("file.txt") + assert "file.txt" in child.s3_key + + def test_get_child_raises_on_file(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + with pytest.raises(ValueError, match="Cannot get a child of file key"): + mp.get_child("sub.txt") + + def test_get_sibling(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + sibling = mp.get_sibling("other.txt") + assert "other.txt" in sibling.s3_key + assert "folder" in sibling.s3_key + + def test_get_parent(self): + mp = MirrorPath.from_s3_key("a/b/c.txt") + parent = mp.get_parent() + assert len(parent.key_segments) == 2 + assert parent.key_segments[-1].name == "b" + + def test_trim(self): + mp = MirrorPath.from_s3_key("a/b/c/d.txt") + trimmed = mp.trim(2) + assert len(trimmed.key_segments) == 2 + assert trimmed.key_segments[-1].name == "b" + + def test_child_propagates_bucket_key(self): + mp = MirrorPath.from_s3_key("folder/subfolder/", bucket_key="custom-bucket") + child = mp.get_child("file.txt") + assert child.bucket_key == "custom-bucket" + + def test_child_propagates_iam_role_arn(self): + mp = MirrorPath.from_s3_key( + "folder/subfolder/", iam_role_arn="arn:aws:iam::123:role/TestRole" + ) + child = mp.get_child("file.txt") + assert child.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_sibling_propagates_bucket_key(self): + mp = MirrorPath.from_s3_key("folder/file.txt", bucket_key="custom-bucket") + sibling = mp.get_sibling("other.txt") + assert sibling.bucket_key == "custom-bucket" + + def test_parent_propagates_iam_role_arn(self): + mp = MirrorPath.from_s3_key("a/b/c.txt", iam_role_arn="arn:aws:iam::123:role/TestRole") + parent = mp.get_parent() + assert parent.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_trim_propagates_bucket_key_and_iam_role_arn(self): + mp = MirrorPath.from_s3_key( + "a/b/c.txt", + bucket_key="custom-bucket", + iam_role_arn="arn:aws:iam::123:role/TestRole", + ) + trimmed = mp.trim(2) + assert trimmed.bucket_key == "custom-bucket" + assert trimmed.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + +# ── mirror_root propagation through navigation ─────────────────────────── + + +class TestMirrorRootPropagation: + """Regression: all navigation methods must preserve a custom mirror_root.""" + + def test_trim_preserves_mirror_root(self, tmp_path): + custom = tmp_path / "custom" + mp = MirrorPath.from_s3_key("a/b/c.txt", mirror_root=custom) + assert mp.trim(2).mirror_root == custom + + def test_get_parent_preserves_mirror_root(self, tmp_path): + custom = tmp_path / "custom" + mp = MirrorPath.from_s3_key("a/b/c.txt", mirror_root=custom) + assert mp.get_parent().mirror_root == custom + + def test_get_sibling_preserves_mirror_root(self, tmp_path): + custom = tmp_path / "custom" + mp = MirrorPath.from_s3_key("a/b/c.txt", mirror_root=custom) + assert mp.get_sibling("d.txt").mirror_root == custom + + def test_get_child_preserves_mirror_root(self, tmp_path): + custom = tmp_path / "custom" + mp = MirrorPath.from_s3_key("a/b/", mirror_root=custom) + assert mp.get_child("file.txt").mirror_root == custom + + def test_replace_key_segments_preserves_mirror_root(self, tmp_path): + from S3MP.keys import KeySegment + + custom = tmp_path / "custom" + mp = MirrorPath.from_s3_key("a/b/c.txt", mirror_root=custom) + replaced = mp.replace_key_segments([KeySegment(1, "x")]) + assert replaced.mirror_root == custom + + def test_get_children_on_s3_preserves_mirror_root(self, s3_client, tmp_path): + custom = tmp_path / "custom" + s3_client.put_object(Bucket=TESTING_BUCKET_NAME, Key="mr_prop/child.txt", Body=b"c") + mp = MirrorPath.from_s3_key("mr_prop/", mirror_root=custom) + children = mp.get_children_on_s3() + assert len(children) > 0 + for child in children: + assert child.mirror_root == custom + + +# ── from_local_path ────────────────────────────────────────────────────── + + +class TestFromLocalPath: + """Tests for MirrorPath.from_local_path.""" + + def test_basic_roundtrip(self, tmp_path): + local = tmp_path / TESTING_BUCKET_NAME / "folder" / "file.txt" + mp = MirrorPath.from_local_path(local, mirror_root=tmp_path) + assert mp.s3_key == "folder/file.txt" + assert mp.bucket_key == TESTING_BUCKET_NAME + + def test_preserves_kwargs(self, tmp_path): + local = tmp_path / "bucket" / "a" / "b.txt" + mp = MirrorPath.from_local_path( + local, + mirror_root=tmp_path, + bucket_key="custom-bucket", + iam_role_arn="arn:aws:iam::123:role/TestRole", + ) + assert mp.bucket_key == "custom-bucket" + assert mp.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_bucket_key_kwarg_overrides_path_bucket(self, tmp_path): + local = tmp_path / "path-bucket" / "folder" / "file.txt" + mp = MirrorPath.from_local_path(local, mirror_root=tmp_path, bucket_key="override-bucket") + assert mp.bucket_key == "override-bucket" + assert mp.s3_key == "folder/file.txt" + + def test_local_path_roundtrip(self, tmp_path): + """from_s3_key -> local_path -> from_local_path should produce the same s3_key and bucket_key.""" + original = MirrorPath.from_s3_key("data/images/photo.jpg") + reconstructed = MirrorPath.from_local_path(original.local_path, mirror_root=tmp_path) + assert reconstructed.s3_key == "data/images/photo.jpg" + assert reconstructed.bucket_key == TESTING_BUCKET_NAME + + def test_mirror_root_propagated_to_local_path(self, tmp_path): + """from_local_path with explicit mirror_root should produce a local_path under that root.""" + custom_root = tmp_path / "custom_mirror" + local = custom_root / "my-bucket" / "key" / "file.txt" + mp = MirrorPath.from_local_path(local, mirror_root=custom_root) + assert mp.local_path == custom_root / "my-bucket" / "key" / "file.txt" + + +# ── exists_in_mirror ───────────────────────────────────────────────────── + + +class TestExistsInMirror: + """Tests for exists_in_mirror (pure local filesystem).""" + + def test_not_exists(self): + mp = MirrorPath.from_s3_key("folder/nonexistent.txt") + assert not mp.exists_in_mirror() + + def test_exists_after_creating_file(self, tmp_path): + mp = MirrorPath.from_s3_key("folder/real.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_text("data") + assert mp.exists_in_mirror() + + +# ── replace_key_segments / get_key_segment ──────────────────────────────── + + +class TestMirrorPathKeySegmentOps: + """Tests for replace_key_segments and get_key_segment on MirrorPath.""" + + def test_get_key_segment(self): + mp = MirrorPath.from_s3_key("a/b/c.txt") + assert mp.get_key_segment(0).name == "a" + assert mp.get_key_segment(1).name == "b" + assert mp.get_key_segment(2).name == "c.txt" + + def test_replace_key_segments(self): + mp = MirrorPath.from_s3_key("a/b/c.txt") + from S3MP.keys import KeySegment + + replaced = mp.replace_key_segments([KeySegment(1, "replaced")]) + assert replaced.key_segments[1].name == "replaced" + # Other segments unchanged + assert replaced.key_segments[0].name == "a" + + def test_replace_key_segments_propagates_context(self): + mp = MirrorPath.from_s3_key( + "a/b/c.txt", + bucket_key="custom-bucket", + iam_role_arn="arn:aws:iam::123:role/TestRole", + ) + from S3MP.keys import KeySegment + + replaced = mp.replace_key_segments([KeySegment(1, "new")]) + assert replaced.bucket_key == "custom-bucket" + assert replaced.iam_role_arn == "arn:aws:iam::123:role/TestRole" + + def test_replace_key_segments_at_relative_depth(self): + mp = MirrorPath.from_s3_key("a/b/c.txt") + from S3MP.keys import KeySegment + + # depth=0 → replaces deepest (c.txt), same as sibling + replaced = mp.replace_key_segments_at_relative_depth([KeySegment(0, "sibling.txt")]) + assert replaced.key_segments[-1].name == "sibling.txt" + + def test_replace_key_segments_placeholder_depths_match_indices(self): + """Regression: placeholders inserted when extending must have depth == list index.""" + from S3MP.keys import KeySegment + + mp = MirrorPath.from_s3_key("a/b.txt") # 2 segments (depth 0, 1) + # Request a replacement at depth 4 — forces 2 placeholder segments at indices 2, 3 + replaced = mp.replace_key_segments([KeySegment(4, "deep.txt")]) + for idx, seg in enumerate(replaced.key_segments): + assert seg.depth == idx, f"segment at index {idx} has depth {seg.depth}" + + +# ── __repr__ ────────────────────────────────────────────────────────────── + + +class TestRepr: + """Tests for MirrorPath.__repr__.""" + + def test_repr_contains_key(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + r = repr(mp) + assert "MirrorPath" in r + assert "folder/file.txt" in r + + +# ── set_bucket_key / set_iam_role_arn ───────────────────────────────────── + + +class TestMutators: + """Tests for set_bucket_key and set_iam_role_arn.""" + + def test_set_bucket_key(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + assert mp.bucket_key == TESTING_BUCKET_NAME + mp.set_bucket_key("new-bucket") + assert mp.bucket_key == "new-bucket" + + def test_set_iam_role_arn(self): + mp = MirrorPath.from_s3_key("folder/file.txt") + assert mp.iam_role_arn is None + mp.set_iam_role_arn("arn:aws:iam::123:role/New") + assert mp.iam_role_arn == "arn:aws:iam::123:role/New" + + +# ── delete_local ────────────────────────────────────────────────────────── + + +class TestDeleteLocal: + """Tests for delete_local (pure filesystem, no S3).""" + + def test_delete_local_file(self, tmp_path): + mp = MirrorPath.from_s3_key("folder/deleteme.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_text("to delete") + assert mp.exists_in_mirror() + mp.delete_local() + assert not mp.exists_in_mirror() + + def test_delete_local_nonexistent_is_noop(self): + mp = MirrorPath.from_s3_key("folder/no_such_file.txt") + # Should not raise + mp.delete_local() + + +# ── S3 operations (require running S3-compatible endpoint) ─────────────── + + +class TestS3Operations: + """Tests for S3-backed MirrorPath operations.""" + + def _upload_test_file(self, s3_client, key: str, data: bytes = b"test"): + s3_client.put_object(Bucket=TESTING_BUCKET_NAME, Key=key, Body=data) + + def test_exists_on_s3(self, s3_client): + self._upload_test_file(s3_client, "mp_test/exists.txt") + mp = MirrorPath.from_s3_key("mp_test/exists.txt") + assert mp.exists_on_s3() + + def test_not_exists_on_s3(self, s3_client): + mp = MirrorPath.from_s3_key("mp_test/nonexistent_file.txt") + assert not mp.exists_on_s3() + + def test_is_file_on_s3(self, s3_client): + self._upload_test_file(s3_client, "mp_test/isfile.txt") + mp = MirrorPath.from_s3_key("mp_test/isfile.txt") + assert mp.is_file_on_s3() + + def test_download_to_mirror(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/download.txt", b"hello") + mp = MirrorPath.from_s3_key("mp_test/download.txt") + mp.download_to_mirror() + assert mp.local_path.exists() + assert mp.local_path.read_bytes() == b"hello" + + def test_upload_from_mirror(self, s3_client, tmp_path): + mp = MirrorPath.from_s3_key("mp_test/upload.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"upload_data") + mp.upload_from_mirror(overwrite=True) + assert mp.exists_on_s3() + + def test_download_skips_if_present(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/skip.txt", b"original") + mp = MirrorPath.from_s3_key("mp_test/skip.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"local_version") + mp.download_to_mirror(overwrite=False) + # Should not overwrite existing local file + assert mp.local_path.read_bytes() == b"local_version" + + def test_download_overwrites_when_requested(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/overwrite.txt", b"s3_version") + mp = MirrorPath.from_s3_key("mp_test/overwrite.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"local_version") + mp.download_to_mirror(overwrite=True) + assert mp.local_path.read_bytes() == b"s3_version" + + def test_delete_s3(self, s3_client): + self._upload_test_file(s3_client, "mp_test/delete_me.txt") + mp = MirrorPath.from_s3_key("mp_test/delete_me.txt") + assert mp.exists_on_s3() + mp.delete_s3() + assert not mp.exists_on_s3() + + def test_get_children_on_s3(self, s3_client): + self._upload_test_file(s3_client, "mp_test/children/a.txt") + self._upload_test_file(s3_client, "mp_test/children/b.txt") + mp = MirrorPath.from_s3_key("mp_test/children/") + children = mp.get_children_on_s3() + child_keys = {c.s3_key for c in children} + assert "mp_test/children/a.txt" in child_keys + assert "mp_test/children/b.txt" in child_keys + + def test_get_children_propagates_bucket_key(self, s3_client): + self._upload_test_file(s3_client, "mp_test/prop/child.txt") + mp = MirrorPath.from_s3_key("mp_test/prop/") + children = mp.get_children_on_s3() + for child in children: + assert child.bucket_key == TESTING_BUCKET_NAME + + def test_copy_to_mp_s3_only(self, s3_client): + self._upload_test_file(s3_client, "mp_test/copy_src.txt", b"copy_data") + src = MirrorPath.from_s3_key("mp_test/copy_src.txt") + dest = MirrorPath.from_s3_key("mp_test/copy_dest.txt") + src.copy_to_mp_s3_only(dest) + assert dest.exists_on_s3() + + def test_save_and_load_local(self, s3_client, tmp_path): + mp = MirrorPath.from_s3_key("mp_test/saveload.json") + mp.save_local('{"key": "value"}', upload=True, save_fn=_write_text, overwrite=True) + assert mp.exists_on_s3() + + # Load it back + loaded = mp.load_local(download=True, load_fn=_read_text, overwrite=True) + assert loaded == '{"key": "value"}' + + def test_is_file_and_exists_on_s3(self, s3_client): + self._upload_test_file(s3_client, "mp_test/filecheck.txt") + mp = MirrorPath.from_s3_key("mp_test/filecheck.txt") + assert mp.is_file_and_exists_on_s3() + + def test_is_file_and_exists_on_s3_for_folder(self, s3_client): + s3_client.put_object(Bucket=TESTING_BUCKET_NAME, Key="mp_test/folder_check/", Body=b"") + mp = MirrorPath.from_s3_key("mp_test/folder_check/") + assert not mp.is_file_and_exists_on_s3() + + def test_is_file_and_exists_on_s3_for_nonexistent(self, s3_client): + mp = MirrorPath.from_s3_key("mp_test/no_such_thing.txt") + assert not mp.is_file_and_exists_on_s3() + + def test_download_to_mirror_if_not_present(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/dlnp.txt", b"dlnp_data") + mp = MirrorPath.from_s3_key("mp_test/dlnp.txt") + assert not mp.exists_in_mirror() + mp.download_to_mirror_if_not_present() + assert mp.exists_in_mirror() + assert mp.local_path.read_bytes() == b"dlnp_data" + + def test_download_to_mirror_if_not_present_skips_existing(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/dlnp_skip.txt", b"s3_data") + mp = MirrorPath.from_s3_key("mp_test/dlnp_skip.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"local_data") + mp.download_to_mirror_if_not_present() + assert mp.local_path.read_bytes() == b"local_data" + + def test_upload_from_mirror_if_not_present(self, s3_client, tmp_path): + mp = MirrorPath.from_s3_key("mp_test/ulnp.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"ulnp_data") + mp.upload_from_mirror_if_not_present() + assert mp.exists_on_s3() + + def test_upload_from_mirror_if_not_present_skips_existing(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/ulnp_skip.txt", b"original") + mp = MirrorPath.from_s3_key("mp_test/ulnp_skip.txt") + mp.local_path.parent.mkdir(parents=True, exist_ok=True) + mp.local_path.write_bytes(b"different_data") + mp.upload_from_mirror_if_not_present() + # Download to verify original still on S3 + mp.download_to_mirror(overwrite=True) + assert mp.local_path.read_bytes() == b"original" + + def test_delete_all(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/delall.txt", b"del_data") + mp = MirrorPath.from_s3_key("mp_test/delall.txt") + mp.download_to_mirror() + assert mp.exists_on_s3() + assert mp.exists_in_mirror() + mp.delete_all() + assert not mp.exists_on_s3() + assert not mp.exists_in_mirror() + + def test_copy_to_mp_mirror_only(self, s3_client, tmp_path): + mp_src = MirrorPath.from_s3_key("mp_test/cp_mirror_src.txt") + mp_src.local_path.parent.mkdir(parents=True, exist_ok=True) + mp_src.local_path.write_bytes(b"mirror_copy") + mp_dest = MirrorPath.from_s3_key("mp_test/cp_mirror_dest.txt") + mp_dest.local_path.parent.mkdir(parents=True, exist_ok=True) + mp_src.copy_to_mp_mirror_only(mp_dest) + assert mp_dest.local_path.read_bytes() == b"mirror_copy" + + def test_copy_to_mp_from_s3(self, s3_client, tmp_path): + self._upload_test_file(s3_client, "mp_test/cp_full_src.txt", b"full_copy") + mp_src = MirrorPath.from_s3_key("mp_test/cp_full_src.txt") + mp_dest = MirrorPath.from_s3_key("mp_test/cp_full_dest.txt") + mp_src.copy_to_mp(mp_dest) + assert mp_dest.exists_on_s3() + assert mp_dest.exists_in_mirror() + assert mp_dest.local_path.read_bytes() == b"full_copy" + + def test_copy_to_mp_from_mirror(self, s3_client, tmp_path): + mp_src = MirrorPath.from_s3_key("mp_test/cp_msrc.txt") + mp_src.local_path.parent.mkdir(parents=True, exist_ok=True) + mp_src.local_path.write_bytes(b"mirror_src_data") + mp_dest = MirrorPath.from_s3_key("mp_test/cp_mdest.txt") + mp_dest.local_path.parent.mkdir(parents=True, exist_ok=True) + mp_src.copy_to_mp(mp_dest, use_mirror_as_src=True) + assert mp_dest.exists_on_s3() + assert mp_dest.local_path.read_bytes() == b"mirror_src_data" + + +def _write_text(path: str, data: str): + """Helper: write text to a file.""" + Path(path).write_text(data) + + +def _read_text(path: str) -> str: + """Helper: read text from a file.""" + return Path(path).read_text() diff --git a/tests/test_prefix_queries.py b/tests/test_prefix_queries.py new file mode 100644 index 0000000..499027b --- /dev/null +++ b/tests/test_prefix_queries.py @@ -0,0 +1,158 @@ +"""Tests for prefix_queries.py — S3 integration tests.""" + +import contextlib + +import boto3 +import pytest + +from S3MP.prefix_queries import ( + get_files_within_folder, + get_folders_within_folder, +) +from S3MP.types import S3Client + +TESTING_BUCKET_NAME = "s3mp-testing" + + +@pytest.fixture() +def s3_env(): + """Create a client and set up a known folder structure for prefix tests.""" + client: S3Client = boto3.client("s3") + with contextlib.suppress(client.exceptions.BucketAlreadyOwnedByYou): + client.create_bucket(Bucket=TESTING_BUCKET_NAME) + + # Build structure: pq_test/ + # file_a.txt + # file_b.csv + # sub_one/ + # nested.txt + # sub_two/ + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/file_a.txt", Body=b"a") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/file_b.csv", Body=b"b") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/sub_one/", Body=b"") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/sub_one/nested.txt", Body=b"n") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/sub_two/", Body=b"") + return client + + +# ── get_files_within_folder ─────────────────────────────────────────────── + + +class TestGetFilesWithinFolder: + """Tests for get_files_within_folder.""" + + def test_lists_files(self, s3_env): + files = list(get_files_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + # Returned names are relative to the folder_key + assert "file_a.txt" in files + assert "file_b.csv" in files + + def test_does_not_include_subfolder_files(self, s3_env): + files = list(get_files_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + # nested.txt lives in sub_one/, should not appear at this level + assert all("nested" not in f for f in files) + + def test_key_filter(self, s3_env): + files = list( + get_files_within_folder( + "pq_test/", TESTING_BUCKET_NAME, key_filter=".csv", client=s3_env + ) + ) + assert "file_b.csv" in files + assert "file_a.txt" not in files + + def test_empty_folder(self, s3_env): + files = list( + get_files_within_folder("pq_test/sub_two/", TESTING_BUCKET_NAME, client=s3_env) + ) + assert files == [] + + def test_nested_folder_files(self, s3_env): + files = list( + get_files_within_folder("pq_test/sub_one/", TESTING_BUCKET_NAME, client=s3_env) + ) + assert "nested.txt" in files + + def test_folder_marker_not_returned_alongside_real_files(self, s3_env): + """Regression: a folder with a zero-byte marker AND real files must not yield '' or the marker.""" + files = list( + get_files_within_folder("pq_test/sub_one/", TESTING_BUCKET_NAME, client=s3_env) + ) + assert "" not in files + assert files == ["nested.txt"] + + +# ── get_folders_within_folder ───────────────────────────────────────────── + + +class TestGetFoldersWithinFolder: + """Tests for get_folders_within_folder.""" + + def test_lists_folders(self, s3_env): + folders = list(get_folders_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + assert "sub_one/" in folders + assert "sub_two/" in folders + + def test_does_not_include_files(self, s3_env): + folders = list(get_folders_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + assert all("file_a" not in f for f in folders) + assert all("file_b" not in f for f in folders) + + def test_key_filter(self, s3_env): + folders = list( + get_folders_within_folder( + "pq_test/", TESTING_BUCKET_NAME, key_filter="sub_one", client=s3_env + ) + ) + assert "sub_one/" in folders + assert "sub_two/" not in folders + + def test_empty_parent_returns_empty(self, s3_env): + folders = list( + get_folders_within_folder("pq_test/sub_two/", TESTING_BUCKET_NAME, client=s3_env) + ) + assert folders == [] + + +# ── removeprefix correctness ───────────────────────────────────────────── + + +class TestPrefixStripping: + """Ensure prefix stripping only removes the leading folder_key, not duplicates deeper in the path.""" + + def test_file_with_repeated_prefix_segment(self, s3_env): + # Key: "pq_test/pq_test/data.txt" — folder_key "pq_test/" appears again in the path + s3_env.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/pq_test/data.txt", Body=b"dup") + files = list(get_files_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + # Should NOT strip the inner "pq_test/" — that's part of the relative path + assert "pq_test/data.txt" not in files + # It's a nested folder, so it shouldn't appear as a direct file at all + + def test_folder_with_repeated_prefix_segment(self, s3_env): + # Create a subfolder that repeats the parent name + s3_env.put_object(Bucket=TESTING_BUCKET_NAME, Key="pq_test/pq_test/", Body=b"") + folders = list(get_folders_within_folder("pq_test/", TESTING_BUCKET_NAME, client=s3_env)) + assert "pq_test/" in folders + + +# ── folder_key without trailing slash ───────────────────────────────────── + + +class TestNoTrailingSlash: + """Ensure functions work correctly when folder_key lacks a trailing '/'.""" + + def test_files_without_trailing_slash(self, s3_env): + files = list(get_files_within_folder("pq_test", TESTING_BUCKET_NAME, client=s3_env)) + assert "file_a.txt" in files + assert "file_b.csv" in files + # No leading "/" in relative paths + assert all(not f.startswith("/") for f in files) + + def test_folders_without_trailing_slash(self, s3_env): + folders = list(get_folders_within_folder("pq_test", TESTING_BUCKET_NAME, client=s3_env)) + assert "sub_one/" in folders + assert "sub_two/" in folders + + def test_empty_folder_without_trailing_slash(self, s3_env): + files = list(get_files_within_folder("pq_test/sub_two", TESTING_BUCKET_NAME, client=s3_env)) + assert files == [] diff --git a/tests/test_s3_utils.py b/tests/test_s3_utils.py index 0e9ff2b..79361fe 100644 --- a/tests/test_s3_utils.py +++ b/tests/test_s3_utils.py @@ -6,8 +6,17 @@ import pytest from S3MP.types import S3Client -from S3MP.utils import s3_utils -from S3MP.utils.s3_utils import key_exists_on_s3, key_is_file_on_s3, s3_list_child_keys +from S3MP.utils.s3_utils import ( + delete_child_keys_on_s3, + delete_key_on_s3, + download_key, + key_exists_on_s3, + key_is_file_on_s3, + key_size_on_s3, + s3_list_child_keys, + s3_list_single_key, + upload_to_key, +) TESTING_BUCKET_NAME = "s3mp-testing" @@ -72,23 +81,22 @@ def test_folders_files(): # Test folder listing with pytest.warns(Warning): s3_list_child_keys("test_folder", bucket, client) - # Use s3_list_single_key to get the full response dict for testing - response = s3_utils.s3_list_single_key("test_folder/", bucket, client) - child_files = [obj["Key"] for obj in response.get("Contents", [])] - child_folders = [obj["Prefix"] for obj in response.get("CommonPrefixes", [])] + + child_keys = s3_list_child_keys("test_folder/", bucket, client) + child_files = [k for k in child_keys if not k.endswith("/")] + child_folders = [k for k in child_keys if k.endswith("/")] list_assert( child_files, [ - "test_folder/", "test_folder/test_file_1", "test_folder/test_file_2", - "test_folder/empty_subfolder/", - "test_folder/nonempty_subfolder/", ], + order_matters=False, ) list_assert( child_folders, ["test_folder/empty_subfolder/", "test_folder/nonempty_subfolder/"], + order_matters=False, ) # Test folder and file existences @@ -136,3 +144,137 @@ def test_folders_files(): if __name__ == "__main__": test_folders_files() + + +# ── Fixtures ────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def s3_env(): + """Create a client, resource, and bucket for S3 tests.""" + client: S3Client = boto3.client("s3") + with contextlib.suppress(client.exceptions.BucketAlreadyOwnedByYou): + client.create_bucket(Bucket=TESTING_BUCKET_NAME) + bucket = boto3.resource("s3").Bucket(TESTING_BUCKET_NAME) + return client, bucket + + +# ── s3_list_single_key ──────────────────────────────────────────────────── + + +class TestS3ListSingleKey: + """Tests for s3_list_single_key.""" + + def test_returns_contents_for_existing_file(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="single/file.txt", Body=b"data") + result = s3_list_single_key("single/file.txt", bucket, client) + assert "Contents" in result + assert result["Contents"][0]["Key"] == "single/file.txt" + + def test_returns_max_one_result(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="single_multi/a.txt", Body=b"a") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="single_multi/b.txt", Body=b"b") + result = s3_list_single_key("single_multi/", bucket, client) + # MaxKeys=1, so at most 1 content entry + content_count = len(result.get("Contents", [])) + prefix_count = len(result.get("CommonPrefixes", [])) + assert content_count + prefix_count <= 1 + + def test_empty_for_nonexistent_key(self, s3_env): + client, bucket = s3_env + result = s3_list_single_key("nonexistent/key/xyz", bucket, client) + assert "Contents" not in result + assert "CommonPrefixes" not in result + + +# ── key_size_on_s3 ──────────────────────────────────────────────────────── + + +class TestKeySizeOnS3: + """Tests for key_size_on_s3.""" + + def test_correct_size(self, s3_env): + client, bucket = s3_env + data = b"twelve bytes" + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="size/known.txt", Body=data) + assert key_size_on_s3("size/known.txt", bucket, client) == len(data) + + def test_zero_byte_file(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="size/empty.txt", Body=b"") + assert key_size_on_s3("size/empty.txt", bucket, client) == 0 + + def test_nonexistent_key_raises(self, s3_env): + client, bucket = s3_env + with pytest.raises(ValueError, match="does not exist"): + key_size_on_s3("size/no_such_key.txt", bucket, client) + + +# ── download_key / upload_to_key ────────────────────────────────────────── + + +class TestDownloadUploadKey: + """Tests for download_key and upload_to_key.""" + + def test_upload_and_download_file(self, s3_env, tmp_path): + client, bucket = s3_env + local_src = tmp_path / "upload_src.txt" + local_src.write_bytes(b"upload content") + upload_to_key("dl_up/roundtrip.txt", local_src, bucket, client) + assert key_exists_on_s3("dl_up/roundtrip.txt", bucket, client) + + local_dest = tmp_path / "download_dest.txt" + download_key("dl_up/roundtrip.txt", local_dest, bucket, client) + assert local_dest.read_bytes() == b"upload content" + + def test_download_creates_parent_dirs(self, s3_env, tmp_path): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="dl_up/nested.txt", Body=b"nested") + dest = tmp_path / "deep" / "path" / "nested.txt" + download_key("dl_up/nested.txt", dest, bucket, client) + assert dest.exists() + assert dest.read_bytes() == b"nested" + + def test_download_nonexistent_raises(self, s3_env, tmp_path): + client, bucket = s3_env + with pytest.raises(ValueError, match="does not exist"): + download_key("dl_up/no_such.txt", tmp_path / "out.txt", bucket, client) + + +# ── delete_key_on_s3 / delete_child_keys_on_s3 ─────────────────────────── + + +class TestDeleteOnS3: + """Tests for delete_key_on_s3 and delete_child_keys_on_s3.""" + + def test_delete_file_key(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del/file.txt", Body=b"x") + assert key_exists_on_s3("del/file.txt", bucket, client) + delete_key_on_s3("del/file.txt", bucket, client) + assert not key_exists_on_s3("del/file.txt", bucket, client) + + def test_delete_nonexistent_is_noop(self, s3_env): + client, bucket = s3_env + # Should not raise + delete_key_on_s3("del/nonexistent_xyz.txt", bucket, client) + + def test_delete_folder_removes_children(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_folder/", Body=b"") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_folder/a.txt", Body=b"a") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_folder/b.txt", Body=b"b") + delete_key_on_s3("del_folder/", bucket, client) + assert not key_exists_on_s3("del_folder/a.txt", bucket, client) + assert not key_exists_on_s3("del_folder/b.txt", bucket, client) + + def test_delete_child_keys(self, s3_env): + client, bucket = s3_env + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_children/", Body=b"") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_children/c1.txt", Body=b"c1") + client.put_object(Bucket=TESTING_BUCKET_NAME, Key="del_children/c2.txt", Body=b"c2") + delete_child_keys_on_s3("del_children/", bucket, client) + assert not key_exists_on_s3("del_children/c1.txt", bucket, client) + assert not key_exists_on_s3("del_children/c2.txt", bucket, client) diff --git a/uv.lock b/uv.lock index 6ebdb68..c3e957d 100644 --- a/uv.lock +++ b/uv.lock @@ -1515,7 +1515,7 @@ wheels = [ [[package]] name = "s3mp" -version = "0.8.4" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "aioboto3" },