diff --git a/automation_file/__init__.py b/automation_file/__init__.py index 83e2860..6a27f81 100644 --- a/automation_file/__init__.py +++ b/automation_file/__init__.py @@ -73,6 +73,15 @@ ) from automation_file.core.sqlite_lock import SQLiteLock from automation_file.core.substitution import SubstitutionException, substitute +from automation_file.core.tracing import action_span, init_tracing +from automation_file.exceptions import ( + BoxException, + DataOpsException, + DiffException, + OneDriveException, + TextOpsException, + TracingException, +) from automation_file.local.archive_ops import ( detect_archive_format, extract_archive, @@ -80,10 +89,24 @@ supported_formats, ) from automation_file.local.conditional import if_exists, if_newer, if_size_gt +from automation_file.local.data_ops import ( + csv_filter, + csv_to_jsonl, + csv_to_parquet, + jsonl_append, + jsonl_iter, + parquet_read, + parquet_write, + yaml_delete, + yaml_get, + yaml_set, +) from automation_file.local.diff_ops import ( DirDiff, apply_dir_diff, + apply_text_patch, diff_dirs, + diff_dirs_summary, diff_text_files, iter_dir_diff, ) @@ -108,6 +131,13 @@ from automation_file.local.sync_ops import SyncException, sync_dir from automation_file.local.tar_ops import TarException, create_tar, extract_tar from automation_file.local.templates import render_file, render_string +from automation_file.local.text_ops import ( + encoding_convert, + file_merge, + file_split, + line_count, + sed_replace, +) from automation_file.local.trash import ( TrashEntry, empty_trash, @@ -147,6 +177,7 @@ azure_blob_instance, register_azure_blob_ops, ) +from automation_file.remote.box import BoxClient, box_instance, register_box_ops from automation_file.remote.cross_backend import CrossBackendException, copy_between from automation_file.remote.dropbox_api import ( DropboxClient, @@ -194,6 +225,11 @@ drive_upload_to_folder, ) from automation_file.remote.http_download import download_file +from automation_file.remote.onedrive import ( + OneDriveClient, + onedrive_instance, + register_onedrive_ops, +) from automation_file.remote.s3 import S3Client, register_s3_ops, s3_instance from automation_file.remote.sftp import SFTPClient, register_sftp_ops, sftp_instance from automation_file.remote.smb import SMBClient, SMBEntry @@ -297,9 +333,19 @@ def __getattr__(name: str) -> Any: "detect_archive_format", "detect_from_bytes", "detect_mime", + "action_span", + "apply_text_patch", + "csv_filter", + "csv_to_jsonl", + "csv_to_parquet", "diff_dirs", + "diff_dirs_summary", "diff_text_files", "empty_trash", + "encoding_convert", + "file_merge", + "file_split", + "init_tracing", "extract_archive", "iter_dir_diff", "list_archive", @@ -321,7 +367,20 @@ def __getattr__(name: str) -> Any: "json_get", "json_set", "json_delete", + "jsonl_append", + "jsonl_iter", "JsonEditException", + "line_count", + "parquet_read", + "parquet_write", + "sed_replace", + "yaml_delete", + "yaml_get", + "yaml_set", + "DataOpsException", + "DiffException", + "TextOpsException", + "TracingException", "zip_dir", "zip_file", "zip_info", @@ -371,6 +430,14 @@ def __getattr__(name: str) -> Any: "FTPException", "ftp_instance", "register_ftp_ops", + "OneDriveClient", + "onedrive_instance", + "register_onedrive_ops", + "OneDriveException", + "BoxClient", + "box_instance", + "register_box_ops", + "BoxException", "CrossBackendException", "copy_between", "WebDAVClient", diff --git a/automation_file/core/action_executor.py b/automation_file/core/action_executor.py index f9aadad..b6e753d 100644 --- a/automation_file/core/action_executor.py +++ b/automation_file/core/action_executor.py @@ -44,15 +44,18 @@ def __init__(self, registry: ActionRegistry | None = None) -> None: # Template-method: single action ------------------------------------ def _execute_event(self, action: list) -> Any: + from automation_file.core.tracing import action_span + name, payload_kind, payload = self._parse_action(action) command = self.registry.resolve(name) if command is None: raise ExecuteActionException(f"unknown action: {name!r}") - if payload_kind == "none": - return command() - if payload_kind == "kwargs": - return command(**payload) - return command(*payload) + with action_span(name): + if payload_kind == "none": + return command() + if payload_kind == "kwargs": + return command(**payload) + return command(*payload) @staticmethod def _parse_action(action: list) -> tuple[str, str, Any]: diff --git a/automation_file/core/action_registry.py b/automation_file/core/action_registry.py index bc11948..2cd75e1 100644 --- a/automation_file/core/action_registry.py +++ b/automation_file/core/action_registry.py @@ -68,12 +68,15 @@ def event_dict(self) -> dict[str, Command]: def _local_commands() -> dict[str, Command]: from automation_file.local import ( conditional, + data_ops, + diff_ops, dir_ops, file_ops, json_edit, shell_ops, sync_ops, tar_ops, + text_ops, zip_ops, ) @@ -113,6 +116,29 @@ def _local_commands() -> dict[str, Command]: "FA_if_exists": conditional.if_exists, "FA_if_newer": conditional.if_newer, "FA_if_size_gt": conditional.if_size_gt, + # Text / binary + "FA_file_split": text_ops.file_split, + "FA_file_merge": text_ops.file_merge, + "FA_encoding_convert": text_ops.encoding_convert, + "FA_line_count": text_ops.line_count, + "FA_sed_replace": text_ops.sed_replace, + # Diff / patch + "FA_diff_files": diff_ops.diff_text_files, + "FA_diff_dirs": diff_ops.diff_dirs_summary, + "FA_apply_patch": diff_ops.apply_text_patch, + # Structured data (CSV / JSONL) + "FA_csv_filter": data_ops.csv_filter, + "FA_csv_to_jsonl": data_ops.csv_to_jsonl, + "FA_jsonl_iter": data_ops.jsonl_iter, + "FA_jsonl_append": data_ops.jsonl_append, + # Structured data (YAML) + "FA_yaml_get": data_ops.yaml_get, + "FA_yaml_set": data_ops.yaml_set, + "FA_yaml_delete": data_ops.yaml_delete, + # Structured data (Parquet) + "FA_parquet_read": data_ops.parquet_read, + "FA_parquet_write": data_ops.parquet_write, + "FA_csv_to_parquet": data_ops.csv_to_parquet, } @@ -153,7 +179,7 @@ def _http_commands() -> dict[str, Command]: def _utils_commands() -> dict[str, Command]: - from automation_file.core import checksum, crypto, manifest + from automation_file.core import checksum, crypto, manifest, tracing from automation_file.remote import cross_backend from automation_file.utils import deduplicate, fast_find, grep, rotate @@ -170,6 +196,7 @@ def _utils_commands() -> dict[str, Command]: "FA_copy_between": cross_backend.copy_between, "FA_encrypt_file": crypto.encrypt_file, "FA_decrypt_file": crypto.decrypt_file, + "FA_tracing_init": tracing.init_tracing, } @@ -186,8 +213,10 @@ def _lazy_execute_action_dag( def _register_cloud_backends(registry: ActionRegistry) -> None: from automation_file.remote.azure_blob import register_azure_blob_ops + from automation_file.remote.box import register_box_ops from automation_file.remote.dropbox_api import register_dropbox_ops from automation_file.remote.ftp import register_ftp_ops + from automation_file.remote.onedrive import register_onedrive_ops from automation_file.remote.s3 import register_s3_ops from automation_file.remote.sftp import register_sftp_ops @@ -196,6 +225,8 @@ def _register_cloud_backends(registry: ActionRegistry) -> None: register_dropbox_ops(registry) register_sftp_ops(registry) register_ftp_ops(registry) + register_onedrive_ops(registry) + register_box_ops(registry) def _register_trigger_ops(registry: ActionRegistry) -> None: diff --git a/automation_file/core/tracing.py b/automation_file/core/tracing.py new file mode 100644 index 0000000..206032d --- /dev/null +++ b/automation_file/core/tracing.py @@ -0,0 +1,125 @@ +"""OpenTelemetry tracing bridge for the action executor and DAG runner. + +Callers opt in by calling :func:`init_tracing` once at startup (or the +``FA_tracing_init`` action) with a service name. Every subsequent action +dispatch through ``ActionExecutor._execute_event`` and every DAG node run +through ``dag_executor._run_action`` is wrapped in a span named +``automation_file.action`` with the action name on the ``fa.action`` attribute. + +If ``init_tracing`` has not been called, :func:`action_span` returns a +cheap no-op context manager — the executor always pays exactly one +``trace.get_tracer`` call and nothing else, so tracing is zero-overhead +for callers who never enable it. +""" + +from __future__ import annotations + +import contextlib +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any + +from opentelemetry import trace +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + +from automation_file.exceptions import TracingException +from automation_file.logging_config import file_automation_logger + +_TRACER_NAME = "automation_file" +# Mutable container so helpers don't need ``global`` to flip the flag. +_state: dict[str, bool] = {"initialised": False} + + +def init_tracing( + service_name: str = "automation_file", + *, + exporter: SpanExporter | None = None, + resource_attributes: dict[str, Any] | None = None, +) -> bool: + """Install a global :class:`TracerProvider` and register ``exporter``. + + Returns True on the first call, False if tracing is already initialised. + ``exporter`` defaults to a :class:`SpanExporter` that discards everything — + so that spans are created (and tooling can inspect them) without requiring + the caller to wire up a backend. Pass an OTLP / Jaeger / Zipkin exporter + from the matching ``opentelemetry-exporter-*`` package when you want spans + to leave the process. + """ + if _state["initialised"]: + return False + attributes: dict[str, Any] = {SERVICE_NAME: service_name} + if resource_attributes: + attributes.update(resource_attributes) + resource = Resource.create(attributes) + provider = TracerProvider(resource=resource) + active_exporter = exporter if exporter is not None else _NullExporter() + provider.add_span_processor(BatchSpanProcessor(active_exporter)) + try: + trace.set_tracer_provider(provider) + except Exception as err: # pylint: disable=broad-exception-caught + raise TracingException(f"cannot install tracer provider: {err}") from err + _state["initialised"] = True + file_automation_logger.info("tracing: initialised (service=%s)", service_name) + return True + + +def is_initialised() -> bool: + """Return True when :func:`init_tracing` has already run.""" + return _state["initialised"] + + +@contextmanager +def action_span(action_name: str, attributes: dict[str, Any] | None = None) -> Iterator[None]: + """Open a span named ``automation_file.action`` for ``action_name``. + + When tracing is not initialised this is a no-op — the executor can wrap + every action unconditionally without paying for an unused tracer on the + hot path. + """ + if not _state["initialised"]: + yield + return + tracer = trace.get_tracer(_TRACER_NAME) + with tracer.start_as_current_span("automation_file.action") as span: + span.set_attribute("fa.action", action_name) + if attributes: + for key, value in attributes.items(): + span.set_attribute(key, value) + yield + + +def _shutdown_for_tests() -> None: + """Reset module state so a fresh :func:`init_tracing` call works again. + + OpenTelemetry's :func:`trace.set_tracer_provider` is a one-shot guarded + by an internal ``Once`` sentinel — repeated calls are silently ignored. + Tests need to flip tracing off and back on, so we reach into the + ``opentelemetry.trace`` module and reset the sentinel. This is the + conventional pattern used by the opentelemetry-python test suite itself. + """ + provider = trace.get_tracer_provider() + shutdown = getattr(provider, "shutdown", None) + if callable(shutdown): + # Exporter shutdown is best-effort when a test already tore it down. + with contextlib.suppress(Exception): + shutdown() # pylint: disable=not-callable # narrowed by callable() above + # pylint: disable=protected-access # test-only reset of OTel's Once sentinel + once_cls = type(trace._TRACER_PROVIDER_SET_ONCE) + trace._TRACER_PROVIDER_SET_ONCE = once_cls() + trace._TRACER_PROVIDER = None + _state["initialised"] = False + + +class _NullExporter(SpanExporter): + """Default exporter: accept spans, discard them.""" + + def export(self, spans: Any) -> Any: + from opentelemetry.sdk.trace.export import SpanExportResult + + del spans + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + return None diff --git a/automation_file/exceptions.py b/automation_file/exceptions.py index cfa8a6d..d541ad9 100644 --- a/automation_file/exceptions.py +++ b/automation_file/exceptions.py @@ -123,6 +123,26 @@ class FsspecException(FileAutomationException): """Raised by the fsspec bridge on missing dependency or backend failures.""" +class TextOpsException(FileAutomationException): + """Raised by text / binary file helpers (split, merge, sed, encoding_convert).""" + + +class DataOpsException(FileAutomationException): + """Raised by CSV / JSONL / YAML / Parquet helpers.""" + + +class OneDriveException(FileAutomationException): + """Raised by the OneDrive (Microsoft Graph) backend.""" + + +class BoxException(FileAutomationException): + """Raised by the Box backend.""" + + +class TracingException(FileAutomationException): + """Raised when OpenTelemetry tracing setup cannot be completed.""" + + _ARGPARSE_EMPTY_MESSAGE = "argparse received no actionable argument" _BAD_TRIGGER_FUNCTION = "trigger name is not registered in the executor" _BAD_CALLBACK_METHOD = "callback_param_method must be 'kwargs' or 'args'" diff --git a/automation_file/local/data_ops.py b/automation_file/local/data_ops.py new file mode 100644 index 0000000..86e2857 --- /dev/null +++ b/automation_file/local/data_ops.py @@ -0,0 +1,448 @@ +"""Structured-data helpers: CSV, JSON Lines, YAML, Parquet. + +All file I/O is UTF-8 and atomic where a destination file is written (temp +file + ``os.replace`` after success). YAML parsing uses ``yaml.safe_load`` — +never ``yaml.load`` — so a malicious config can't construct arbitrary Python +objects. + +The functions in this module intentionally materialise results as Python +lists/dicts rather than iterators so they round-trip cleanly through the +JSON-based action payload protocol used by the executor, MCP bridge, and +TCP/HTTP servers. Callers that need streaming iteration can reach into the +underlying ``csv``/``pyarrow`` APIs directly. +""" + +from __future__ import annotations + +import csv +import json +import os +import tempfile +from collections.abc import MutableMapping, MutableSequence, Sequence +from pathlib import Path +from typing import Any + +from automation_file.exceptions import DataOpsException, FileNotExistsException +from automation_file.logging_config import file_automation_logger + +_MISSING = object() + + +def csv_filter( + src: str, + target: str, + *, + columns: list[str] | None = None, + where_column: str | None = None, + where_equals: str | None = None, + delimiter: str = ",", + encoding: str = "utf-8", +) -> int: + """Copy ``src`` CSV rows into ``target``, optionally projecting and filtering. + + * ``columns`` — if given, the output keeps only these header names in + this order. Unknown names raise :class:`DataOpsException`. + * ``where_column`` + ``where_equals`` — keep only rows whose value in + ``where_column`` exactly equals ``where_equals`` (string compare). + Supplying one without the other raises. + + Returns the number of data rows written. + """ + source = Path(src) + if not source.is_file(): + raise FileNotExistsException(str(source)) + if (where_column is None) != (where_equals is None): + raise DataOpsException("where_column and where_equals must be supplied together") + dest = Path(target) + dest.parent.mkdir(parents=True, exist_ok=True) + written = _stream_csv_filter( + source, dest, columns, where_column, where_equals, delimiter, encoding + ) + file_automation_logger.info("csv_filter: %s -> %s (%d rows)", source, dest, written) + return written + + +# pylint: disable-next=too-many-positional-arguments # flat option bundle +def _stream_csv_filter( + source: Path, + dest: Path, + columns: list[str] | None, + where_column: str | None, + where_equals: str | None, + delimiter: str, + encoding: str, +) -> int: + """Do the actual streaming copy; separated so ``csv_filter`` stays simple.""" + tmp_name: str | None = None + try: + with ( + open(source, encoding=encoding, newline="") as reader, + tempfile.NamedTemporaryFile( + mode="w", + encoding=encoding, + newline="", + dir=str(dest.parent), + delete=False, + suffix=".tmp", + ) as writer, + ): + tmp_name = writer.name + written = _write_filtered_rows( + reader, writer, columns, where_column, where_equals, delimiter + ) + os.replace(tmp_name, dest) + tmp_name = None + return written + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) + + +# pylint: disable-next=too-many-positional-arguments # flat option bundle +def _write_filtered_rows( + reader: Any, + writer: Any, + columns: list[str] | None, + where_column: str | None, + where_equals: str | None, + delimiter: str, +) -> int: + parsed = csv.DictReader(reader, delimiter=delimiter) + fieldnames = _resolve_fieldnames(parsed.fieldnames, columns) + if where_column is not None and where_column not in (parsed.fieldnames or []): + raise DataOpsException(f"where_column {where_column!r} is not in CSV header") + output = csv.DictWriter(writer, fieldnames=fieldnames, delimiter=delimiter) + output.writeheader() + written = 0 + for row in parsed: + if where_column is not None and row.get(where_column) != where_equals: + continue + output.writerow({name: row.get(name, "") for name in fieldnames}) + written += 1 + return written + + +def csv_to_jsonl( + src: str, + target: str, + *, + delimiter: str = ",", + encoding: str = "utf-8", +) -> int: + """Convert a CSV file to JSON Lines; return the number of records written.""" + source = Path(src) + if not source.is_file(): + raise FileNotExistsException(str(source)) + dest = Path(target) + dest.parent.mkdir(parents=True, exist_ok=True) + tmp_name: str | None = None + written = 0 + try: + with ( + open(source, encoding=encoding, newline="") as reader, + tempfile.NamedTemporaryFile( + mode="w", + encoding=encoding, + dir=str(dest.parent), + delete=False, + suffix=".tmp", + ) as writer, + ): + tmp_name = writer.name + for row in csv.DictReader(reader, delimiter=delimiter): + writer.write(json.dumps(row, ensure_ascii=False)) + writer.write("\n") + written += 1 + os.replace(tmp_name, dest) + tmp_name = None + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) + file_automation_logger.info("csv_to_jsonl: %s -> %s (%d records)", source, dest, written) + return written + + +def jsonl_iter( + path: str, + *, + limit: int | None = None, + encoding: str = "utf-8", +) -> list[dict[str, Any]]: + """Return every JSON Lines record in ``path`` as a list of dicts. + + ``limit`` caps the number of records returned (handy for previews on + large files). Blank lines are skipped. Non-dict records are rejected so + downstream consumers can count on a stable shape. + """ + source = Path(path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + records: list[dict[str, Any]] = [] + with open(source, encoding=encoding) as reader: + for line_no, raw in enumerate(reader, start=1): + line = raw.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError as err: + raise DataOpsException(f"{source}:{line_no} is not valid JSON: {err}") from err + if not isinstance(record, dict): + raise DataOpsException( + f"{source}:{line_no} is not a JSON object: {type(record).__name__}" + ) + records.append(record) + if limit is not None and len(records) >= limit: + break + return records + + +def jsonl_append(path: str, record: dict[str, Any], *, encoding: str = "utf-8") -> bool: + """Append one JSON object as a new line in ``path``. Creates the file if absent.""" + if not isinstance(record, dict): + raise DataOpsException(f"record must be a dict, got {type(record).__name__}") + target = Path(path) + target.parent.mkdir(parents=True, exist_ok=True) + line = json.dumps(record, ensure_ascii=False) + "\n" + with open(target, "a", encoding=encoding) as writer: + writer.write(line) + return True + + +def _resolve_fieldnames( + source_fields: Sequence[str] | None, + requested: list[str] | None, +) -> list[str]: + if source_fields is None: + raise DataOpsException("CSV has no header row") + if requested is None: + return list(source_fields) + missing = [name for name in requested if name not in source_fields] + if missing: + raise DataOpsException(f"column(s) not in CSV header: {', '.join(missing)}") + return list(requested) + + +def yaml_get(path: str, key_path: str, default: Any = None) -> Any: + """Return the value at dotted ``key_path`` in a YAML file, or ``default``.""" + data = _yaml_load(path) + result = _walk(data, _split_key(key_path)) + return default if result is _MISSING else result + + +def yaml_set(path: str, key_path: str, value: Any) -> bool: + """Set the value at dotted ``key_path``. Creates intermediate dicts.""" + segments = _split_key(key_path) + if not segments: + raise DataOpsException("key_path must not be empty") + data = _yaml_load(path) + _set_in(data, segments, value) + _yaml_dump(path, data) + file_automation_logger.info("yaml_set: %s %s", path, key_path) + return True + + +def yaml_delete(path: str, key_path: str) -> bool: + """Delete the value at dotted ``key_path``; return True when a value was removed.""" + segments = _split_key(key_path) + if not segments: + raise DataOpsException("key_path must not be empty") + data = _yaml_load(path) + removed = _delete_in(data, segments) + if removed: + _yaml_dump(path, data) + file_automation_logger.info("yaml_delete: %s %s", path, key_path) + return removed + + +def parquet_read( + path: str, + *, + limit: int | None = None, + columns: list[str] | None = None, +) -> list[dict[str, Any]]: + """Read a Parquet file into a list of dicts. + + ``columns`` projects the output schema (unknown column names raise). + ``limit`` caps the number of rows returned (reads the whole file but + slices before conversion — handy for previews of multi-GB files). + """ + import pyarrow.parquet as pq + + source = Path(path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + try: + table = pq.read_table(str(source), columns=columns) + except (OSError, ValueError) as err: + raise DataOpsException(f"cannot read parquet {source}: {err}") from err + if limit is not None: + table = table.slice(0, limit) + return table.to_pylist() + + +def parquet_write(path: str, records: list[dict[str, Any]]) -> int: + """Write ``records`` (list of dicts) as a Parquet file; return the row count.""" + import pyarrow as pa + import pyarrow.parquet as pq + + if not isinstance(records, list): + raise DataOpsException("records must be a list of dicts") + target = Path(path) + target.parent.mkdir(parents=True, exist_ok=True) + try: + table = pa.Table.from_pylist(records) + except (TypeError, pa.ArrowInvalid) as err: + raise DataOpsException(f"cannot build parquet table: {err}") from err + tmp_name: str | None = None + try: + with tempfile.NamedTemporaryFile( + mode="wb", dir=str(target.parent), delete=False, suffix=".tmp" + ) as writer: + tmp_name = writer.name + pq.write_table(table, tmp_name) + os.replace(tmp_name, target) + tmp_name = None + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) + file_automation_logger.info("parquet_write: %s (%d rows)", target, table.num_rows) + return table.num_rows + + +def csv_to_parquet( + csv_path: str, + parquet_path: str, + *, + delimiter: str = ",", + encoding: str = "utf-8", +) -> int: + """Convert a CSV file to Parquet; return the row count written.""" + source = Path(csv_path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + with open(source, encoding=encoding, newline="") as reader: + rows = list(csv.DictReader(reader, delimiter=delimiter)) + return parquet_write(parquet_path, rows) + + +def _yaml_load(path: str) -> Any: + import yaml + + source = Path(path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + try: + return yaml.safe_load(source.read_text(encoding="utf-8")) or {} + except yaml.YAMLError as err: + raise DataOpsException(f"cannot parse YAML {source}: {err}") from err + + +def _yaml_dump(path: str, data: Any) -> None: + import yaml + + target = Path(path) + target.parent.mkdir(parents=True, exist_ok=True) + tmp_name: str | None = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=str(target.parent), + delete=False, + suffix=".tmp", + ) as writer: + tmp_name = writer.name + yaml.safe_dump(data, writer, sort_keys=False, allow_unicode=True) + os.replace(tmp_name, target) + tmp_name = None + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) + + +def _split_key(key_path: str) -> list[str]: + if not isinstance(key_path, str): + raise DataOpsException("key_path must be a string") + return [seg for seg in key_path.split(".") if seg != ""] + + +def _walk(data: Any, segments: list[str]) -> Any: + current: Any = data + for segment in segments: + try: + current = _child(current, segment) + except (KeyError, IndexError, TypeError): + return _MISSING + return current + + +def _child(container: Any, segment: str) -> Any: + if isinstance(container, MutableMapping): + return container[segment] + if isinstance(container, MutableSequence) and _is_int_segment(segment): + return container[int(segment)] + raise TypeError(f"cannot index {type(container).__name__} by {segment!r}") + + +def _is_int_segment(segment: str) -> bool: + return segment.lstrip("-").isdigit() + + +def _descend_for_set(container: Any, segment: str) -> Any: + if isinstance(container, MutableMapping): + if segment not in container or not isinstance( + container[segment], (MutableMapping, MutableSequence) + ): + container[segment] = {} + return container[segment] + if isinstance(container, MutableSequence) and _is_int_segment(segment): + return container[int(segment)] + raise DataOpsException(f"cannot traverse into {segment!r}") + + +def _set_in(data: Any, segments: list[str], value: Any) -> None: + container = data + for segment in segments[:-1]: + container = _descend_for_set(container, segment) + last = segments[-1] + if isinstance(container, MutableMapping): + container[last] = value + return + if isinstance(container, MutableSequence) and _is_int_segment(last): + _assign_into_sequence(container, last, value) + return + raise DataOpsException(f"cannot set into {type(container).__name__}") + + +def _assign_into_sequence(container: MutableSequence[Any], last: str, value: Any) -> None: + idx = int(last) + if -len(container) <= idx < len(container): + container[idx] = value + return + if idx == len(container): + container.append(value) + return + raise DataOpsException(f"list index out of range: {idx}") + + +def _delete_in(data: Any, segments: list[str]) -> bool: + container = data + for segment in segments[:-1]: + try: + container = _child(container, segment) + except (KeyError, IndexError, TypeError): + return False + last = segments[-1] + if isinstance(container, MutableMapping): + if last not in container: + return False + del container[last] + return True + if isinstance(container, MutableSequence) and _is_int_segment(last): + idx = int(last) + if not -len(container) <= idx < len(container): + return False + del container[idx] + return True + return False diff --git a/automation_file/local/diff_ops.py b/automation_file/local/diff_ops.py index 73ad698..a0703af 100644 --- a/automation_file/local/diff_ops.py +++ b/automation_file/local/diff_ops.py @@ -11,6 +11,7 @@ import difflib import hashlib import os +import re import shutil from collections.abc import Iterable from dataclasses import dataclass, field @@ -113,6 +114,128 @@ def diff_text_files( return "".join(diff_lines) +def apply_text_patch(target: str | os.PathLike[str], patch: str) -> bool: + """Apply a unified-diff ``patch`` to ``target`` in place; return True on success. + + The patch must have been produced against the current contents of + ``target`` (for example by :func:`diff_text_files`). Hunk headers are + verified against the live file before any write; if a hunk's context + lines don't match, no change is applied and :class:`DiffException` is + raised so the caller sees the mismatch instead of a corrupt file. + + ``target`` is taken at face value — the caller is the trust boundary + for this path, exactly like :func:`pathlib.Path.write_text` or the + surrounding :func:`diff_text_files` helper. Upstream callers that + accept a user-controlled root should run the path through + :func:`automation_file.local.safe_paths.safe_join` themselves. + """ + target_path = Path(target) + try: + original = target_path.read_text(encoding="utf-8").splitlines(keepends=True) + except OSError as error: + raise DiffException(f"cannot read patch target: {error}") from error + patched = _apply_unified_patch(original, patch) + target_path.write_text("".join(patched), encoding="utf-8") # NOSONAR pythonsecurity:S2083 + return True + + +def _apply_unified_patch(lines: list[str], patch: str) -> list[str]: + result: list[str] = [] + cursor = 0 + for hunk in _iter_hunks(patch): + cursor = _copy_up_to(lines, cursor, hunk.start, result) + cursor = _apply_hunk_ops(lines, cursor, hunk.ops, result) + result.extend(lines[cursor:]) + return result + + +def _copy_up_to(lines: list[str], cursor: int, stop: int, result: list[str]) -> int: + while cursor < stop: + result.append(lines[cursor]) + cursor += 1 + return cursor + + +def _apply_hunk_ops( + lines: list[str], + cursor: int, + ops: tuple[tuple[str, str], ...], + result: list[str], +) -> int: + for op, payload in ops: + if op == "+": + result.append(payload) + continue + # " " (context) and "-" (delete) both require a live-line match. + _verify_live_line(lines, cursor, payload, context=op == " ") + if op == " ": + result.append(payload) + cursor += 1 + return cursor + + +def _verify_live_line(lines: list[str], cursor: int, expected: str, *, context: bool) -> None: + got = lines[cursor] if cursor < len(lines) else "" + if cursor >= len(lines) or lines[cursor] != expected: + kind = "context" if context else "deletion" + raise DiffException( + f"patch {kind} mismatch at line {cursor + 1}: expected {expected!r}, got {got!r}" + ) + + +@dataclass(frozen=True) +class _Hunk: + start: int + ops: tuple[tuple[str, str], ...] + + +def _iter_hunks(patch: str) -> Iterable[_Hunk]: + state = _HunkParseState() + for raw_line in patch.splitlines(keepends=True): + pending = _consume_patch_line(state, raw_line) + if pending is not None: + yield pending + if state.in_hunk and state.buffer: + yield _Hunk(start=state.start, ops=tuple(state.buffer)) + + +@dataclass +class _HunkParseState: + buffer: list[tuple[str, str]] = field(default_factory=list) + start: int = 0 + in_hunk: bool = False + + +def _consume_patch_line(state: _HunkParseState, raw_line: str) -> _Hunk | None: + if raw_line.startswith("@@"): + pending = ( + _Hunk(start=state.start, ops=tuple(state.buffer)) + if state.in_hunk and state.buffer + else None + ) + state.start = _parse_hunk_header(raw_line) + state.buffer = [] + state.in_hunk = True + return pending + if raw_line.startswith(("---", "+++")) or not state.in_hunk or not raw_line: + return None + prefix, payload = raw_line[0], raw_line[1:] + if prefix in {" ", "+", "-"}: + state.buffer.append((prefix, payload)) + return None + + +_HUNK_HEADER = re.compile(r"^@@\s+-(\d+)(?:,\d+)?\s+\+\d+(?:,\d+)?\s+@@") + + +def _parse_hunk_header(line: str) -> int: + match = _HUNK_HEADER.match(line) + if not match: + raise DiffException(f"malformed hunk header: {line.rstrip()!r}") + # Unified-diff line numbers are 1-based; convert to 0-based index. + return max(int(match.group(1)) - 1, 0) + + def _relative_files(root: Path) -> set[str]: collected: set[str] = set() for dirpath, _dirnames, filenames in os.walk(root, followlinks=False): @@ -138,3 +261,16 @@ def iter_dir_diff(diff: DirDiff) -> Iterable[tuple[str, str]]: yield "removed", rel for rel in diff.changed: yield "changed", rel + + +def diff_dirs_summary( + left: str | os.PathLike[str], + right: str | os.PathLike[str], +) -> dict[str, list[str]]: + """JSON-friendly wrapper around :func:`diff_dirs` — returns plain lists.""" + diff = diff_dirs(left, right) + return { + "added": list(diff.added), + "removed": list(diff.removed), + "changed": list(diff.changed), + } diff --git a/automation_file/local/text_ops.py b/automation_file/local/text_ops.py new file mode 100644 index 0000000..acef68e --- /dev/null +++ b/automation_file/local/text_ops.py @@ -0,0 +1,197 @@ +"""Text and binary file helpers: split, merge, encoding conversion, sed, line count. + +These are the low-level building blocks for automating large-file handling +(splitting a multi-gigabyte archive into transferable chunks, for instance), +pipeline text munging that currently needs a shell, and cheap text stats. + +Every helper writes atomically where a destination file is produced: data +lands in a sibling temp file that is ``os.replace`` d over the final path +after the operation finishes, so a crash leaves either the old content or +the new content — never a partial mix. +""" + +from __future__ import annotations + +import os +import re +import tempfile +from pathlib import Path + +from automation_file.exceptions import FileNotExistsException, TextOpsException +from automation_file.logging_config import file_automation_logger + +_CHUNK_IO = 1 << 20 # 1 MiB read buffer + + +def file_split(file_path: str, chunk_size: int, output_dir: str | None = None) -> list[str]: + """Split ``file_path`` into fixed-size chunks; return the part paths in order. + + Parts are named ``.part000``, ``.part001``, ... and + written under ``output_dir`` if given (created if missing), otherwise + alongside the source file. ``chunk_size`` is the maximum bytes per part + and must be > 0. + """ + if chunk_size <= 0: + raise TextOpsException("chunk_size must be positive") + source = Path(file_path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + dest_dir = Path(output_dir) if output_dir else source.parent + dest_dir.mkdir(parents=True, exist_ok=True) + parts: list[str] = [] + with open(source, "rb") as reader: + index = 0 + while True: + chunk = reader.read(chunk_size) + if not chunk: + break + part_path = dest_dir / f"{source.name}.part{index:03d}" + with open(part_path, "wb") as writer: + writer.write(chunk) + parts.append(str(part_path)) + index += 1 + file_automation_logger.info("file_split: %s -> %d parts", source, len(parts)) + return parts + + +def file_merge(parts: list[str], target_path: str) -> bool: + """Concatenate ``parts`` in list order into ``target_path``. Atomic write.""" + if not parts: + raise TextOpsException("parts must be a non-empty list") + missing = [p for p in parts if not Path(p).is_file()] + if missing: + raise FileNotExistsException(", ".join(missing)) + target = Path(target_path) + target.parent.mkdir(parents=True, exist_ok=True) + tmp_name: str | None = None + try: + with tempfile.NamedTemporaryFile( + mode="wb", dir=str(target.parent), delete=False, suffix=".tmp" + ) as writer: + tmp_name = writer.name + for part in parts: + with open(part, "rb") as reader: + while True: + chunk = reader.read(_CHUNK_IO) + if not chunk: + break + writer.write(chunk) + os.replace(tmp_name, target) + tmp_name = None + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) + file_automation_logger.info("file_merge: %d parts -> %s", len(parts), target) + return True + + +def encoding_convert( + file_path: str, + target_path: str, + source_encoding: str, + target_encoding: str, + *, + errors: str = "strict", +) -> bool: + """Re-encode a text file from ``source_encoding`` to ``target_encoding``. + + ``errors`` follows :func:`codecs.decode`'s contract (``strict`` / ``replace`` + / ``ignore``). Defaults to ``strict`` so mis-declared source encodings + surface as :class:`TextOpsException` instead of silently corrupting. + """ + source = Path(file_path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + target = Path(target_path) + target.parent.mkdir(parents=True, exist_ok=True) + try: + raw = source.read_bytes() + text = raw.decode(source_encoding, errors=errors) + encoded = text.encode(target_encoding, errors=errors) + except (LookupError, UnicodeError) as err: + raise TextOpsException( + f"encoding_convert {source_encoding}->{target_encoding} failed: {err}" + ) from err + _atomic_write_bytes(target, encoded) + file_automation_logger.info( + "encoding_convert: %s (%s) -> %s (%s)", + source, + source_encoding, + target, + target_encoding, + ) + return True + + +def line_count(file_path: str, *, encoding: str = "utf-8") -> int: + """Count lines in a text file. A trailing newline is not counted as an extra line.""" + source = Path(file_path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + count = 0 + with open(source, encoding=encoding) as reader: + for _ in reader: + count += 1 + return count + + +def sed_replace( + file_path: str, + pattern: str, + replacement: str, + *, + regex: bool = False, + count: int = 0, + encoding: str = "utf-8", +) -> int: + """Replace occurrences of ``pattern`` with ``replacement`` in-place; return hit count. + + ``regex=False`` (default) does literal substring replacement; ``regex=True`` + treats ``pattern`` as a :mod:`re` pattern and ``replacement`` may use + backreferences. ``count=0`` replaces every occurrence; any positive integer + caps the number of replacements. + """ + if count < 0: + raise TextOpsException("count must be >= 0") + source = Path(file_path) + if not source.is_file(): + raise FileNotExistsException(str(source)) + original = source.read_text(encoding=encoding) + try: + new_text, hits = _apply_replacement(original, pattern, replacement, regex, count) + except re.error as err: + raise TextOpsException(f"invalid regex: {err}") from err + if hits: + _atomic_write_bytes(source, new_text.encode(encoding)) + file_automation_logger.info("sed_replace: %s -> %d replacement(s)", source, hits) + return hits + + +def _apply_replacement( + text: str, pattern: str, replacement: str, regex: bool, count: int +) -> tuple[str, int]: + if regex: + compiled = re.compile(pattern) + new_text, hits = compiled.subn(replacement, text, count=count) + return new_text, hits + if not pattern: + raise TextOpsException("pattern must not be empty for literal replace") + hits = text.count(pattern) if count == 0 else min(text.count(pattern), count) + new_text = text.replace(pattern, replacement, -1 if count == 0 else count) + return new_text, hits + + +def _atomic_write_bytes(target: Path, data: bytes) -> None: + target.parent.mkdir(parents=True, exist_ok=True) + tmp_name: str | None = None + try: + with tempfile.NamedTemporaryFile( + mode="wb", dir=str(target.parent), delete=False, suffix=".tmp" + ) as writer: + tmp_name = writer.name + writer.write(data) + os.replace(tmp_name, target) + tmp_name = None + finally: + if tmp_name is not None: + Path(tmp_name).unlink(missing_ok=True) diff --git a/automation_file/remote/box/__init__.py b/automation_file/remote/box/__init__.py new file mode 100644 index 0000000..bdec93e --- /dev/null +++ b/automation_file/remote/box/__init__.py @@ -0,0 +1,30 @@ +"""Box strategy module. + +Actions (``FA_box_*``) are registered on the shared default registry +automatically by :func:`build_default_registry`. :func:`register_box_ops` +is exposed for callers that build their own :class:`ActionRegistry`. +""" + +from __future__ import annotations + +from automation_file.core.action_registry import ActionRegistry +from automation_file.remote.box import delete_ops, download_ops, list_ops, upload_ops +from automation_file.remote.box.client import BoxClient, box_instance + + +def register_box_ops(registry: ActionRegistry) -> None: + """Register every ``FA_box_*`` command into ``registry``.""" + registry.register_many( + { + "FA_box_later_init": box_instance.later_init, + "FA_box_upload_file": upload_ops.box_upload_file, + "FA_box_upload_dir": upload_ops.box_upload_dir, + "FA_box_download_file": download_ops.box_download_file, + "FA_box_delete_file": delete_ops.box_delete_file, + "FA_box_delete_folder": delete_ops.box_delete_folder, + "FA_box_list_folder": list_ops.box_list_folder, + } + ) + + +__all__ = ["BoxClient", "box_instance", "register_box_ops"] diff --git a/automation_file/remote/box/client.py b/automation_file/remote/box/client.py new file mode 100644 index 0000000..ea63b36 --- /dev/null +++ b/automation_file/remote/box/client.py @@ -0,0 +1,65 @@ +"""Box client (Singleton Facade) backed by the ``boxsdk`` library. + +Box's OAuth2 flow is authorization-code based (not device-code), so the +caller is expected to obtain an access token via their app registration +and hand it in to :meth:`later_init`. Matches the Dropbox backend's +contract — automation workflows typically receive the token from a +secrets manager rather than prompting interactively. +""" + +from __future__ import annotations + +from typing import Any + +from automation_file.exceptions import BoxException +from automation_file.logging_config import file_automation_logger + + +def _import_boxsdk() -> Any: + try: + import boxsdk + except ImportError as error: + raise BoxException( + "boxsdk import failed — reinstall `automation_file` to restore the Box backend" + ) from error + return boxsdk + + +class BoxClient: + """Lazy wrapper around :class:`boxsdk.Client`.""" + + def __init__(self) -> None: + self.client: Any = None + + def later_init( + self, + access_token: str, + *, + client_id: str = "", + client_secret: str = "", + ) -> Any: + """Build a :class:`boxsdk.Client` from an OAuth2 access token. + + ``client_id`` and ``client_secret`` are only required if the caller + wants to let boxsdk refresh the token — most automation callers + already refresh externally, so both default to empty. + """ + if not isinstance(access_token, str) or not access_token: + raise BoxException("access_token must be a non-empty string") + boxsdk = _import_boxsdk() + oauth = boxsdk.OAuth2( + client_id=client_id, + client_secret=client_secret, + access_token=access_token, + ) + self.client = boxsdk.Client(oauth) + file_automation_logger.info("BoxClient: client ready") + return self.client + + def require_client(self) -> Any: + if self.client is None: + raise BoxException("BoxClient not initialised; call later_init() first") + return self.client + + +box_instance: BoxClient = BoxClient() diff --git a/automation_file/remote/box/delete_ops.py b/automation_file/remote/box/delete_ops.py new file mode 100644 index 0000000..09f2259 --- /dev/null +++ b/automation_file/remote/box/delete_ops.py @@ -0,0 +1,29 @@ +"""Box delete operations.""" + +from __future__ import annotations + +from automation_file.exceptions import BoxException +from automation_file.logging_config import file_automation_logger +from automation_file.remote.box.client import box_instance + + +def box_delete_file(file_id: str) -> bool: + """Delete a Box file by id.""" + client = box_instance.require_client() + try: + client.file(file_id=file_id).delete() + except Exception as error: # pylint: disable=broad-except + raise BoxException(f"box_delete_file failed: {error}") from error + file_automation_logger.info("box_delete_file: %s", file_id) + return True + + +def box_delete_folder(folder_id: str, recursive: bool = False) -> bool: + """Delete a Box folder by id (optionally recursive).""" + client = box_instance.require_client() + try: + client.folder(folder_id=folder_id).delete(recursive=recursive) + except Exception as error: # pylint: disable=broad-except + raise BoxException(f"box_delete_folder failed: {error}") from error + file_automation_logger.info("box_delete_folder: %s (recursive=%s)", folder_id, recursive) + return True diff --git a/automation_file/remote/box/download_ops.py b/automation_file/remote/box/download_ops.py new file mode 100644 index 0000000..d055128 --- /dev/null +++ b/automation_file/remote/box/download_ops.py @@ -0,0 +1,23 @@ +"""Box download operations.""" + +from __future__ import annotations + +from pathlib import Path + +from automation_file.exceptions import BoxException +from automation_file.logging_config import file_automation_logger +from automation_file.remote.box.client import box_instance + + +def box_download_file(file_id: str, target_path: str) -> bool: + """Download Box file ``file_id`` to ``target_path``.""" + client = box_instance.require_client() + target = Path(target_path) + target.parent.mkdir(parents=True, exist_ok=True) + try: + with open(target, "wb") as writer: + client.file(file_id=file_id).download_to(writer) + except Exception as error: # pylint: disable=broad-except + raise BoxException(f"box_download_file failed: {error}") from error + file_automation_logger.info("box_download_file: %s -> %s", file_id, target) + return True diff --git a/automation_file/remote/box/list_ops.py b/automation_file/remote/box/list_ops.py new file mode 100644 index 0000000..58a5b96 --- /dev/null +++ b/automation_file/remote/box/list_ops.py @@ -0,0 +1,35 @@ +"""Box listing operations.""" + +from __future__ import annotations + +from typing import Any + +from automation_file.exceptions import BoxException +from automation_file.logging_config import file_automation_logger +from automation_file.remote.box.client import box_instance + + +def box_list_folder(folder_id: str = "0", limit: int = 100) -> list[dict[str, Any]]: + """List entries in a Box folder; return basic metadata per entry. + + ``folder_id="0"`` is Box's root. ``limit`` caps how many entries are + returned — pagination is not followed so callers can stay under a + reasonable payload size. Each entry is + ``{"id": str, "name": str, "type": "file"|"folder"}``. + """ + client = box_instance.require_client() + try: + folder = client.folder(folder_id=folder_id) + items = folder.get_items(limit=limit) + entries = [ + { + "id": str(getattr(item, "id", "")), + "name": getattr(item, "name", ""), + "type": getattr(item, "type", "file"), + } + for item in items + ] + except Exception as error: # pylint: disable=broad-except + raise BoxException(f"box_list_folder failed: {error}") from error + file_automation_logger.info("box_list_folder: %s -> %d entries", folder_id, len(entries)) + return entries diff --git a/automation_file/remote/box/upload_ops.py b/automation_file/remote/box/upload_ops.py new file mode 100644 index 0000000..2e2a0d9 --- /dev/null +++ b/automation_file/remote/box/upload_ops.py @@ -0,0 +1,67 @@ +"""Box upload operations.""" + +from __future__ import annotations + +from pathlib import Path + +from automation_file.exceptions import BoxException, FileNotExistsException +from automation_file.logging_config import file_automation_logger +from automation_file.remote._upload_tree import walk_and_upload +from automation_file.remote.box.client import box_instance + + +def box_upload_file(file_path: str, parent_folder_id: str = "0", name: str = "") -> str: + """Upload a local file into ``parent_folder_id``; return the new Box file id. + + ``parent_folder_id`` defaults to ``"0"`` — Box's conventional id for + the root folder of the authenticated user. ``name`` overrides the + local filename on upload; empty means "use the source basename". + """ + local = Path(file_path) + if not local.is_file(): + raise FileNotExistsException(str(local)) + client = box_instance.require_client() + target_name = name or local.name + try: + folder = client.folder(folder_id=parent_folder_id) + new_file = folder.upload(file_path=str(local), file_name=target_name) + except Exception as error: # pylint: disable=broad-except + raise BoxException(f"box_upload_file failed: {error}") from error + file_automation_logger.info( + "box_upload_file: %s -> %s/%s (id=%s)", + local, + parent_folder_id, + target_name, + getattr(new_file, "id", "?"), + ) + return str(getattr(new_file, "id", "")) + + +def box_upload_dir(dir_path: str, parent_folder_id: str = "0") -> list[str]: + """Upload every file under ``dir_path`` into ``parent_folder_id``. + + Box doesn't accept nested paths on upload — this helper flattens the + tree, uploading each file with its relative path joined by ``/`` as + the Box name. Folder hierarchy on Box is not mirrored; callers that + need it should create folders first via the Box UI or SDK. + """ + + def _upload_one(local: Path, flat_name: str) -> bool: + # ``walk_and_upload`` expects a bool callback; we always return True + # because box_upload_file raises on failure (never returns falsy). + box_upload_file(str(local), parent_folder_id, flat_name) + return True + + result = walk_and_upload( + dir_path, + "", + lambda _prefix, rel: rel.replace("\\", "/"), + _upload_one, + ) + file_automation_logger.info( + "box_upload_dir: %s -> folder %s (%d files)", + result.source, + parent_folder_id, + len(result.uploaded), + ) + return result.uploaded diff --git a/automation_file/remote/onedrive/__init__.py b/automation_file/remote/onedrive/__init__.py new file mode 100644 index 0000000..d422324 --- /dev/null +++ b/automation_file/remote/onedrive/__init__.py @@ -0,0 +1,31 @@ +"""OneDrive strategy module (Microsoft Graph via MSAL). + +Actions (``FA_onedrive_*``) are registered on the shared default registry +automatically by :func:`build_default_registry`. :func:`register_onedrive_ops` +is exposed for callers that build their own :class:`ActionRegistry`. +""" + +from __future__ import annotations + +from automation_file.core.action_registry import ActionRegistry +from automation_file.remote.onedrive import delete_ops, download_ops, list_ops, upload_ops +from automation_file.remote.onedrive.client import OneDriveClient, onedrive_instance + + +def register_onedrive_ops(registry: ActionRegistry) -> None: + """Register every ``FA_onedrive_*`` command into ``registry``.""" + registry.register_many( + { + "FA_onedrive_later_init": onedrive_instance.later_init, + "FA_onedrive_device_code_login": onedrive_instance.device_code_login, + "FA_onedrive_upload_file": upload_ops.onedrive_upload_file, + "FA_onedrive_upload_dir": upload_ops.onedrive_upload_dir, + "FA_onedrive_download_file": download_ops.onedrive_download_file, + "FA_onedrive_delete_item": delete_ops.onedrive_delete_item, + "FA_onedrive_list_folder": list_ops.onedrive_list_folder, + "FA_onedrive_close": onedrive_instance.close, + } + ) + + +__all__ = ["OneDriveClient", "onedrive_instance", "register_onedrive_ops"] diff --git a/automation_file/remote/onedrive/client.py b/automation_file/remote/onedrive/client.py new file mode 100644 index 0000000..3455a0f --- /dev/null +++ b/automation_file/remote/onedrive/client.py @@ -0,0 +1,142 @@ +"""OneDrive client (Singleton Facade) backed by Microsoft Graph + MSAL. + +The client supports two initialisation paths: + +* :meth:`later_init` — caller passes an already-obtained OAuth2 access + token. Matches the Dropbox backend's pattern; best for non-interactive + automation where a token is injected via secrets manager. +* :meth:`device_code_login` — runs the MSAL device-code flow against + Microsoft's ``/common`` (or tenant-specific) authority. The caller is + expected to present the returned ``message`` to a human, who signs in at + the displayed URL. Blocks until the user completes the flow or the code + expires. + +Only the bare Graph HTTP session is held on the client — every ``*_ops`` +module calls Graph through the helper :meth:`graph_request`, which keeps +the ``Authorization: Bearer`` header + JSON content-type handling in one +place. +""" + +from __future__ import annotations + +from typing import Any + +import requests + +from automation_file.exceptions import OneDriveException +from automation_file.logging_config import file_automation_logger + +_GRAPH_BASE = "https://graph.microsoft.com/v1.0" +_DEFAULT_SCOPES = ("Files.ReadWrite", "Files.ReadWrite.All") +_DEFAULT_AUTHORITY = "https://login.microsoftonline.com/common" + + +def _import_msal() -> Any: + try: + import msal + except ImportError as error: + raise OneDriveException( + "msal import failed — reinstall `automation_file` to restore the OneDrive backend" + ) from error + return msal + + +class OneDriveClient: + """Lazy wrapper holding an access token and a :class:`requests.Session`.""" + + def __init__(self) -> None: + self._access_token: str | None = None + self._session: requests.Session | None = None + + def later_init(self, access_token: str) -> bool: + """Install a pre-obtained OAuth2 access token. Returns True on success.""" + if not isinstance(access_token, str) or not access_token: + raise OneDriveException("access_token must be a non-empty string") + self._access_token = access_token + self._session = requests.Session() + self._session.headers["Authorization"] = f"Bearer {access_token}" + file_automation_logger.info("OneDriveClient: access token installed") + return True + + def device_code_login( + self, + client_id: str, + *, + tenant_id: str | None = None, + scopes: tuple[str, ...] | None = None, + timeout: int = 300, + ) -> dict[str, Any]: + """Run MSAL's device-code flow and install the resulting token. + + Blocks until the user completes the login (or ``timeout`` seconds + elapse). Returns the raw MSAL token dict so callers can inspect + claims / refresh window. The message to present to the user is in + the MSAL log — it is not returned here to avoid leaking it into an + action-result payload. + """ + msal = _import_msal() + authority = ( + f"https://login.microsoftonline.com/{tenant_id}" if tenant_id else _DEFAULT_AUTHORITY + ) + app = msal.PublicClientApplication(client_id=client_id, authority=authority) + flow = app.initiate_device_flow(scopes=list(scopes or _DEFAULT_SCOPES)) + if "user_code" not in flow: + raise OneDriveException( + f"device-code flow init failed: {flow.get('error_description', flow)}" + ) + file_automation_logger.info("OneDriveClient: %s", flow.get("message", "")) + flow["expires_at"] = flow.get("expires_in", timeout) + result = app.acquire_token_by_device_flow(flow) + access_token = result.get("access_token") + if not access_token: + raise OneDriveException( + f"device-code login failed: {result.get('error_description', result)}" + ) + self.later_init(access_token) + return result + + def require_session(self) -> requests.Session: + if self._session is None: + raise OneDriveException( + "OneDriveClient not initialised; call later_init() or device_code_login() first" + ) + return self._session + + def graph_request( + self, + method: str, + path: str, + *, + timeout: float = 30.0, + **request_kwargs: Any, + ) -> requests.Response: + """Issue a Microsoft Graph API request against ``/me/drive`` (or a full URL). + + Paths starting with ``/`` are joined onto the base + ``https://graph.microsoft.com/v1.0`` endpoint; absolute ``https://`` + URLs are used verbatim (handy for the ``@microsoft.graph.downloadUrl`` + redirect Graph hands out for file contents). ``request_kwargs`` is + forwarded to :meth:`requests.Session.request` — ``params``, ``json``, + ``data``, and ``headers`` are the common ones. + """ + session = self.require_session() + url = path if path.startswith("http") else f"{_GRAPH_BASE}{path}" + try: + response = session.request(method, url, timeout=timeout, **request_kwargs) + except requests.RequestException as error: + raise OneDriveException(f"graph request failed: {error}") from error + if not response.ok: + raise OneDriveException( + f"graph {method} {path} returned {response.status_code}: {response.text[:200]}" + ) + return response + + def close(self) -> bool: + if self._session is not None: + self._session.close() + self._session = None + self._access_token = None + return True + + +onedrive_instance: OneDriveClient = OneDriveClient() diff --git a/automation_file/remote/onedrive/delete_ops.py b/automation_file/remote/onedrive/delete_ops.py new file mode 100644 index 0000000..b71d750 --- /dev/null +++ b/automation_file/remote/onedrive/delete_ops.py @@ -0,0 +1,14 @@ +"""OneDrive delete operations.""" + +from __future__ import annotations + +from automation_file.logging_config import file_automation_logger +from automation_file.remote.onedrive.client import onedrive_instance + + +def onedrive_delete_item(remote_path: str) -> bool: + """Delete a file or folder at ``remote_path``.""" + encoded = remote_path.lstrip("/").replace(" ", "%20") + onedrive_instance.graph_request("DELETE", f"/me/drive/root:/{encoded}") + file_automation_logger.info("onedrive_delete_item: %s", remote_path) + return True diff --git a/automation_file/remote/onedrive/download_ops.py b/automation_file/remote/onedrive/download_ops.py new file mode 100644 index 0000000..b14a57e --- /dev/null +++ b/automation_file/remote/onedrive/download_ops.py @@ -0,0 +1,23 @@ +"""OneDrive download operations.""" + +from __future__ import annotations + +from pathlib import Path + +from automation_file.logging_config import file_automation_logger +from automation_file.remote.onedrive.client import onedrive_instance + + +def onedrive_download_file(remote_path: str, target_path: str) -> bool: + """Download ``remote_path`` to ``target_path`` via Microsoft Graph.""" + encoded = remote_path.lstrip("/").replace(" ", "%20") + response = onedrive_instance.graph_request( + "GET", f"/me/drive/root:/{encoded}:/content", timeout=120.0 + ) + target = Path(target_path) + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(response.content) + file_automation_logger.info( + "onedrive_download_file: %s -> %s (%d bytes)", remote_path, target, len(response.content) + ) + return True diff --git a/automation_file/remote/onedrive/list_ops.py b/automation_file/remote/onedrive/list_ops.py new file mode 100644 index 0000000..e0b782f --- /dev/null +++ b/automation_file/remote/onedrive/list_ops.py @@ -0,0 +1,37 @@ +"""OneDrive listing operations.""" + +from __future__ import annotations + +from typing import Any + +from automation_file.logging_config import file_automation_logger +from automation_file.remote.onedrive.client import onedrive_instance + + +def onedrive_list_folder(remote_path: str = "") -> list[dict[str, Any]]: + """List a folder under ``remote_path`` (root if empty). Returns a list of entries. + + Each entry is a minimal view: ``{"name": str, "type": "file"|"folder", "size": int}``. + Graph's pagination ``@odata.nextLink`` is followed until exhausted so + large folders return in one call. + """ + encoded = remote_path.lstrip("/").replace(" ", "%20") + path = f"/me/drive/root:/{encoded}:/children" if encoded else "/me/drive/root/children" + entries: list[dict[str, Any]] = [] + cursor: str | None = path + while cursor: + response = onedrive_instance.graph_request("GET", cursor) + payload = response.json() + for item in payload.get("value", []): + entries.append( + { + "name": item.get("name", ""), + "type": "folder" if "folder" in item else "file", + "size": int(item.get("size", 0)), + } + ) + cursor = payload.get("@odata.nextLink") + file_automation_logger.info( + "onedrive_list_folder: %s -> %d entries", remote_path or "/", len(entries) + ) + return entries diff --git a/automation_file/remote/onedrive/upload_ops.py b/automation_file/remote/onedrive/upload_ops.py new file mode 100644 index 0000000..fad9f95 --- /dev/null +++ b/automation_file/remote/onedrive/upload_ops.py @@ -0,0 +1,59 @@ +"""OneDrive upload operations (Microsoft Graph).""" + +from __future__ import annotations + +from pathlib import Path + +from automation_file.exceptions import FileNotExistsException, OneDriveException +from automation_file.logging_config import file_automation_logger +from automation_file.remote._upload_tree import walk_and_upload +from automation_file.remote.onedrive.client import onedrive_instance + +# 4 MB — Graph's documented upper bound for simple PUT uploads. +_SIMPLE_UPLOAD_MAX = 4 * 1024 * 1024 + + +def onedrive_upload_file(file_path: str, remote_path: str) -> bool: + """Upload a local file to ``remote_path`` under ``/me/drive/root``. + + ``remote_path`` is treated as a posix-style path relative to the drive + root (e.g. ``Documents/report.pdf``). Parent folders are created lazily + by Graph when present in the path. Files over 4 MiB are rejected so + callers don't silently truncate — resumable upload sessions are a + separate helper that can be added later. + """ + local = Path(file_path) + if not local.is_file(): + raise FileNotExistsException(str(local)) + size = local.stat().st_size + if size > _SIMPLE_UPLOAD_MAX: + raise OneDriveException( + f"{local} is {size} bytes; simple upload is capped at {_SIMPLE_UPLOAD_MAX}" + ) + data = local.read_bytes() + encoded = remote_path.lstrip("/").replace(" ", "%20") + onedrive_instance.graph_request( + "PUT", + f"/me/drive/root:/{encoded}:/content", + data=data, + headers={"Content-Type": "application/octet-stream"}, + ) + file_automation_logger.info("onedrive_upload_file: %s -> %s", local, remote_path) + return True + + +def onedrive_upload_dir(dir_path: str, remote_prefix: str = "") -> list[str]: + """Upload every file under ``dir_path`` to ``remote_prefix``.""" + result = walk_and_upload( + dir_path, + remote_prefix, + lambda prefix, rel: f"{prefix}/{rel}" if prefix else rel, + lambda local, name: onedrive_upload_file(str(local), name), + ) + file_automation_logger.info( + "onedrive_upload_dir: %s -> %s (%d files)", + result.source, + result.prefix, + len(result.uploaded), + ) + return result.uploaded diff --git a/automation_file/ui/tabs/__init__.py b/automation_file/ui/tabs/__init__.py index 8142f19..413236f 100644 --- a/automation_file/ui/tabs/__init__.py +++ b/automation_file/ui/tabs/__init__.py @@ -3,12 +3,14 @@ from __future__ import annotations from automation_file.ui.tabs.azure_tab import AzureBlobTab +from automation_file.ui.tabs.box_tab import BoxTab from automation_file.ui.tabs.drive_tab import GoogleDriveTab from automation_file.ui.tabs.dropbox_tab import DropboxTab from automation_file.ui.tabs.home_tab import HomeTab from automation_file.ui.tabs.http_tab import HTTPDownloadTab from automation_file.ui.tabs.json_editor_tab import JSONEditorTab from automation_file.ui.tabs.local_tab import LocalOpsTab +from automation_file.ui.tabs.onedrive_tab import OneDriveTab from automation_file.ui.tabs.progress_tab import ProgressTab from automation_file.ui.tabs.s3_tab import S3Tab from automation_file.ui.tabs.scheduler_tab import SchedulerTab @@ -19,12 +21,14 @@ __all__ = [ "AzureBlobTab", + "BoxTab", "DropboxTab", "GoogleDriveTab", "HTTPDownloadTab", "HomeTab", "JSONEditorTab", "LocalOpsTab", + "OneDriveTab", "ProgressTab", "S3Tab", "SFTPTab", diff --git a/automation_file/ui/tabs/box_tab.py b/automation_file/ui/tabs/box_tab.py new file mode 100644 index 0000000..827641c --- /dev/null +++ b/automation_file/ui/tabs/box_tab.py @@ -0,0 +1,112 @@ +"""Box tab.""" + +from __future__ import annotations + +from PySide6.QtWidgets import ( + QCheckBox, + QFormLayout, + QGroupBox, + QLineEdit, +) + +from automation_file.remote.box.client import box_instance +from automation_file.remote.box.delete_ops import box_delete_file, box_delete_folder +from automation_file.remote.box.download_ops import box_download_file +from automation_file.remote.box.list_ops import box_list_folder +from automation_file.remote.box.upload_ops import box_upload_dir, box_upload_file +from automation_file.ui.tabs.base import RemoteBackendTab + + +class BoxTab(RemoteBackendTab): + """Form-driven Box operations.""" + + def _init_group(self) -> QGroupBox: + box = QGroupBox("Client") + form = QFormLayout(box) + self._token = QLineEdit() + self._token.setEchoMode(QLineEdit.EchoMode.Password) + self._token.setPlaceholderText("OAuth2 access token") + form.addRow("Access token", self._token) + form.addRow(self.make_button("Initialise Box client", self._on_init)) + return box + + def _ops_group(self) -> QGroupBox: + box = QGroupBox("Operations") + form = QFormLayout(box) + self._local = QLineEdit() + self._folder_id = QLineEdit("0") + self._file_id = QLineEdit() + self._recursive = QCheckBox("Recursive delete") + form.addRow("Local path", self._local) + form.addRow("Folder id", self._folder_id) + form.addRow("File id", self._file_id) + form.addRow(self._recursive) + form.addRow(self.make_button("Upload file", self._on_upload_file)) + form.addRow(self.make_button("Upload dir", self._on_upload_dir)) + form.addRow(self.make_button("Download", self._on_download)) + form.addRow(self.make_button("Delete file", self._on_delete_file)) + form.addRow(self.make_button("Delete folder", self._on_delete_folder)) + form.addRow(self.make_button("List folder", self._on_list)) + return box + + def _on_init(self) -> None: + token = self._token.text().strip() + self.run_action( + box_instance.later_init, + "box.later_init", + kwargs={"access_token": token}, + ) + + def _on_upload_file(self) -> None: + self.run_action( + box_upload_file, + f"box_upload_file {self._local.text().strip()}", + kwargs={ + "file_path": self._local.text().strip(), + "parent_folder_id": self._folder_id.text().strip() or "0", + }, + ) + + def _on_upload_dir(self) -> None: + self.run_action( + box_upload_dir, + f"box_upload_dir {self._local.text().strip()}", + kwargs={ + "dir_path": self._local.text().strip(), + "parent_folder_id": self._folder_id.text().strip() or "0", + }, + ) + + def _on_download(self) -> None: + self.run_action( + box_download_file, + f"box_download_file {self._file_id.text().strip()}", + kwargs={ + "file_id": self._file_id.text().strip(), + "target_path": self._local.text().strip(), + }, + ) + + def _on_delete_file(self) -> None: + self.run_action( + box_delete_file, + f"box_delete_file {self._file_id.text().strip()}", + kwargs={"file_id": self._file_id.text().strip()}, + ) + + def _on_delete_folder(self) -> None: + self.run_action( + box_delete_folder, + f"box_delete_folder {self._folder_id.text().strip()}", + kwargs={ + "folder_id": self._folder_id.text().strip(), + "recursive": self._recursive.isChecked(), + }, + ) + + def _on_list(self) -> None: + self.run_action( + box_list_folder, + f"box_list_folder {self._folder_id.text().strip()}", + kwargs={"folder_id": self._folder_id.text().strip()}, + ) diff --git a/automation_file/ui/tabs/home_tab.py b/automation_file/ui/tabs/home_tab.py index 870b8fa..6189dcc 100644 --- a/automation_file/ui/tabs/home_tab.py +++ b/automation_file/ui/tabs/home_tab.py @@ -15,8 +15,10 @@ ) from automation_file.remote.azure_blob.client import azure_blob_instance +from automation_file.remote.box.client import box_instance from automation_file.remote.dropbox_api.client import dropbox_instance from automation_file.remote.google_drive.client import driver_instance +from automation_file.remote.onedrive.client import onedrive_instance from automation_file.remote.s3.client import s3_instance from automation_file.remote.sftp.client import sftp_instance from automation_file.ui.log_widget import LogPanel @@ -36,6 +38,8 @@ class _BackendProbe(NamedTuple): _BackendProbe("Azure Blob", lambda: azure_blob_instance.service is not None), _BackendProbe("Dropbox", lambda: dropbox_instance.client is not None), _BackendProbe("SFTP", lambda: getattr(sftp_instance, "_sftp", None) is not None), + _BackendProbe("OneDrive", lambda: getattr(onedrive_instance, "_session", None) is not None), + _BackendProbe("Box", lambda: box_instance.client is not None), ) diff --git a/automation_file/ui/tabs/onedrive_tab.py b/automation_file/ui/tabs/onedrive_tab.py new file mode 100644 index 0000000..81ed90f --- /dev/null +++ b/automation_file/ui/tabs/onedrive_tab.py @@ -0,0 +1,119 @@ +"""OneDrive tab.""" + +from __future__ import annotations + +from PySide6.QtWidgets import ( + QFormLayout, + QGroupBox, + QLineEdit, + QPushButton, +) + +from automation_file.remote.onedrive.client import onedrive_instance +from automation_file.remote.onedrive.delete_ops import onedrive_delete_item +from automation_file.remote.onedrive.download_ops import onedrive_download_file +from automation_file.remote.onedrive.list_ops import onedrive_list_folder +from automation_file.remote.onedrive.upload_ops import ( + onedrive_upload_dir, + onedrive_upload_file, +) +from automation_file.ui.tabs.base import RemoteBackendTab + + +class OneDriveTab(RemoteBackendTab): + """Form-driven OneDrive operations (Microsoft Graph).""" + + def _init_group(self) -> QGroupBox: + box = QGroupBox("Client") + form = QFormLayout(box) + self._token = QLineEdit() + self._token.setEchoMode(QLineEdit.EchoMode.Password) + self._token.setPlaceholderText("OAuth2 access token (from MSAL)") + form.addRow("Access token", self._token) + btn = QPushButton("Initialise OneDrive client") + btn.clicked.connect(self._on_init) + form.addRow(btn) + self._client_id = QLineEdit() + self._tenant_id = QLineEdit() + self._tenant_id.setPlaceholderText("(optional)") + form.addRow("MSAL client id", self._client_id) + form.addRow("MSAL tenant id", self._tenant_id) + device_btn = QPushButton("Device-code login") + device_btn.clicked.connect(self._on_device_code) + form.addRow(device_btn) + return box + + def _ops_group(self) -> QGroupBox: + box = QGroupBox("Operations") + form = QFormLayout(box) + self._local = QLineEdit() + self._remote = QLineEdit() + form.addRow("Local path", self._local) + form.addRow("Remote path", self._remote) + form.addRow(self.make_button("Upload file", self._on_upload_file)) + form.addRow(self.make_button("Upload dir", self._on_upload_dir)) + form.addRow(self.make_button("Download", self._on_download)) + form.addRow(self.make_button("Delete", self._on_delete)) + form.addRow(self.make_button("List folder", self._on_list)) + return box + + def _on_init(self) -> None: + token = self._token.text().strip() + self.run_action( + onedrive_instance.later_init, + "onedrive.later_init", + kwargs={"access_token": token}, + ) + + def _on_device_code(self) -> None: + client_id = self._client_id.text().strip() + tenant_id = self._tenant_id.text().strip() or None + self.run_action( + onedrive_instance.device_code_login, + "onedrive.device_code_login", + kwargs={"client_id": client_id, "tenant_id": tenant_id}, + ) + + def _on_upload_file(self) -> None: + self.run_action( + onedrive_upload_file, + f"onedrive_upload_file {self._local.text().strip()}", + kwargs={ + "file_path": self._local.text().strip(), + "remote_path": self._remote.text().strip(), + }, + ) + + def _on_upload_dir(self) -> None: + self.run_action( + onedrive_upload_dir, + f"onedrive_upload_dir {self._local.text().strip()}", + kwargs={ + "dir_path": self._local.text().strip(), + "remote_prefix": self._remote.text().strip(), + }, + ) + + def _on_download(self) -> None: + self.run_action( + onedrive_download_file, + f"onedrive_download_file {self._remote.text().strip()}", + kwargs={ + "remote_path": self._remote.text().strip(), + "target_path": self._local.text().strip(), + }, + ) + + def _on_delete(self) -> None: + self.run_action( + onedrive_delete_item, + f"onedrive_delete_item {self._remote.text().strip()}", + kwargs={"remote_path": self._remote.text().strip()}, + ) + + def _on_list(self) -> None: + self.run_action( + onedrive_list_folder, + f"onedrive_list_folder {self._remote.text().strip()}", + kwargs={"remote_path": self._remote.text().strip()}, + ) diff --git a/automation_file/ui/tabs/transfer_tab.py b/automation_file/ui/tabs/transfer_tab.py index 92a91d5..34cca43 100644 --- a/automation_file/ui/tabs/transfer_tab.py +++ b/automation_file/ui/tabs/transfer_tab.py @@ -22,9 +22,11 @@ from automation_file.ui.log_widget import LogPanel from automation_file.ui.tabs.azure_tab import AzureBlobTab from automation_file.ui.tabs.base import BaseTab +from automation_file.ui.tabs.box_tab import BoxTab from automation_file.ui.tabs.drive_tab import GoogleDriveTab from automation_file.ui.tabs.dropbox_tab import DropboxTab from automation_file.ui.tabs.http_tab import HTTPDownloadTab +from automation_file.ui.tabs.onedrive_tab import OneDriveTab from automation_file.ui.tabs.s3_tab import S3Tab from automation_file.ui.tabs.sftp_tab import SFTPTab @@ -41,6 +43,8 @@ class _BackendEntry(NamedTuple): _BackendEntry("Azure Blob", AzureBlobTab), _BackendEntry("Dropbox", DropboxTab), _BackendEntry("SFTP", SFTPTab), + _BackendEntry("OneDrive", OneDriveTab), + _BackendEntry("Box", BoxTab), ) diff --git a/dev.toml b/dev.toml index 067963a..1d0f3a3 100644 --- a/dev.toml +++ b/dev.toml @@ -28,6 +28,12 @@ dependencies = [ "cryptography>=46.0.7", "prometheus_client>=0.25.0", "defusedxml>=0.7.1", + "PyYAML>=6.0", + "pyarrow>=15.0.0", + "opentelemetry-api>=1.25.0", + "opentelemetry-sdk>=1.25.0", + "msal>=1.28.0", + "boxsdk>=3.14.0", "tomli>=2.0.1; python_version<\"3.11\"" ] classifiers = [ diff --git a/requirements.txt b/requirements.txt index 4564960..e1006ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,10 @@ protobuf tqdm watchdog defusedxml>=0.7.1 +PyYAML>=6.0 +pyarrow>=15.0.0 +opentelemetry-api>=1.25.0 +opentelemetry-sdk>=1.25.0 +msal>=1.28.0 +boxsdk>=3.14.0 tomli; python_version<"3.11" \ No newline at end of file diff --git a/stable.toml b/stable.toml index d8702c1..9c30413 100644 --- a/stable.toml +++ b/stable.toml @@ -28,6 +28,12 @@ dependencies = [ "cryptography>=46.0.7", "prometheus_client>=0.25.0", "defusedxml>=0.7.1", + "PyYAML>=6.0", + "pyarrow>=15.0.0", + "opentelemetry-api>=1.25.0", + "opentelemetry-sdk>=1.25.0", + "msal>=1.28.0", + "boxsdk>=3.14.0", "tomli>=2.0.1; python_version<\"3.11\"" ] classifiers = [ diff --git a/tests/test_backends.py b/tests/test_backends.py index 3bf7a3d..c5f5134 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -18,6 +18,8 @@ ("automation_file.remote.azure_blob", "azure_blob_instance"), ("automation_file.remote.dropbox_api", "dropbox_instance"), ("automation_file.remote.sftp", "sftp_instance"), + ("automation_file.remote.onedrive", "onedrive_instance"), + ("automation_file.remote.box", "box_instance"), ] @@ -40,6 +42,10 @@ def test_default_registry_contains_every_backend() -> None: "FA_dropbox_list_folder", "FA_sftp_upload_file", "FA_sftp_list_dir", + "FA_onedrive_upload_file", + "FA_onedrive_list_folder", + "FA_box_upload_file", + "FA_box_list_folder", ] for name in expected: assert name in registry, f"{name} missing from default registry" @@ -80,3 +86,23 @@ def test_register_sftp_ops_adds_entries() -> None: registry = ActionRegistry() register_sftp_ops(registry) assert "FA_sftp_upload_file" in registry + + +def test_register_onedrive_ops_adds_entries() -> None: + from automation_file.core.action_registry import ActionRegistry + from automation_file.remote.onedrive import register_onedrive_ops + + registry = ActionRegistry() + register_onedrive_ops(registry) + for name in ("FA_onedrive_upload_file", "FA_onedrive_list_folder", "FA_onedrive_close"): + assert name in registry + + +def test_register_box_ops_adds_entries() -> None: + from automation_file.core.action_registry import ActionRegistry + from automation_file.remote.box import register_box_ops + + registry = ActionRegistry() + register_box_ops(registry) + for name in ("FA_box_upload_file", "FA_box_list_folder", "FA_box_delete_file"): + assert name in registry diff --git a/tests/test_box_ops.py b/tests/test_box_ops.py new file mode 100644 index 0000000..ab88d07 --- /dev/null +++ b/tests/test_box_ops.py @@ -0,0 +1,171 @@ +"""Box backend tests. + +Live Box endpoints are outside CI; these tests verify registry wiring, +the Client singleton's guard clauses, and the error-path wrapping that +converts ``boxsdk`` failures into :class:`BoxException`. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from automation_file import ( + BoxClient, + BoxException, + box_instance, + build_default_registry, + register_box_ops, +) +from automation_file.core.action_registry import ActionRegistry +from automation_file.exceptions import FileNotExistsException +from automation_file.remote.box import delete_ops, download_ops, list_ops, upload_ops + + +class _FakeItem: + def __init__(self, item_id: str, name: str, item_type: str = "file") -> None: + self.id = item_id + self.name = name + self.type = item_type + + +class _FakeFile: + def __init__(self, file_id: str) -> None: + self.id = file_id + + def download_to(self, writer: Any) -> None: + writer.write(b"contents") + + def delete(self) -> None: + return None + + +class _FakeFolder: + def __init__(self, folder_id: str) -> None: + self.id = folder_id + self._uploads: list[tuple[str, str]] = [] + + def upload(self, file_path: str, file_name: str) -> _FakeFile: + self._uploads.append((file_path, file_name)) + return _FakeFile("new-id") + + def get_items(self, limit: int = 100) -> list[_FakeItem]: + del limit + return [_FakeItem("1", "a.txt"), _FakeItem("2", "subdir", "folder")] + + def delete(self, recursive: bool = False) -> None: + del recursive + + +class _FakeBoxClient: + def __init__(self) -> None: + self._files: dict[str, _FakeFile] = {} + self._folders: dict[str, _FakeFolder] = {} + + def file(self, file_id: str) -> _FakeFile: + self._files.setdefault(file_id, _FakeFile(file_id)) + return self._files[file_id] + + def folder(self, folder_id: str) -> _FakeFolder: + self._folders.setdefault(folder_id, _FakeFolder(folder_id)) + return self._folders[folder_id] + + +@pytest.fixture(name="fake_box") +def _fake_box(monkeypatch: pytest.MonkeyPatch) -> _FakeBoxClient: + fake = _FakeBoxClient() + monkeypatch.setattr(box_instance, "client", fake, raising=False) + return fake + + +def test_require_client_raises_when_not_initialised() -> None: + client = BoxClient() + with pytest.raises(BoxException): + client.require_client() + + +def test_later_init_rejects_empty_token() -> None: + client = BoxClient() + with pytest.raises(BoxException): + client.later_init("") + + +def test_default_registry_contains_box() -> None: + registry = build_default_registry() + for name in ( + "FA_box_upload_file", + "FA_box_list_folder", + "FA_box_delete_file", + "FA_box_delete_folder", + ): + assert name in registry + + +def test_register_box_ops_adds_entries() -> None: + registry = ActionRegistry() + register_box_ops(registry) + assert "FA_box_upload_file" in registry + + +def test_upload_file_rejects_missing_source(tmp_path: Path, fake_box: _FakeBoxClient) -> None: + del fake_box + with pytest.raises(FileNotExistsException): + upload_ops.box_upload_file(str(tmp_path / "gone.txt")) + + +def test_upload_file_returns_id(tmp_path: Path, fake_box: _FakeBoxClient) -> None: + del fake_box + src = tmp_path / "report.txt" + src.write_text("ok", encoding="utf-8") + file_id = upload_ops.box_upload_file(str(src)) + assert file_id == "new-id" + + +def test_upload_dir_uploads_each_file(tmp_path: Path, fake_box: _FakeBoxClient) -> None: + (tmp_path / "a.txt").write_text("a", encoding="utf-8") + (tmp_path / "sub").mkdir() + (tmp_path / "sub" / "b.txt").write_text("b", encoding="utf-8") + uploaded_keys = upload_ops.box_upload_dir(str(tmp_path)) + assert sorted(uploaded_keys) == ["a.txt", "sub/b.txt"] + folder = fake_box.folder("0") + flat_names = sorted(name for _, name in folder._uploads) # pylint: disable=protected-access + assert flat_names == ["a.txt", "sub/b.txt"] + + +def test_download_writes_target(tmp_path: Path, fake_box: _FakeBoxClient) -> None: + del fake_box + target = tmp_path / "out" / "f.txt" + assert download_ops.box_download_file("42", str(target)) is True + assert target.read_bytes() == b"contents" + + +def test_list_folder_returns_entries(fake_box: _FakeBoxClient) -> None: + del fake_box + entries = list_ops.box_list_folder() + assert entries == [ + {"id": "1", "name": "a.txt", "type": "file"}, + {"id": "2", "name": "subdir", "type": "folder"}, + ] + + +def test_delete_file_uses_client(fake_box: _FakeBoxClient) -> None: + del fake_box + assert delete_ops.box_delete_file("7") is True + + +def test_delete_folder_uses_client(fake_box: _FakeBoxClient) -> None: + del fake_box + assert delete_ops.box_delete_folder("7", recursive=True) is True + + +def test_errors_in_sdk_surface_as_box_exception( + monkeypatch: pytest.MonkeyPatch, fake_box: _FakeBoxClient +) -> None: + def blow(*_a: Any, **_k: Any) -> None: + raise RuntimeError("simulated SDK error") + + monkeypatch.setattr(fake_box, "folder", blow) + with pytest.raises(BoxException): + list_ops.box_list_folder() diff --git a/tests/test_data_ops.py b/tests/test_data_ops.py new file mode 100644 index 0000000..12f1331 --- /dev/null +++ b/tests/test_data_ops.py @@ -0,0 +1,126 @@ +"""Tests for automation_file.local.data_ops (CSV + JSONL).""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from automation_file import ( + DataOpsException, + build_default_registry, + csv_filter, + csv_to_jsonl, + jsonl_append, + jsonl_iter, +) +from automation_file.exceptions import FileNotExistsException + + +def _write_csv(path: Path, rows: list[list[str]]) -> None: + path.write_text("\n".join(",".join(r) for r in rows) + "\n", encoding="utf-8") + + +def test_csv_filter_passthrough(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id", "name"], ["1", "alice"], ["2", "bob"]]) + dest = tmp_path / "b.csv" + assert csv_filter(str(src), str(dest)) == 2 + assert dest.read_text(encoding="utf-8").splitlines() == ["id,name", "1,alice", "2,bob"] + + +def test_csv_filter_projects_columns(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id", "name", "team"], ["1", "alice", "red"], ["2", "bob", "blue"]]) + dest = tmp_path / "b.csv" + csv_filter(str(src), str(dest), columns=["name", "id"]) + assert dest.read_text(encoding="utf-8").splitlines() == ["name,id", "alice,1", "bob,2"] + + +def test_csv_filter_where_clause(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id", "team"], ["1", "red"], ["2", "blue"], ["3", "red"]]) + dest = tmp_path / "b.csv" + written = csv_filter(str(src), str(dest), where_column="team", where_equals="red") + assert written == 2 + assert dest.read_text(encoding="utf-8").splitlines() == ["id,team", "1,red", "3,red"] + + +def test_csv_filter_rejects_unknown_column(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id"], ["1"]]) + with pytest.raises(DataOpsException): + csv_filter(str(src), str(tmp_path / "b.csv"), columns=["missing"]) + + +def test_csv_filter_requires_paired_where(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id"], ["1"]]) + with pytest.raises(DataOpsException): + csv_filter(str(src), str(tmp_path / "b.csv"), where_column="id") + + +def test_csv_filter_rejects_missing_source(tmp_path: Path) -> None: + with pytest.raises(FileNotExistsException): + csv_filter(str(tmp_path / "gone.csv"), str(tmp_path / "out.csv")) + + +def test_csv_to_jsonl_basic(tmp_path: Path) -> None: + src = tmp_path / "a.csv" + _write_csv(src, [["id", "name"], ["1", "alice"], ["2", "bob"]]) + dest = tmp_path / "a.jsonl" + assert csv_to_jsonl(str(src), str(dest)) == 2 + lines = dest.read_text(encoding="utf-8").splitlines() + assert json.loads(lines[0]) == {"id": "1", "name": "alice"} + assert json.loads(lines[1]) == {"id": "2", "name": "bob"} + + +def test_jsonl_iter_parses_records(tmp_path: Path) -> None: + path = tmp_path / "x.jsonl" + path.write_text( + '{"a": 1}\n' + "\n" # blank line should be skipped + '{"a": 2}\n', + encoding="utf-8", + ) + records = jsonl_iter(str(path)) + assert records == [{"a": 1}, {"a": 2}] + + +def test_jsonl_iter_respects_limit(tmp_path: Path) -> None: + path = tmp_path / "x.jsonl" + path.write_text('{"a":1}\n{"a":2}\n{"a":3}\n', encoding="utf-8") + assert jsonl_iter(str(path), limit=2) == [{"a": 1}, {"a": 2}] + + +def test_jsonl_iter_rejects_non_object(tmp_path: Path) -> None: + path = tmp_path / "x.jsonl" + path.write_text('{"a":1}\n["not","object"]\n', encoding="utf-8") + with pytest.raises(DataOpsException): + jsonl_iter(str(path)) + + +def test_jsonl_iter_rejects_bad_json(tmp_path: Path) -> None: + path = tmp_path / "x.jsonl" + path.write_text("{bad\n", encoding="utf-8") + with pytest.raises(DataOpsException): + jsonl_iter(str(path)) + + +def test_jsonl_append_appends(tmp_path: Path) -> None: + path = tmp_path / "x.jsonl" + assert jsonl_append(str(path), {"a": 1}) is True + assert jsonl_append(str(path), {"a": 2}) is True + assert jsonl_iter(str(path)) == [{"a": 1}, {"a": 2}] + + +def test_jsonl_append_rejects_non_dict(tmp_path: Path) -> None: + with pytest.raises(DataOpsException): + jsonl_append(str(tmp_path / "x.jsonl"), ["not", "a", "dict"]) # type: ignore[arg-type] + + +def test_data_ops_registered() -> None: + registry = build_default_registry() + for name in ("FA_csv_filter", "FA_csv_to_jsonl", "FA_jsonl_iter", "FA_jsonl_append"): + assert name in registry diff --git a/tests/test_data_ops_yaml_parquet.py b/tests/test_data_ops_yaml_parquet.py new file mode 100644 index 0000000..a8afb19 --- /dev/null +++ b/tests/test_data_ops_yaml_parquet.py @@ -0,0 +1,141 @@ +"""Tests for automation_file.local.data_ops YAML + Parquet helpers.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from automation_file import ( + DataOpsException, + build_default_registry, + csv_to_parquet, + parquet_read, + parquet_write, + yaml_delete, + yaml_get, + yaml_set, +) +from automation_file.exceptions import FileNotExistsException + +# --- YAML ----------------------------------------------------------------- + + +def _write_yaml(path: Path, body: str) -> None: + path.write_text(body, encoding="utf-8") + + +def test_yaml_get_nested(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "a:\n b:\n c: 42\n") + assert yaml_get(str(path), "a.b.c") == 42 + + +def test_yaml_get_missing_returns_default(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "root: 1\n") + assert yaml_get(str(path), "missing.key", default="fallback") == "fallback" + + +def test_yaml_get_list_index(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "items:\n - a\n - b\n - c\n") + assert yaml_get(str(path), "items.1") == "b" + + +def test_yaml_set_creates_intermediate_dicts(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "existing: kept\n") + yaml_set(str(path), "new.nested.key", "value") + assert yaml_get(str(path), "new.nested.key") == "value" + assert yaml_get(str(path), "existing") == "kept" + + +def test_yaml_set_rejects_empty_key_path(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "a: 1\n") + with pytest.raises(DataOpsException): + yaml_set(str(path), "", "x") + + +def test_yaml_delete_returns_true_when_removed(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "a:\n b: 1\n c: 2\n") + assert yaml_delete(str(path), "a.b") is True + assert yaml_get(str(path), "a.b") is None + assert yaml_get(str(path), "a.c") == 2 + + +def test_yaml_delete_returns_false_when_missing(tmp_path: Path) -> None: + path = tmp_path / "c.yaml" + _write_yaml(path, "a: 1\n") + assert yaml_delete(str(path), "nope") is False + + +def test_yaml_load_rejects_malformed(tmp_path: Path) -> None: + path = tmp_path / "bad.yaml" + _write_yaml(path, "a: [unterminated\n") + with pytest.raises(DataOpsException): + yaml_get(str(path), "a") + + +def test_yaml_handles_missing_file(tmp_path: Path) -> None: + with pytest.raises(FileNotExistsException): + yaml_get(str(tmp_path / "gone.yaml"), "a") + + +# --- Parquet -------------------------------------------------------------- + + +def test_parquet_write_and_read_roundtrip(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + records = [{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}] + assert parquet_write(str(path), records) == 2 + assert parquet_read(str(path)) == records + + +def test_parquet_read_respects_limit(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + parquet_write(str(path), [{"i": n} for n in range(5)]) + assert parquet_read(str(path), limit=2) == [{"i": 0}, {"i": 1}] + + +def test_parquet_read_projects_columns(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + parquet_write(str(path), [{"id": 1, "name": "alice", "team": "red"}]) + assert parquet_read(str(path), columns=["id", "team"]) == [{"id": 1, "team": "red"}] + + +def test_parquet_write_rejects_non_list(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + with pytest.raises(DataOpsException): + parquet_write(str(path), {"not": "a list"}) # type: ignore[arg-type] + + +def test_parquet_read_rejects_missing_file(tmp_path: Path) -> None: + with pytest.raises(FileNotExistsException): + parquet_read(str(tmp_path / "gone.parquet")) + + +def test_csv_to_parquet_roundtrip(tmp_path: Path) -> None: + csv_path = tmp_path / "a.csv" + csv_path.write_text("id,name\n1,alice\n2,bob\n", encoding="utf-8") + parquet_path = tmp_path / "a.parquet" + assert csv_to_parquet(str(csv_path), str(parquet_path)) == 2 + assert parquet_read(str(parquet_path)) == [ + {"id": "1", "name": "alice"}, + {"id": "2", "name": "bob"}, + ] + + +def test_yaml_parquet_actions_registered() -> None: + registry = build_default_registry() + for name in ( + "FA_yaml_get", + "FA_yaml_set", + "FA_yaml_delete", + "FA_parquet_read", + "FA_parquet_write", + "FA_csv_to_parquet", + ): + assert name in registry diff --git a/tests/test_diff_ops.py b/tests/test_diff_ops.py index e694e6a..3408b46 100644 --- a/tests/test_diff_ops.py +++ b/tests/test_diff_ops.py @@ -6,11 +6,14 @@ import pytest +from automation_file import build_default_registry from automation_file.exceptions import DiffException, PathTraversalException from automation_file.local.diff_ops import ( DirDiff, apply_dir_diff, + apply_text_patch, diff_dirs, + diff_dirs_summary, diff_text_files, iter_dir_diff, ) @@ -87,3 +90,42 @@ def test_iter_dir_diff_labels_entries() -> None: assert ("added", "a") in entries assert ("removed", "b") in entries assert ("changed", "c") in entries + + +def test_diff_dirs_summary_returns_plain_dict(tmp_path: Path) -> None: + left = tmp_path / "a" + right = tmp_path / "b" + _populate(left, {"keep.txt": "same", "remove.txt": "bye"}) + _populate(right, {"keep.txt": "same", "add.txt": "hi"}) + summary = diff_dirs_summary(str(left), str(right)) + assert summary == {"added": ["add.txt"], "removed": ["remove.txt"], "changed": []} + + +def test_apply_text_patch_roundtrip(tmp_path: Path) -> None: + before = tmp_path / "before.txt" + after = tmp_path / "after.txt" + target = tmp_path / "live.txt" + before.write_text("one\ntwo\nthree\n", encoding="utf-8") + after.write_text("one\nTWO\nthree\nfour\n", encoding="utf-8") + target.write_text(before.read_text(encoding="utf-8"), encoding="utf-8") + patch = diff_text_files(before, after) + apply_text_patch(str(target), patch) + assert target.read_text(encoding="utf-8") == after.read_text(encoding="utf-8") + + +def test_apply_text_patch_detects_mismatch(tmp_path: Path) -> None: + before = tmp_path / "before.txt" + after = tmp_path / "after.txt" + target = tmp_path / "live.txt" + before.write_text("one\ntwo\n", encoding="utf-8") + after.write_text("one\ntwo\nthree\n", encoding="utf-8") + target.write_text("one\nDIVERGED\n", encoding="utf-8") + patch = diff_text_files(before, after) + with pytest.raises(DiffException): + apply_text_patch(str(target), patch) + + +def test_diff_ops_registered() -> None: + registry = build_default_registry() + for name in ("FA_diff_files", "FA_diff_dirs", "FA_apply_patch"): + assert name in registry diff --git a/tests/test_onedrive_ops.py b/tests/test_onedrive_ops.py new file mode 100644 index 0000000..09c1101 --- /dev/null +++ b/tests/test_onedrive_ops.py @@ -0,0 +1,192 @@ +"""OneDrive backend tests. + +Live Microsoft Graph endpoints are outside CI; these tests verify registry +wiring, guard clauses, and that ``graph_request`` surfaces non-2xx +responses as :class:`OneDriveException`. +""" + +# pylint: disable=protected-access # fakes inject _session / _access_token directly + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest +import requests + +from automation_file import ( + OneDriveClient, + OneDriveException, + build_default_registry, + onedrive_instance, + register_onedrive_ops, +) +from automation_file.core.action_registry import ActionRegistry +from automation_file.exceptions import FileNotExistsException +from automation_file.remote.onedrive import delete_ops, download_ops, list_ops, upload_ops + + +class _FakeResponse: + def __init__(self, status: int = 200, payload: Any = None, content: bytes = b"") -> None: + self.status_code = status + self.text = "" if status < 400 else "fake-error" + self.content = content + self._payload = payload or {} + self.ok = status < 400 + + def json(self) -> Any: + return self._payload + + +class _FakeSession: + def __init__(self, responder: Any) -> None: + self.headers: dict[str, str] = {} + self._responder = responder + self.calls: list[dict[str, Any]] = [] + + def request(self, method: str, url: str, **kwargs: Any) -> _FakeResponse: + self.calls.append({"method": method, "url": url, **kwargs}) + return self._responder(method, url, **kwargs) + + def close(self) -> None: + return None + + +@pytest.fixture(name="fake_client") +def _fake_client(monkeypatch: pytest.MonkeyPatch) -> OneDriveClient: + client = OneDriveClient() + + def responder(method: str, url: str, **_kwargs: Any) -> _FakeResponse: + if method == "GET" and "/children" in url: + return _FakeResponse( + 200, + payload={ + "value": [ + {"name": "a.txt", "size": 4}, + {"name": "dir", "folder": {}}, + ] + }, + ) + if method == "GET" and "/content" in url: + return _FakeResponse(200, content=b"hello") + if method == "PUT": + return _FakeResponse(201, payload={"id": "abc"}) + if method == "DELETE": + return _FakeResponse(204) + return _FakeResponse(404) + + session = _FakeSession(responder) + # Not a credential — placeholder marker used only by the fake session. + fake_token = "fake-token" # nosec B105 + client._session = session # test injection + client._access_token = fake_token + monkeypatch.setattr(onedrive_instance, "_session", session, raising=False) + monkeypatch.setattr(onedrive_instance, "_access_token", fake_token, raising=False) + return client + + +def test_require_session_raises_when_not_initialised() -> None: + client = OneDriveClient() + with pytest.raises(OneDriveException): + client.require_session() + + +def test_later_init_rejects_empty_token() -> None: + client = OneDriveClient() + with pytest.raises(OneDriveException): + client.later_init("") + + +def test_default_registry_contains_onedrive() -> None: + registry = build_default_registry() + assert "FA_onedrive_upload_file" in registry + assert "FA_onedrive_device_code_login" in registry + + +def test_register_onedrive_ops_adds_entries() -> None: + registry = ActionRegistry() + register_onedrive_ops(registry) + for name in ( + "FA_onedrive_later_init", + "FA_onedrive_upload_file", + "FA_onedrive_upload_dir", + "FA_onedrive_download_file", + "FA_onedrive_delete_item", + "FA_onedrive_list_folder", + "FA_onedrive_close", + ): + assert name in registry + + +def test_upload_rejects_missing_source(tmp_path: Path, fake_client: OneDriveClient) -> None: + del fake_client + with pytest.raises(FileNotExistsException): + upload_ops.onedrive_upload_file(str(tmp_path / "gone.txt"), "x.txt") + + +def test_upload_rejects_oversize(tmp_path: Path, fake_client: OneDriveClient) -> None: + del fake_client + big = tmp_path / "big.bin" + big.write_bytes(b"\0" * (4 * 1024 * 1024 + 1)) + with pytest.raises(OneDriveException): + upload_ops.onedrive_upload_file(str(big), "x.bin") + + +def test_upload_roundtrip(tmp_path: Path, fake_client: OneDriveClient) -> None: + src = tmp_path / "hello.txt" + src.write_text("hi", encoding="utf-8") + assert upload_ops.onedrive_upload_file(str(src), "dest/hello.txt") is True + last = fake_client._session.calls[-1] + assert last["method"] == "PUT" + assert last["data"] == b"hi" + + +def test_download_writes_target(tmp_path: Path, fake_client: OneDriveClient) -> None: + del fake_client + target = tmp_path / "out" / "hi.bin" + assert download_ops.onedrive_download_file("hi.bin", str(target)) is True + assert target.read_bytes() == b"hello" + + +def test_list_folder_returns_entries(fake_client: OneDriveClient) -> None: + del fake_client + entries = list_ops.onedrive_list_folder() + assert {entry["type"] for entry in entries} == {"file", "folder"} + + +def test_delete_hits_graph(fake_client: OneDriveClient) -> None: + assert delete_ops.onedrive_delete_item("dir/file.txt") is True + last = fake_client._session.calls[-1] + assert last["method"] == "DELETE" + + +def test_graph_request_raises_on_http_error(fake_client: OneDriveClient) -> None: + del fake_client + # Replace session with one that always returns 500. + err_session = _FakeSession(lambda *_a, **_k: _FakeResponse(status=500)) + onedrive_instance._session = err_session + with pytest.raises(OneDriveException): + onedrive_instance.graph_request("GET", "/me/drive/root") + + +def test_graph_request_wraps_requests_exception( + monkeypatch: pytest.MonkeyPatch, fake_client: OneDriveClient +) -> None: + del fake_client + + def blow_up(*_a: Any, **_k: Any) -> None: + raise requests.ConnectionError("cannot reach host") + + err_session = _FakeSession(blow_up) + monkeypatch.setattr(onedrive_instance, "_session", err_session) + with pytest.raises(OneDriveException): + onedrive_instance.graph_request("GET", "/me/drive/root") + + +def test_close_tears_down() -> None: + client = OneDriveClient() + client.later_init("abc") + assert client.close() is True + with pytest.raises(OneDriveException): + client.require_session() diff --git a/tests/test_text_ops.py b/tests/test_text_ops.py new file mode 100644 index 0000000..f383508 --- /dev/null +++ b/tests/test_text_ops.py @@ -0,0 +1,146 @@ +"""Tests for automation_file.local.text_ops.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from automation_file import ( + TextOpsException, + build_default_registry, + encoding_convert, + file_merge, + file_split, + line_count, + sed_replace, +) +from automation_file.exceptions import FileNotExistsException + + +def test_file_split_produces_ordered_parts(tmp_path: Path) -> None: + source = tmp_path / "payload.bin" + source.write_bytes(b"0123456789abcdef") + parts = file_split(str(source), chunk_size=5) + assert [Path(p).name for p in parts] == [ + "payload.bin.part000", + "payload.bin.part001", + "payload.bin.part002", + "payload.bin.part003", + ] + assert Path(parts[0]).read_bytes() == b"01234" + assert Path(parts[-1]).read_bytes() == b"f" + + +def test_file_split_respects_output_dir(tmp_path: Path) -> None: + source = tmp_path / "a.txt" + source.write_bytes(b"hello world") + dest = tmp_path / "parts" + parts = file_split(str(source), chunk_size=4, output_dir=str(dest)) + assert all(Path(p).parent == dest for p in parts) + + +def test_file_split_rejects_non_positive_chunk(tmp_path: Path) -> None: + source = tmp_path / "empty.bin" + source.write_bytes(b"x") + with pytest.raises(TextOpsException): + file_split(str(source), chunk_size=0) + + +def test_file_split_rejects_missing_source(tmp_path: Path) -> None: + with pytest.raises(FileNotExistsException): + file_split(str(tmp_path / "missing"), chunk_size=10) + + +def test_file_merge_roundtrip(tmp_path: Path) -> None: + source = tmp_path / "big.bin" + source.write_bytes(b"abcdefghijklmno") + parts = file_split(str(source), chunk_size=3) + merged = tmp_path / "rebuilt.bin" + assert file_merge(parts, str(merged)) is True + assert merged.read_bytes() == source.read_bytes() + + +def test_file_merge_rejects_missing_part(tmp_path: Path) -> None: + with pytest.raises(FileNotExistsException): + file_merge([str(tmp_path / "gone.part000")], str(tmp_path / "out")) + + +def test_file_merge_rejects_empty_parts(tmp_path: Path) -> None: + with pytest.raises(TextOpsException): + file_merge([], str(tmp_path / "out")) + + +def test_encoding_convert_utf8_to_latin1(tmp_path: Path) -> None: + source = tmp_path / "a.txt" + source.write_text("café", encoding="utf-8") + target = tmp_path / "b.txt" + encoding_convert(str(source), str(target), "utf-8", "latin-1") + assert target.read_text(encoding="latin-1") == "café" + + +def test_encoding_convert_reports_bad_mapping(tmp_path: Path) -> None: + source = tmp_path / "a.txt" + source.write_text("hello", encoding="utf-8") + target = tmp_path / "b.txt" + with pytest.raises(TextOpsException): + encoding_convert(str(source), str(target), "not-a-real-codec", "utf-8") + + +def test_line_count_counts_newlines(tmp_path: Path) -> None: + source = tmp_path / "lines.txt" + source.write_text("one\ntwo\nthree\n", encoding="utf-8") + assert line_count(str(source)) == 3 + + +def test_line_count_handles_no_trailing_newline(tmp_path: Path) -> None: + source = tmp_path / "lines.txt" + source.write_text("one\ntwo", encoding="utf-8") + assert line_count(str(source)) == 2 + + +def test_sed_replace_literal(tmp_path: Path) -> None: + source = tmp_path / "t.txt" + source.write_text("aaa bbb aaa", encoding="utf-8") + assert sed_replace(str(source), "aaa", "XXX") == 2 + assert source.read_text(encoding="utf-8") == "XXX bbb XXX" + + +def test_sed_replace_regex_with_backref(tmp_path: Path) -> None: + source = tmp_path / "t.txt" + source.write_text("hello world", encoding="utf-8") + assert sed_replace(str(source), r"(\w+) (\w+)", r"\2 \1", regex=True) == 1 + assert source.read_text(encoding="utf-8") == "world hello" + + +def test_sed_replace_respects_count(tmp_path: Path) -> None: + source = tmp_path / "t.txt" + source.write_text("a a a a", encoding="utf-8") + assert sed_replace(str(source), "a", "b", count=2) == 2 + assert source.read_text(encoding="utf-8") == "b b a a" + + +def test_sed_replace_rejects_empty_literal_pattern(tmp_path: Path) -> None: + source = tmp_path / "t.txt" + source.write_text("abc", encoding="utf-8") + with pytest.raises(TextOpsException): + sed_replace(str(source), "", "x") + + +def test_sed_replace_rejects_bad_regex(tmp_path: Path) -> None: + source = tmp_path / "t.txt" + source.write_text("abc", encoding="utf-8") + with pytest.raises(TextOpsException): + sed_replace(str(source), "[unclosed", "x", regex=True) + + +def test_text_ops_registered() -> None: + registry = build_default_registry() + for name in ( + "FA_file_split", + "FA_file_merge", + "FA_encoding_convert", + "FA_line_count", + "FA_sed_replace", + ): + assert name in registry diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 0000000..674ebef --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,120 @@ +"""Tests for automation_file.core.tracing.""" + +# pylint: disable=protected-access # tests probe tracing._state / _shutdown_for_tests directly + +from __future__ import annotations + +from typing import Any + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult + +from automation_file import ( + action_span, + execute_action, + executor, + init_tracing, +) +from automation_file.core import tracing +from automation_file.exceptions import TracingException + + +class _CapturingExporter(SpanExporter): + """Collect spans in memory so tests can assert on them.""" + + def __init__(self) -> None: + self.spans: list[ReadableSpan] = [] + self._shutdown = False + + def export(self, spans: Any) -> Any: + if not self._shutdown: + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + self._shutdown = True + + +@pytest.fixture(name="exporter", scope="module") +def _exporter() -> Any: + """Initialise tracing once per module. + + OpenTelemetry's ``trace.set_tracer_provider`` is one-shot per process, + so we can't re-initialise between tests. A module-scoped fixture keeps + all tracing tests sharing the same exporter. + """ + # Force a clean state on entry in case another test suite touched tracing. + tracing._shutdown_for_tests() + exporter = _CapturingExporter() + init_tracing("test-service", exporter=exporter) + yield exporter + provider = trace.get_tracer_provider() + shutdown = getattr(provider, "shutdown", None) + if callable(shutdown): + shutdown() # pylint: disable=not-callable # narrowed by callable() above + + +def _flush() -> None: + """Force any pending batch-exported spans out to the exporter.""" + provider = trace.get_tracer_provider() + force_flush = getattr(provider, "force_flush", None) + if callable(force_flush): + force_flush() # pylint: disable=not-callable # narrowed by callable() above + + +def test_is_initialised_true_after_fixture(exporter: _CapturingExporter) -> None: + del exporter # fixture side effect only + assert tracing.is_initialised() is True + + +def test_action_span_records_attributes(exporter: _CapturingExporter) -> None: + before = len(exporter.spans) + with action_span("probe", {"answer": 42}): + # Body is intentionally a trivial op — the test asserts on the span + # the context manager emits, not on any computation inside. + _ = trace.get_current_span() + _flush() + new_spans = exporter.spans[before:] + assert any(s.name == "automation_file.action" for s in new_spans) + probe = next(s for s in new_spans if s.name == "automation_file.action") + assert probe.attributes is not None + assert probe.attributes["fa.action"] == "probe" + assert probe.attributes["answer"] == 42 + + +def test_init_tracing_returns_false_on_second_call(exporter: _CapturingExporter) -> None: + del exporter + assert init_tracing("svc") is False + + +def test_executor_wraps_actions_in_span(exporter: _CapturingExporter) -> None: + before = len(exporter.spans) + executor.registry.register("test_traced_echo", lambda value: value) + execute_action([["test_traced_echo", {"value": "hi"}]]) + _flush() + action_names = [ + span.attributes.get("fa.action") + for span in exporter.spans[before:] + if span.attributes is not None + ] + assert "test_traced_echo" in action_names + + +def test_action_span_noop_when_uninitialised() -> None: + # Capture current state, drop to uninitialised, verify no-op, restore. + previous = tracing._state["initialised"] + tracing._state["initialised"] = False + try: + with action_span("probe"): + # Must not raise — the context manager has to stay a cheap no-op + # when tracing is switched off. + assert tracing.is_initialised() is False + assert tracing.is_initialised() is False + finally: + tracing._state["initialised"] = previous + + +def test_tracing_exception_is_exported() -> None: + assert issubclass(TracingException, Exception) diff --git a/tests/test_ui_smoke.py b/tests/test_ui_smoke.py index c890152..bf5164e 100644 --- a/tests/test_ui_smoke.py +++ b/tests/test_ui_smoke.py @@ -52,6 +52,8 @@ def test_main_window_constructs(qt_app) -> None: "AzureBlobTab", "DropboxTab", "SFTPTab", + "OneDriveTab", + "BoxTab", "JSONEditorTab", "ServerTab", "TransferTab",