diff --git a/README.md b/README.md index de067dd..ba22a0e 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,56 @@ client = APIDeploymentsClient( The retry logic uses exponential backoff with full jitter and respects the `Retry-After` header on 429 responses. +## Unstract CLI + +Installing `unstract-client` also provides the `unstract` command: + +```bash +pip install unstract-client +unstract --help +``` + +### `unstract clone` + +Clones an organization's resources to another org, on the same or a different +deployment (e.g. promote **dev** → **QA** → **prod**). Covers adapters, +connectors, workflows, pipelines, API deployments, Prompt Studio projects and +their files, user groups, and sharing state (users matched by email, groups by +name). + +Authenticates with each org admin's **Platform API key**; prefer the env vars +so keys never land in shell history: + +```bash +export UNSTRACT_SRC_PLATFORM_KEY="" +export UNSTRACT_TGT_PLATFORM_KEY="" + +unstract clone \ + --source-url https://dev.example.com --source-org org_dev123 \ + --target-url https://qa.example.com --target-org org_qa456 \ + --dry-run +``` + +Drop `--dry-run` to perform the clone. + +| Option | Description | +|--------|-------------| +| `--dry-run` | Plan only; nothing is written. | +| `--include` / `--exclude` | Comma-separated phases to run / skip. | +| `--on-name-conflict` | `adopt` (default) reuses like-named target resources; `abort` stops. | +| `--clone-group-members` | Also add group members on target, matched by email. | +| `--source-key` / `--target-key` | Platform API keys, if not set via env vars. | +| `--api-prefix` | Backend URL prefix (default `api/v1`). | + +Re-runs are idempotent: existing target resources are adopted by name, so a +failed run can be resumed by re-running the same command. + +| Exit code | Meaning | +|------|---------| +| `0` | Success. | +| `1` | Completed with failures — see the printed report. | +| `2` | Could not run (setup error or `--on-name-conflict=abort` collision). | + ## Questions and Feedback On Slack, [join great conversations](https://join-slack.unstract.com/) around LLMs, their ecosystem and leveraging them to automate the previously unautomatable! diff --git a/pyproject.toml b/pyproject.toml index a208e3a..0cb4dff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,8 @@ authors = [ dependencies = [ "requests>=2.32.3", "tenacity>=8.2.0", + "click>=8.1", + "rich>=13.7", ] requires-python = ">=3.11" readme = "README.md" @@ -25,14 +27,8 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] -[project.optional-dependencies] -clone = [ - "click>=8.1", - "rich>=13.7", -] - [project.scripts] -unstract-clone = "unstract.clone.cli:main" +unstract = "unstract.cli:main" [build-system] requires = ["hatchling"] diff --git a/src/unstract/cli.py b/src/unstract/cli.py new file mode 100644 index 0000000..9634111 --- /dev/null +++ b/src/unstract/cli.py @@ -0,0 +1,30 @@ +"""Top-level ``unstract`` command group. + +Subcommands live in their own subpackages and are registered here so a +single console script (``unstract``) fronts all of them. ``unstract.clone`` +keeps its own group + ``main`` so ``python -m unstract.clone`` still works. +""" + +from __future__ import annotations + +from typing import Any + +import click + +from unstract.clone.cli import clone_cmd + + +@click.group(name="unstract") +def cli() -> None: + """Unstract command-line tools.""" + + +cli.add_command(clone_cmd, name="clone") + + +def main(argv: list[str] | None = None) -> Any: + return cli(args=argv, standalone_mode=True) + + +if __name__ == "__main__": + main() diff --git a/src/unstract/clone/cli.py b/src/unstract/clone/cli.py index d2ed358..43f3a09 100644 --- a/src/unstract/clone/cli.py +++ b/src/unstract/clone/cli.py @@ -1,9 +1,12 @@ """Click-based CLI for ``unstract.clone``. -Single ``clone`` command. Platform keys can be passed via flags -(``--source-key`` / ``--target-key``) or env vars -(``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) — env vars -are preferred so the key never lands in shell history. +Single ``clone`` command, registered on the top-level ``unstract`` group +(``unstract.cli``) — the canonical invocation is ``unstract clone``. The +local group here only backs ``python -m unstract.clone``. + +Platform keys can be passed via flags (``--source-key`` / ``--target-key``) +or env vars (``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) +— env vars are preferred so the key never lands in shell history. """ from __future__ import annotations @@ -141,6 +144,12 @@ def cli() -> None: show_default=True, help="Per-phase worker count. 1 = strictly sequential.", ) +@click.option( + "--clone-group-members", + is_flag=True, + help="Also add group members on target, matched by email. " + "Members missing on target are skipped and reported.", +) @click.option("-v", "--verbose", is_flag=True, help="Debug logging") def clone_cmd( source_url: str, @@ -158,6 +167,7 @@ def clone_cmd( max_file_size: str, skip_files: bool, concurrency: int, + clone_group_members: bool, verbose: bool, ) -> None: """Clone configured resources from one org to another.""" @@ -178,6 +188,7 @@ def clone_cmd( file_strategy=effective_strategy, max_file_size=cap_bytes if cap_bytes is not None else DEFAULT_MAX_FILE_SIZE, concurrency=concurrency, + clone_group_members=clone_group_members, ) source = OrgEndpoint( diff --git a/src/unstract/clone/client.py b/src/unstract/clone/client.py index ad873da..d64ec78 100644 --- a/src/unstract/clone/client.py +++ b/src/unstract/clone/client.py @@ -113,6 +113,39 @@ def get_post_schema(self, entity_path: str) -> frozenset[str]: self._post_schema_cache[entity_path] = writable return writable + # ----- org users & groups ----- + + def list_users(self) -> list[dict[str, Any]]: + """List org member rows (each carries ``id`` and ``email``).""" + result = self._request("GET", "users/") + return (result or {}).get("members", []) + + def list_groups(self) -> list[dict[str, Any]]: + """List org groups; no server-side name filter — callers match in memory.""" + result = self._request("GET", "groups/") + return result if isinstance(result, list) else result.get("results", []) + + def create_group(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a group; response has no ``id`` — re-list to learn the pk.""" + return self._request("POST", "groups/", json=payload) + + def list_group_members(self, group_id: Any) -> list[dict[str, Any]]: + """List a group's member rows (each carries ``email``).""" + result = self._request("GET", f"groups/{group_id}/members/") + return result if isinstance(result, list) else result.get("results", []) + + def add_group_members(self, group_id: Any, user_ids: list[int]) -> Any: + """Bulk-add members by user pk; idempotent server-side.""" + return self._request( + "POST", f"groups/{group_id}/members/", json={"user_ids": user_ids} + ) + + # ----- sharing ----- + + def share_resource(self, share_path: str, payload: dict[str, Any]) -> Any: + """Replace-style share update; axes omitted from ``payload`` are untouched.""" + return self._request("POST", share_path, json=payload) + # ----- adapters ----- def list_adapters( diff --git a/src/unstract/clone/context.py b/src/unstract/clone/context.py index 833668f..8441748 100644 --- a/src/unstract/clone/context.py +++ b/src/unstract/clone/context.py @@ -12,8 +12,9 @@ from __future__ import annotations +import threading from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from unstract.clone.client import PlatformClient @@ -53,6 +54,8 @@ class CloneOptions: max_file_size: int = DEFAULT_MAX_FILE_SIZE # Per-phase worker fan-out. 1 = sequential (no executor). concurrency: int = DEFAULT_CONCURRENCY + # Group phase: also add members (matched by email) to cloned groups. + clone_group_members: bool = False def includes(self, phase_name: str) -> bool: if self.include is not None and phase_name not in self.include: @@ -107,3 +110,7 @@ class CloneContext: # Source prompt_registry_ids whose CustomTool was skipped; used to # cascade-skip dependent workflows downstream. skipped_custom_tool_registry_ids: set[str] = field(default_factory=set) + # Per-run memo for users/groups directory listings (sharing replication + # touches them once per endpoint, never per resource). + share_cache: dict[str, Any] = field(default_factory=dict) + share_cache_lock: threading.Lock = field(default_factory=threading.Lock) diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py index a7c81a0..9fbcd04 100644 --- a/src/unstract/clone/orchestrator.py +++ b/src/unstract/clone/orchestrator.py @@ -23,6 +23,7 @@ ConnectorPhase, CustomToolPhase, FilesPhase, + GroupPhase, PipelinePhase, TagPhase, ToolInstancePhase, @@ -35,12 +36,15 @@ logger = logging.getLogger(__name__) # Strict dependency order. Each entry: (phase_name, phase_class). -# Adapter, connector, tag are independent leaf phases. Downstream phases -# (custom_tool, workflow, tool_instance, workflow_endpoint) land later -# and consume the remap entries these produce. Pipeline + api_deployment -# come last: both FK the workflow and api_deployment additionally -# requires endpoints to be configured before the serializer accepts it. +# Group runs first: every shareable phase consumes its remap entries when +# replicating shared_groups. Adapter, connector, tag are independent leaf +# phases. Downstream phases (custom_tool, workflow, tool_instance, +# workflow_endpoint) land later and consume the remap entries these +# produce. Pipeline + api_deployment come last: both FK the workflow and +# api_deployment additionally requires endpoints to be configured before +# the serializer accepts it. PHASES: list[tuple[str, type[Phase]]] = [ + ("group", GroupPhase), ("adapter", AdapterPhase), ("connector", ConnectorPhase), ("tag", TagPhase), diff --git a/src/unstract/clone/phases/__init__.py b/src/unstract/clone/phases/__init__.py index 03f0952..0c3a9a6 100644 --- a/src/unstract/clone/phases/__init__.py +++ b/src/unstract/clone/phases/__init__.py @@ -13,6 +13,7 @@ from unstract.clone.phases.connector import ConnectorPhase from unstract.clone.phases.custom_tool import CustomToolPhase from unstract.clone.phases.files import FilesPhase +from unstract.clone.phases.group import GroupPhase from unstract.clone.phases.pipeline import PipelinePhase from unstract.clone.phases.tag import TagPhase from unstract.clone.phases.tool_instance import ToolInstancePhase @@ -25,6 +26,7 @@ "ConnectorPhase", "CustomToolPhase", "FilesPhase", + "GroupPhase", "Phase", "PipelinePhase", "TagPhase", diff --git a/src/unstract/clone/phases/adapter.py b/src/unstract/clone/phases/adapter.py index c98bf63..28082c5 100644 --- a/src/unstract/clone/phases/adapter.py +++ b/src/unstract/clone/phases/adapter.py @@ -26,6 +26,7 @@ class AdapterPhase(Phase): name = "adapter" + share_path_template = "adapter/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -123,3 +124,7 @@ def _clone_one( with lock: self.ctx.remap.record("adapter", src_id, tgt["id"]) + # Source detail (fetched above) carries the share axes. + self.apply_share( + src=src, tgt_id=tgt["id"], label=name, result=result, lock=lock + ) diff --git a/src/unstract/clone/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py index 0d2575f..b707045 100644 --- a/src/unstract/clone/phases/api_deployment.py +++ b/src/unstract/clone/phases/api_deployment.py @@ -29,6 +29,7 @@ class APIDeploymentPhase(Phase): name = "api_deployment" + share_path_template = "api/deployment/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -150,6 +151,15 @@ def _clone_one( with lock: self.ctx.remap.record("api_deployment", src_id, tgt["id"]) + # List rows omit the share axes — fetch source detail when needed. + self.apply_share( + src=src, + tgt_id=tgt["id"], + label=api_name, + result=result, + lock=lock, + src_detail_fn=lambda: self.ctx.source.get_api_deployment(src_id), + ) def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: try: diff --git a/src/unstract/clone/phases/base.py b/src/unstract/clone/phases/base.py index e83e00e..c7710d3 100644 --- a/src/unstract/clone/phases/base.py +++ b/src/unstract/clone/phases/base.py @@ -12,6 +12,7 @@ from unstract.clone.context import CloneContext from unstract.clone.exceptions import CloneError from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.sharing import apply_share_state T = TypeVar("T") @@ -22,6 +23,8 @@ # either noise (silently overwritten) or a 400 (when a source-org value # doesn't validate against the target org). Strip them universally — # the phase OPTIONS schema covers the entity-specific writable subset. +# ``shared_users`` stays stripped on create — share state is replicated +# post-create instead (see sharing.py). SERVER_MANAGED: frozenset[str] = frozenset( { "id", @@ -55,6 +58,9 @@ class Phase(ABC): """Abstract phase. One subclass per entity type.""" name: str = "" + # Share endpoint template for shareable resource types, e.g. + # "adapter/{id}/share/" ({id} = target pk). None = not shareable. + share_path_template: str | None = None def __init__(self, ctx: CloneContext): self.ctx = ctx @@ -64,6 +70,34 @@ def run(self, report: CloneReport) -> PhaseResult: """Migrate all entities of this phase's type. Idempotent across runs.""" raise NotImplementedError + def apply_share( + self, + *, + src: dict[str, Any], + tgt_id: str, + label: str, + result: PhaseResult, + lock: threading.Lock, + src_detail_fn: Callable[[], dict[str, Any]] | None = None, + ) -> None: + """Replicate ``src``'s share state onto the target entity. + + Pass ``src_detail_fn`` when ``src`` may be a stripped list-row — + the helper fetches the detail only if a share axis is missing. + No-op for phases without ``share_path_template``; never raises. + """ + if self.share_path_template is None: + return + apply_share_state( + self.ctx, + share_path=self.share_path_template.format(id=tgt_id), + entity_label=f"{self.name} '{label}'", + src=src, + result=result, + lock=lock, + src_detail_fn=src_detail_fn, + ) + def parallel_map( self, items: Iterable[T], diff --git a/src/unstract/clone/phases/connector.py b/src/unstract/clone/phases/connector.py index 2a30b93..3243542 100644 --- a/src/unstract/clone/phases/connector.py +++ b/src/unstract/clone/phases/connector.py @@ -43,6 +43,7 @@ def _has_oauth_tokens(metadata: dict[str, Any]) -> bool: class ConnectorPhase(Phase): name = "connector" + share_path_template = "connector/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -157,3 +158,7 @@ def _clone_one( with lock: self.ctx.remap.record("connector", src_id, tgt["id"]) + # Source detail (fetched above) carries the share axes. + self.apply_share( + src=src, tgt_id=tgt["id"], label=name, result=result, lock=lock + ) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py index 03eec4f..53c2936 100644 --- a/src/unstract/clone/phases/custom_tool.py +++ b/src/unstract/clone/phases/custom_tool.py @@ -18,6 +18,8 @@ 5. Republishes ``PromptStudioRegistry`` via the export action and records the ``custom_tool`` + ``prompt_studio_registry`` remaps so downstream ToolInstancePhase can rewrite ``ToolInstance.tool_id``. + Skipped for tools with no source registry entry (never exported — + e.g. empty projects, which the backend refuses to export). Adapter id discovery for the fresh path needs all four of LLM, vector_db, embedding, x2text. If any source adapter can't be resolved @@ -59,6 +61,7 @@ def _extract_adapter_name(value: Any) -> str | None: class CustomToolPhase(Phase): name = "custom_tool" + share_path_template = "prompt-studio/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -148,9 +151,44 @@ def _clone_one( with lock: self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + # Neither the export blob nor list rows carry share axes — + # share state comes from the source detail. + self.apply_share( + src={}, + tgt_id=tgt_tool_id, + label=tool_name, + result=result, + lock=lock, + src_detail_fn=lambda: self.ctx.source.get_custom_tool(src_tool_id), + ) + if self.ctx.options.dry_run: return + # Tools never exported on source (e.g. empty projects — backend + # blocks their export) have no registry entry and no workflow + # references; republishing would fail the same backend guard. + try: + src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) + except Exception as e: + logger.warning( + "source registry lookup failed for tool '%s' " + "(downstream ToolInstance clone may skip): %s", + tool_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"registry remap lookup {tool_name}: {e}") + return + + if not src_regs: + logger.info( + "tool '%s' was never exported on source; skipping registry republish", + tool_name, + ) + return + try: self.ctx.target.export_custom_tool(tgt_tool_id) logger.info( @@ -164,11 +202,10 @@ def _clone_one( return try: - src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) except Exception as e: logger.warning( - "registry remap lookup failed for tool '%s' " + "target registry lookup failed for tool '%s' " "(downstream ToolInstance clone may skip): %s", tool_name, e, @@ -263,9 +300,7 @@ def _create_fresh( ) return None - adapter_ids = self._resolve_target_adapter_ids( - default_profile, tool_name - ) + adapter_ids = self._resolve_target_adapter_ids(default_profile, tool_name) if adapter_ids is None: with lock: result.failed += 1 diff --git a/src/unstract/clone/phases/group.py b/src/unstract/clone/phases/group.py new file mode 100644 index 0000000..1b54f7b --- /dev/null +++ b/src/unstract/clone/phases/group.py @@ -0,0 +1,202 @@ +"""Migrate org user groups from source org to target org. + +Groups are matched by name and a like-named target group is always reused +(idempotent merge) — ``--on-name-conflict`` does not apply because merging +a sharing container cannot lose configuration. Runs first so downstream +phases can remap group ids when replicating share state. + +With ``--clone-group-members`` each group's members are matched to +target-org users by email and bulk-added; misses surface as report +warnings. Service accounts never migrate. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.sharing import is_service_account, target_user_id_by_email + +logger = logging.getLogger(__name__) + +GROUP_PATH = "groups/" + + +class GroupPhase(Phase): + name = "group" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_groups = self.ctx.source.list_groups() + except Exception as e: + logger.exception("Failed to list source groups: %s", e) + result.failed += 1 + result.errors.append(f"list source groups: {e}") + return result + try: + tgt_groups = self.ctx.target.list_groups() + except Exception as e: + logger.exception("Failed to list target groups: %s", e) + result.failed += 1 + result.errors.append(f"list target groups: {e}") + return result + + # Single listing for the whole phase — the endpoint has no name + # filter. Mutated under lock as creates land. + target_by_name: dict[str, dict[str, Any]] = {g["name"]: g for g in tgt_groups} + + logger.info("Found %d group(s) in source org", len(src_groups)) + self.parallel_map( + src_groups, + lambda src, lock: self._clone_one(src, target_by_name, result, lock), + ) + return result + + def _clone_one( + self, + src: dict[str, Any], + target_by_name: dict[str, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + name = src["name"] + src_id = src["id"] + + with lock: + tgt = target_by_name.get(name) + + if tgt is not None: + with lock: + result.adopted += 1 + logger.info("reusing group '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create group '%s' src=%s", name, src_id) + if self.ctx.options.clone_group_members: + # Still computed so would-skip members show up in the report. + self._clone_members(src, None, result, lock) + return + else: + tgt = self._create_group(src, result, lock) + if tgt is None: + return + with lock: + result.created += 1 + target_by_name[name] = tgt + logger.info("created group '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + + with lock: + self.ctx.remap.record("group", str(src_id), str(tgt["id"])) + + if self.ctx.options.clone_group_members: + self._clone_members(src, tgt, result, lock) + + def _create_group( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> dict[str, Any] | None: + name = src["name"] + payload = {"name": name, "description": src.get("description") or ""} + try: + self.ctx.target.create_group(payload) + # Create response has no id — re-list and match by name. + created = next( + (g for g in self.ctx.target.list_groups() if g["name"] == name), + None, + ) + except Exception as e: + logger.exception("Failed to create group %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return None + if created is None: + logger.error("group '%s' created but missing from target listing", name) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: not found in target after POST") + return None + return created + + def _clone_members( + self, + src_group: dict[str, Any], + tgt_group: dict[str, Any] | None, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + """Email-match source members to target users and bulk-add the hits. + + ``tgt_group`` is None only on the dry-run would-create path; the + matching (and its warnings) still run for report visibility. + """ + name = src_group["name"] + try: + members = self.ctx.source.list_group_members(src_group["id"]) + except Exception as e: + logger.exception("Failed to list members of group %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"list members {name}: {e}") + return + if not members: + return + + tgt_by_email = target_user_id_by_email(self.ctx) + if tgt_by_email is None: + with lock: + result.warnings.append( + f"group '{name}': target users listing unavailable — " + f"{len(members)} member(s) not migrated" + ) + return + + to_add: list[int] = [] + skipped: list[str] = [] + for member in members: + email = (member.get("email") or "").lower() + if not email or is_service_account(member): + continue + tgt_uid = tgt_by_email.get(email) + if tgt_uid is None: + skipped.append(email) + else: + to_add.append(tgt_uid) + + if skipped: + with lock: + result.warnings.extend( + f"group '{name}': member {email} not found in target org — skipped" + for email in skipped + ) + + if self.ctx.options.dry_run: + logger.info( + "[dry-run] would add %d member(s) to group '%s' (%d skipped)", + len(to_add), + name, + len(skipped), + ) + return + if not to_add: + return + + try: + # Bulk-add is idempotent server-side, so adopt re-runs are safe. + self.ctx.target.add_group_members(tgt_group["id"], to_add) + except Exception as e: + logger.exception("Failed to add members to group %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"add members {name}: {e}") + return + logger.info( + "added %d member(s) to group '%s' (%d skipped)", + len(to_add), + name, + len(skipped), + ) diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py index cdad5f0..121f3a4 100644 --- a/src/unstract/clone/phases/pipeline.py +++ b/src/unstract/clone/phases/pipeline.py @@ -30,6 +30,7 @@ class PipelinePhase(Phase): name = "pipeline" + share_path_template = "pipeline/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -147,6 +148,15 @@ def _clone_one( with lock: self.ctx.remap.record("pipeline", src_id, tgt["id"]) + # List rows carry the share axes; detail fn is a safety net. + self.apply_share( + src=src, + tgt_id=tgt["id"], + label=name, + result=result, + lock=lock, + src_detail_fn=lambda: self.ctx.source.get_pipeline(src_id), + ) def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: try: diff --git a/src/unstract/clone/phases/workflow.py b/src/unstract/clone/phases/workflow.py index 2612eaf..7e3d3f7 100644 --- a/src/unstract/clone/phases/workflow.py +++ b/src/unstract/clone/phases/workflow.py @@ -31,6 +31,7 @@ class WorkflowPhase(Phase): name = "workflow" + share_path_template = "workflow/{id}/share/" def run(self, report: CloneReport) -> PhaseResult: result = report.get_phase(self.name) @@ -133,9 +134,7 @@ def _clone_one( try: src_detail = self.ctx.source.get_workflow(src_id) except Exception as e: - logger.exception( - "Failed to GET source workflow %s detail: %s", name, e - ) + logger.exception("Failed to GET source workflow %s detail: %s", name, e) with lock: result.failed += 1 result.errors.append(f"GET source detail {name}: {e}") @@ -158,3 +157,12 @@ def _clone_one( with lock: self.ctx.remap.record("workflow", src_id, tgt["id"]) + # List rows carry the share axes; detail fn is a safety net. + self.apply_share( + src=src, + tgt_id=tgt["id"], + label=name, + result=result, + lock=lock, + src_detail_fn=lambda: self.ctx.source.get_workflow(src_id), + ) diff --git a/src/unstract/clone/report.py b/src/unstract/clone/report.py index a0f63d1..bfc52ea 100644 --- a/src/unstract/clone/report.py +++ b/src/unstract/clone/report.py @@ -22,6 +22,9 @@ class PhaseResult: skipped: int = 0 failed: int = 0 errors: list[str] = field(default_factory=list) + # Non-fatal anomalies (e.g. share-state members missing on target). + # Surfaced in the report but never affect counts or the exit code. + warnings: list[str] = field(default_factory=list) duration_s: float = 0.0 @@ -77,6 +80,7 @@ def render(self) -> str: ) # Actionable summary first so it doesn't scroll past the table. self._render_failures_summary(console_print=console.print, rich=True) + self._render_warnings_summary(console_print=console.print, rich=True) self._render_endpoints(console.print) table = Table(title="Clone Report", header_style="bold cyan") table.add_column("Phase", style="bold", justify="left") @@ -155,6 +159,7 @@ def _fmt_duration_plain(seconds: float) -> str: def _render_plain(self) -> str: lines: list[str] = [] self._render_failures_summary(console_print=lines.append, rich=False) + self._render_warnings_summary(console_print=lines.append, rich=False) lines.extend(["Clone Report", "=" * 60]) self._render_endpoints(lines.append) header = ( @@ -204,6 +209,7 @@ def as_dict(self) -> dict[str, Any]: "skipped": p.skipped, "failed": p.failed, "errors": list(p.errors), + "warnings": list(p.warnings), "duration_s": p.duration_s, } for p in self.phases @@ -319,6 +325,36 @@ def _render_failures_summary(self, console_print: Any, rich: bool) -> None: else: console_print(tail) + def _render_warnings_summary(self, console_print: Any, rich: bool) -> None: + rows: list[tuple[str, str]] = [] + for p in self.phases: + for warning in p.warnings: + rows.append((p.name, warning)) + if not rows: + return + header = "Warnings (non-fatal; operator follow-up may be needed)" + if rich: + console_print(f"[yellow]{header}:[/yellow]") + else: + console_print(f"{header}:") + shown = rows[: self._FAILURE_MAX_ROWS] + for phase_name, warning in shown: + truncated = self._truncate(warning, self._FAILURE_LINE_MAX_CHARS) + if rich: + console_print( + f" - [bold cyan]{phase_name}[/bold cyan]: {truncated}", + highlight=False, + ) + else: + console_print(f" - {phase_name}: {truncated}") + remaining = len(rows) - len(shown) + if remaining > 0: + tail = f" ... +{remaining} more — see logs" + if rich: + console_print(f"[dim]{tail}[/dim]") + else: + console_print(tail) + @staticmethod def _truncate(text: str, limit: int) -> str: text = text.replace("\n", " ") diff --git a/src/unstract/clone/sharing.py b/src/unstract/clone/sharing.py new file mode 100644 index 0000000..798a997 --- /dev/null +++ b/src/unstract/clone/sharing.py @@ -0,0 +1,209 @@ +"""Replicate a resource's share state onto its cloned counterpart. + +Share state is server-managed on create, so it is mirrored post-create via +the resource's share endpoint: groups map through the ``group`` remap (axis +omitted with a warning when the group phase is excluded), the org flag is +copied as-is, and users map by email. Users missing on the target are +skipped with a warning; service accounts and the source owner are skipped +silently. Users/groups listings are memoised per run in +``CloneContext.share_cache``. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from unstract.clone.context import CloneContext + from unstract.clone.report import PhaseResult + +logger = logging.getLogger(__name__) + +SHARE_AXES: tuple[str, ...] = ("shared_users", "shared_groups", "shared_to_org") +# Platform-key identities; they exist per-org and never map across orgs. +SERVICE_ACCOUNT_EMAIL_SUFFIX = "@platform.internal" + +_FETCH_FAILED = object() # cache sentinel so a failing listing isn't re-hit + + +def is_service_account(row: dict[str, Any]) -> bool: + """True if a user/member listing row is a service account. + + Email-suffix fallback covers older backends without the flag; + mis-classification is benign — a service-account email never matches + across orgs, so worst case is a spurious skip-warning. + """ + flag = row.get("is_service_account") + if flag is not None: + return bool(flag) + return (row.get("email") or "").lower().endswith(SERVICE_ACCOUNT_EMAIL_SUFFIX) + + +def _cached(ctx: CloneContext, key: str, build: Callable[[], Any]) -> Any: + with ctx.share_cache_lock: + if key in ctx.share_cache: + return ctx.share_cache[key] + # Build outside the lock (HTTP call); worst case two threads race and + # one result wins — the listings are read-only so that's harmless. + try: + value = build() + except Exception as e: + logger.warning("share replication: %s listing failed: %s", key, e) + value = _FETCH_FAILED + with ctx.share_cache_lock: + ctx.share_cache.setdefault(key, value) + return ctx.share_cache[key] + + +def source_user_by_id(ctx: CloneContext) -> dict[str, dict[str, Any]] | None: + """Map source user pk (as str) -> listing row; ``None`` if the listing failed.""" + value = _cached( + ctx, + "source_user_by_id", + lambda: {str(m["id"]): m for m in ctx.source.list_users() if m.get("email")}, + ) + return None if value is _FETCH_FAILED else value + + +def target_user_id_by_email(ctx: CloneContext) -> dict[str, int] | None: + """Map lowercased email -> target user pk; ``None`` if the listing failed.""" + value = _cached( + ctx, + "target_user_id_by_email", + lambda: { + m["email"].lower(): int(m["id"]) + for m in ctx.target.list_users() + if m.get("email") + }, + ) + return None if value is _FETCH_FAILED else value + + +def apply_share_state( + ctx: CloneContext, + *, + share_path: str, + entity_label: str, + src: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + src_detail_fn: Callable[[], dict[str, Any]] | None = None, +) -> None: + """Mirror ``src``'s share state onto the target resource at ``share_path``. + + ``src`` may be a stripped list-row; when any share axis is missing and + ``src_detail_fn`` is given, the source detail is fetched once. No-ops + when the effective share state is empty. Never raises — failures land + in ``result.errors`` (counted) and skips in ``result.warnings``. + """ + share_src = src + if src_detail_fn is not None and not all(k in share_src for k in SHARE_AXES): + try: + share_src = src_detail_fn() + except Exception as e: + logger.warning("share %s: source detail fetch failed: %s", entity_label, e) + with lock: + result.warnings.append( + f"share {entity_label}: source detail fetch failed — " + f"share state not replicated: {e}" + ) + return + + shared_to_org = bool(share_src.get("shared_to_org")) + src_group_ids = list(share_src.get("shared_groups") or []) + src_user_ids = list(share_src.get("shared_users") or []) + owner_id = share_src.get("created_by") + + payload: dict[str, Any] = {"shared_to_org": shared_to_org} + + group_warnings: list[str] = [] + if src_group_ids and not ctx.options.includes("group"): + # Axis omitted entirely so the target's group shares are untouched. + group_warnings.append( + f"share {entity_label}: group phase excluded — " + f"{len(src_group_ids)} group share(s) not replicated" + ) + mapped_groups: list[int] | None = None + else: + mapped_groups = [] + for gid in src_group_ids: + tgt_gid = ctx.remap.resolve("group", str(gid)) + if tgt_gid is None: + group_warnings.append( + f"share {entity_label}: source group id {gid} has no " + "target mapping — skipped" + ) + else: + mapped_groups.append(int(tgt_gid)) + payload["shared_groups"] = mapped_groups + + user_warnings: list[str] = [] + mapped_users: list[int] = [] + if src_user_ids: + src_users = source_user_by_id(ctx) + tgt_by_email = target_user_id_by_email(ctx) + if src_users is None or tgt_by_email is None: + user_warnings.append( + f"share {entity_label}: users listing unavailable — " + f"{len(src_user_ids)} user share(s) not replicated" + ) + else: + for uid in src_user_ids: + if owner_id is not None and uid == owner_id: + continue # ownership is server-managed on target + row = src_users.get(str(uid)) + if row is None: + user_warnings.append( + f"share {entity_label}: source user id {uid} not in " + "source users listing — skipped" + ) + continue + if is_service_account(row): + continue + email = row["email"] + tgt_uid = tgt_by_email.get(email.lower()) + if tgt_uid is None: + user_warnings.append( + f"share {entity_label}: user {email} not found in " + "target org — skipped" + ) + else: + mapped_users.append(tgt_uid) + payload["shared_users"] = mapped_users + + with lock: + result.warnings.extend(group_warnings) + result.warnings.extend(user_warnings) + + if not mapped_users and not payload.get("shared_groups") and not shared_to_org: + logger.debug("share %s: nothing to replicate", entity_label) + return + + if ctx.options.dry_run: + logger.info( + "[dry-run] would share %s: users=%s groups=%s org=%s", + entity_label, + mapped_users, + payload.get("shared_groups"), + shared_to_org, + ) + return + + try: + ctx.target.share_resource(share_path, payload) + except Exception as e: + logger.exception("Failed to apply share state for %s: %s", entity_label, e) + with lock: + result.failed += 1 + result.errors.append(f"share {entity_label}: {e}") + return + logger.info( + "shared %s: users=%s groups=%s org=%s", + entity_label, + mapped_users, + payload.get("shared_groups"), + shared_to_org, + ) diff --git a/tests/clone/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py index 30edfb1..5c6746e 100644 --- a/tests/clone/test_custom_tool_phase.py +++ b/tests/clone/test_custom_tool_phase.py @@ -363,6 +363,31 @@ def test_frictionless_adapter_dependence_skips_tool_and_records_for_cascade(): assert SRC_REG in ctx.skipped_custom_tool_registry_ids +def test_never_exported_source_tool_skips_registry_republish(): + """A source tool with no registry entry (e.g. an empty project — the + backend blocks exporting those) clones cleanly without republishing. + """ + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Empty Project") + del src.registries_by_tool["src-tool-x"] + src.export_blobs["src-tool-x"]["prompts"] = [] + _seed_source_adapters(src) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + # No registry on source → republish must not fire (it would hit the + # backend's empty-project export guard). + assert tgt.export_tool_calls == [] + # Tool remap still recorded; registry remap absent. + assert ctx.remap.resolve("custom_tool", "src-tool-x") is not None + assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) is None + + def test_missing_target_adapter_fails_tool_cleanly(): src = FakeClient() tgt = FakeClient() diff --git a/tests/clone/test_group_phase.py b/tests/clone/test_group_phase.py new file mode 100644 index 0000000..9ef1490 --- /dev/null +++ b/tests/clone/test_group_phase.py @@ -0,0 +1,167 @@ +"""Tests for ``GroupPhase``. + +Groups merge idempotently by name (no rename, no abort) and optionally +clone members matched by email — missing members surface as warnings. +""" + +from __future__ import annotations + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.phases.group import GroupPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + def __init__( + self, + groups: list[dict] | None = None, + members: dict[int, list[dict]] | None = None, + users: list[dict] | None = None, + ): + self.groups: list[dict] = list(groups or []) + self.members: dict[int, list[dict]] = dict(members or {}) + self.users: list[dict] = list(users or []) + self.group_posts: list[dict] = [] + self.member_posts: list[tuple[int, list[int]]] = [] + self._next_id = 100 + + def list_groups(self): + return list(self.groups) + + def create_group(self, payload): + new = { + "id": self._next_id, + "name": payload["name"], + "description": payload.get("description", ""), + } + self._next_id += 1 + self.groups.append(new) + self.group_posts.append(payload) + # Backend echoes only name/description (no id) — callers re-list. + return {"name": payload["name"], "description": payload.get("description", "")} + + def list_group_members(self, group_id): + return list(self.members.get(group_id, [])) + + def add_group_members(self, group_id, user_ids): + self.member_posts.append((group_id, list(user_ids))) + return {"added_user_ids": list(user_ids)} + + def list_users(self): + return list(self.users) + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def _grp(id_, name, description="d"): + return {"id": id_, "name": name, "description": description} + + +def test_creates_missing_groups_and_records_remap(): + src = FakeClient([_grp(1, "devs"), _grp(2, "qa")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + result = GroupPhase(ctx).run(CloneReport()) + + assert result.created == 2 + assert result.failed == 0 + created_names = {g["name"] for g in tgt.groups} + assert created_names == {"devs", "qa"} + # remap is keyed by stringified int pks + tgt_devs = next(g for g in tgt.groups if g["name"] == "devs") + assert ctx.remap.resolve("group", "1") == str(tgt_devs["id"]) + + +def test_reuses_like_named_group_even_with_abort_conflict_mode(): + """Idempotent merge: never error or rename on name collision.""" + src = FakeClient([_grp(1, "devs")]) + tgt = FakeClient([_grp(50, "devs")]) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + result = GroupPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert result.failed == 0 + assert tgt.group_posts == [] + assert ctx.remap.resolve("group", "1") == "50" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_grp(1, "devs")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + + result = GroupPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.group_posts == [] + assert ctx.remap.resolve("group", "1") is None + + +def test_members_not_cloned_by_default(): + src = FakeClient( + [_grp(1, "devs")], + members={1: [{"user_id": 7, "email": "a@x.com"}]}, + ) + tgt = FakeClient(users=[{"id": "70", "email": "a@x.com"}]) + ctx = _ctx(src, tgt) + + GroupPhase(ctx).run(CloneReport()) + + assert tgt.member_posts == [] + + +def test_member_cloning_matches_by_email_and_skips_missing(): + src = FakeClient( + [_grp(1, "devs")], + members={ + 1: [ + {"user_id": 7, "email": "alice@x.com"}, + {"user_id": 8, "email": "ghost@x.com"}, # not in target org + # service acct via email-suffix fallback (no flag in row) + {"user_id": 9, "email": "svc@platform.internal"}, + # service acct via backend flag (email alone wouldn't tell) + {"user_id": 10, "email": "bot@x.com", "is_service_account": True}, + ] + }, + ) + tgt = FakeClient( + users=[ + {"id": "70", "email": "Alice@X.com"}, # case-insensitive match + {"id": "71", "email": "bob@x.com"}, + {"id": "72", "email": "bot@x.com"}, + ] + ) + ctx = _ctx(src, tgt, clone_group_members=True) + + result = GroupPhase(ctx).run(CloneReport()) + + tgt_group_id = next(g for g in tgt.groups if g["name"] == "devs")["id"] + assert tgt.member_posts == [(tgt_group_id, [70])] + assert any("ghost@x.com" in w for w in result.warnings) + # service accounts are skipped silently, not warned about + assert not any("platform.internal" in w for w in result.warnings) + assert not any("bot@x.com" in w for w in result.warnings) + + +def test_member_cloning_dry_run_warns_but_never_posts(): + src = FakeClient( + [_grp(1, "devs")], + members={1: [{"user_id": 8, "email": "ghost@x.com"}]}, + ) + tgt = FakeClient(users=[{"id": "70", "email": "alice@x.com"}]) + ctx = _ctx(src, tgt, dry_run=True, clone_group_members=True) + + result = GroupPhase(ctx).run(CloneReport()) + + assert tgt.member_posts == [] + assert any("ghost@x.com" in w for w in result.warnings) diff --git a/tests/clone/test_sharing.py b/tests/clone/test_sharing.py new file mode 100644 index 0000000..9ad0c4b --- /dev/null +++ b/tests/clone/test_sharing.py @@ -0,0 +1,237 @@ +"""Tests for share-state replication (``unstract.clone.sharing``). + +Covers payload building (user email mapping, group remap, org flag), +the skip-when-group-phase-excluded axis, the empty-state short circuit, +dry-run behaviour and the SERVER_MANAGED guarantee that ``shared_users`` +never rides the create POST. +""" + +from __future__ import annotations + +import threading + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.phases.base import build_post_payload +from unstract.clone.report import PhaseResult +from unstract.clone.sharing import apply_share_state + + +class FakeClient: + def __init__(self, users: list[dict] | None = None): + self.users: list[dict] = list(users or []) + self.share_posts: list[tuple[str, dict]] = [] + + def list_users(self): + return list(self.users) + + def share_resource(self, share_path, payload): + self.share_posts.append((share_path, payload)) + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def _apply(ctx, src, result=None, src_detail_fn=None): + result = result if result is not None else PhaseResult(name="adapter") + apply_share_state( + ctx, + share_path="adapter/tgt-1/share/", + entity_label="adapter 'demo'", + src=src, + result=result, + lock=threading.Lock(), + src_detail_fn=src_detail_fn, + ) + return result + + +def test_share_payload_maps_users_groups_and_org_flag(): + src_client = FakeClient( + users=[ + {"id": "7", "email": "alice@x.com"}, + # service account via email-suffix fallback (no flag in row) + {"id": "8", "email": "svc@platform.internal"}, + # service account via backend flag (email alone wouldn't tell) + {"id": "9", "email": "bot@x.com", "is_service_account": True}, + ] + ) + tgt_client = FakeClient( + users=[ + {"id": "70", "email": "ALICE@x.com"}, + {"id": "71", "email": "bot@x.com"}, + ] + ) + ctx = _ctx(src_client, tgt_client) + ctx.remap.record("group", "1", "10") + + result = _apply( + ctx, + { + "created_by": 99, + "shared_users": [7, 8, 9, 99], # svc accounts + owner must be dropped + "shared_groups": [1], + "shared_to_org": True, + }, + ) + + assert tgt_client.share_posts == [ + ( + "adapter/tgt-1/share/", + {"shared_to_org": True, "shared_groups": [10], "shared_users": [70]}, + ) + ] + assert result.failed == 0 + assert result.warnings == [] + + +def test_share_skips_users_missing_on_target_with_warning(): + src_client = FakeClient(users=[{"id": "7", "email": "ghost@x.com"}]) + tgt_client = FakeClient(users=[]) + ctx = _ctx(src_client, tgt_client) + + result = _apply( + ctx, + {"shared_users": [7], "shared_groups": [], "shared_to_org": True}, + ) + + # shared_to_org still forces a POST; the unmapped user is just dropped. + assert tgt_client.share_posts == [ + ( + "adapter/tgt-1/share/", + {"shared_to_org": True, "shared_groups": [], "shared_users": []}, + ) + ] + assert any("ghost@x.com" in w for w in result.warnings) + + +def test_share_unmapped_group_id_is_skipped_with_warning(): + ctx = _ctx(FakeClient(), FakeClient()) + ctx.remap.record("group", "1", "10") + + result = _apply( + ctx, + {"shared_users": [], "shared_groups": [1, 2], "shared_to_org": False}, + ) + + assert ctx.target.share_posts == [ + ( + "adapter/tgt-1/share/", + {"shared_to_org": False, "shared_groups": [10], "shared_users": []}, + ) + ] + assert any("group id 2" in w for w in result.warnings) + + +def test_share_group_axis_omitted_when_group_phase_excluded(): + ctx = _ctx(FakeClient(), FakeClient(), exclude=("group",)) + + result = _apply( + ctx, + {"shared_users": [], "shared_groups": [1], "shared_to_org": True}, + ) + + (path, payload) = ctx.target.share_posts[0] + assert "shared_groups" not in payload # axis untouched on target + assert payload["shared_to_org"] is True + assert any("group phase excluded" in w for w in result.warnings) + + +def test_share_empty_state_skips_the_post(): + ctx = _ctx(FakeClient(), FakeClient()) + + result = _apply( + ctx, + # Owner-only shared_users counts as empty. + { + "created_by": 99, + "shared_users": [99], + "shared_groups": [], + "shared_to_org": False, + }, + ) + + assert ctx.target.share_posts == [] + assert result.failed == 0 + + +def test_share_dry_run_never_posts(): + src_client = FakeClient(users=[{"id": "7", "email": "alice@x.com"}]) + tgt_client = FakeClient(users=[{"id": "70", "email": "alice@x.com"}]) + ctx = _ctx(src_client, tgt_client, dry_run=True) + + _apply( + ctx, + {"shared_users": [7], "shared_groups": [], "shared_to_org": True}, + ) + + assert tgt_client.share_posts == [] + + +def test_share_fetches_source_detail_when_axes_missing_from_list_row(): + ctx = _ctx(FakeClient(), FakeClient()) + detail = { + "shared_users": [], + "shared_groups": [], + "shared_to_org": True, + } + calls = [] + + def fetch_detail(): + calls.append(1) + return detail + + _apply(ctx, {"id": "src-1", "name": "demo"}, src_detail_fn=fetch_detail) + + assert calls == [1] + assert ctx.target.share_posts[0][1]["shared_to_org"] is True + + +def test_share_users_listing_caches_across_resources(): + src_client = FakeClient(users=[{"id": "7", "email": "alice@x.com"}]) + tgt_client = FakeClient(users=[{"id": "70", "email": "alice@x.com"}]) + src_calls = [] + orig = src_client.list_users + src_client.list_users = lambda: (src_calls.append(1), orig())[1] + ctx = _ctx(src_client, tgt_client) + + share = {"shared_users": [7], "shared_groups": [], "shared_to_org": False} + _apply(ctx, dict(share)) + _apply(ctx, dict(share)) + + assert len(src_calls) == 1 # memoised per run, not per resource + assert len(tgt_client.share_posts) == 2 + + +def test_share_post_failure_lands_in_errors(): + src_client = FakeClient(users=[{"id": "7", "email": "alice@x.com"}]) + tgt_client = FakeClient(users=[{"id": "70", "email": "alice@x.com"}]) + + def boom(share_path, payload): + raise RuntimeError("503") + + tgt_client.share_resource = boom + ctx = _ctx(src_client, tgt_client) + + result = _apply( + ctx, {"shared_users": [7], "shared_groups": [], "shared_to_org": False} + ) + + assert result.failed == 1 + assert any("share adapter 'demo'" in e for e in result.errors) + + +def test_server_managed_still_strips_shared_users_on_create(): + src = { + "name": "demo", + "shared_users": [1, 2], + "shared_to_org": True, + } + # Even a (hypothetically) writable shared_users never rides the POST. + payload = build_post_payload(src, frozenset({"name", "shared_users"})) + assert payload == {"name": "demo"} diff --git a/tests/test_cli_top_level.py b/tests/test_cli_top_level.py new file mode 100644 index 0000000..e55d26d --- /dev/null +++ b/tests/test_cli_top_level.py @@ -0,0 +1,50 @@ +"""Tests for the top-level ``unstract`` command group (``unstract.cli``).""" + +from __future__ import annotations + +from click.testing import CliRunner + +from unstract.cli import cli +from unstract.clone.report import CloneReport, Endpoint + + +def test_clone_invocation_via_top_level_group(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["source"] = source + captured["target"] = target + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + # The clone command's callback resolves run_clone from unstract.clone.cli. + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + ], + ) + + assert result.exit_code == 0, result.output + assert captured["source"].organization_id == "src" + assert captured["target"].organization_id == "tgt" diff --git a/uv.lock b/uv.lock index 8710285..99dffef 100644 --- a/uv.lock +++ b/uv.lock @@ -799,14 +799,10 @@ wheels = [ name = "unstract-client" source = { editable = "." } dependencies = [ - { name = "requests" }, - { name = "tenacity" }, -] - -[package.optional-dependencies] -clone = [ { name = "click" }, + { name = "requests" }, { name = "rich" }, + { name = "tenacity" }, ] [package.dev-dependencies] @@ -841,12 +837,11 @@ test = [ [package.metadata] requires-dist = [ - { name = "click", marker = "extra == 'clone'", specifier = ">=8.1" }, + { name = "click", specifier = ">=8.1" }, { name = "requests", specifier = ">=2.32.3" }, - { name = "rich", marker = "extra == 'clone'", specifier = ">=13.7" }, + { name = "rich", specifier = ">=13.7" }, { name = "tenacity", specifier = ">=8.2.0" }, ] -provides-extras = ["clone"] [package.metadata.requires-dev] dev = [