From c2ec0d638aa0f51a664365830783105c67a27492 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Wed, 21 Jan 2026 08:59:10 -0800 Subject: [PATCH 1/5] add audit and validator components --- src/cdm_data_loader_utils/audit/checkpoint.py | 114 ++++++++++++ src/cdm_data_loader_utils/audit/metrics.py | 97 ++++++++++ src/cdm_data_loader_utils/audit/rejects.py | 80 +++++++++ src/cdm_data_loader_utils/audit/run.py | 117 ++++++++++++ .../validation/__init__.py | 0 .../validation/dataframe_validator.py | 88 ++++++++++ .../validation/df_nullable_fields.py | 61 +++++++ .../validation/validation_result.py | 16 ++ tests/audit/__init__.py | 0 tests/audit/conftest.py | 118 +++++++++++++ tests/audit/test_checkpoint.py | 166 ++++++++++++++++++ tests/audit/test_metrics.py | 95 ++++++++++ tests/audit/test_rejects.py | 156 ++++++++++++++++ tests/audit/test_run.py | 161 +++++++++++++++++ tests/audit/test_schema.py | 32 ++++ .../readers/test_dsv_read_with_validation.py | 144 +++++++++++++++ tests/validation/__init__.py | 0 tests/validation/test_dataframe_validator.py | 82 +++++++++ tests/validation/test_df_nullable_fields.py | 71 ++++++++ 19 files changed, 1598 insertions(+) create mode 100644 src/cdm_data_loader_utils/audit/checkpoint.py create mode 100644 src/cdm_data_loader_utils/audit/metrics.py create mode 100644 src/cdm_data_loader_utils/audit/rejects.py create mode 100644 src/cdm_data_loader_utils/audit/run.py create mode 100644 src/cdm_data_loader_utils/validation/__init__.py create mode 100644 src/cdm_data_loader_utils/validation/dataframe_validator.py create mode 100644 src/cdm_data_loader_utils/validation/df_nullable_fields.py create mode 100644 src/cdm_data_loader_utils/validation/validation_result.py create mode 100644 tests/audit/__init__.py create mode 100644 tests/audit/conftest.py create mode 100644 tests/audit/test_checkpoint.py create mode 100644 tests/audit/test_metrics.py create mode 100644 tests/audit/test_rejects.py create mode 100644 tests/audit/test_run.py create mode 100644 tests/audit/test_schema.py create mode 100644 tests/readers/test_dsv_read_with_validation.py create mode 100644 tests/validation/__init__.py create mode 100644 tests/validation/test_dataframe_validator.py create mode 100644 tests/validation/test_df_nullable_fields.py diff --git a/src/cdm_data_loader_utils/audit/checkpoint.py b/src/cdm_data_loader_utils/audit/checkpoint.py new file mode 100644 index 0000000..ddc00f3 --- /dev/null +++ b/src/cdm_data_loader_utils/audit/checkpoint.py @@ -0,0 +1,114 @@ +"""Checkpoint audit table functions: adding and updating information on data import pipeline execution.""" + +from delta.tables import DeltaTable +from pyspark.sql import SparkSession +from pyspark.sql import functions as sf + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + CHECKPOINT, + LAST_ENTRY_ID, + PIPELINE, + RECORDS_PROCESSED, + RUN_ID, + SOURCE, + STATUS, + STATUS_RUNNING, + UPDATED, + current_run_expr, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger + + +# Checkpoint table-related functions +def upsert_checkpoint( + spark: SparkSession, + run: PipelineRun, + last_entry_id: str, + records_processed: int, +) -> None: + """Add or update checkpoint records for the current pipeline ingest. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + :param last_entry_id: ID of the last entry processed + :type last_entry_id: str + :param records_processed: number of entries processed + :type records_processed: int + """ + delta = DeltaTable.forName(spark, f"{run.namespace}.{CHECKPOINT}") + + df = spark.range(1).select( + sf.lit(run.run_id).alias(RUN_ID), + sf.lit(run.pipeline).alias(PIPELINE), + sf.lit(run.source_path).alias(SOURCE), + sf.lit(STATUS_RUNNING).alias(STATUS), + sf.lit(records_processed).alias(RECORDS_PROCESSED), + sf.lit(last_entry_id).alias(LAST_ENTRY_ID), + sf.current_timestamp().alias(UPDATED), + ) + updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[CHECKPOINT]) + + ( + delta.alias("t") + .merge(updates.alias("s"), current_run_expr()) + .whenMatchedUpdate(set={val: f"s.{val}" for val in [STATUS, RECORDS_PROCESSED, LAST_ENTRY_ID, UPDATED]}) + .whenNotMatchedInsertAll() + .execute() + ) + get_cdm_logger().info("%s %s: checkpoint created/updated", run.pipeline, run.run_id) + + +def update_checkpoint_status(spark: SparkSession, run: PipelineRun, status: str) -> None: + """Update checkpoint status and timestamp. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + :param status: pipeline status + :type status: str + """ + delta = DeltaTable.forName(spark, f"{run.namespace}.{CHECKPOINT}") + delta.update( + " AND ".join([f"{p} = '{getattr(run, p)}'" for p in [RUN_ID, PIPELINE, SOURCE]]), + {STATUS: sf.lit(status), UPDATED: sf.current_timestamp()}, + ) + # check whether rows were updated by looking in the delta log + # N.b. this may not work correctly if another process updates the table in the interim + metrics = delta.history(1).select("operationMetrics").collect()[0][0] + if int(metrics.get("numUpdatedRows", 0)) == 0: + get_cdm_logger().warning( + "%s %s: cannot update '%s' to status %s because no record exists.", + run.pipeline, + run.run_id, + CHECKPOINT, + status, + ) + else: + get_cdm_logger().info("%s %s: checkpoint successfully updated to status %s", run.pipeline, run.run_id, status) + + +def load_checkpoint(spark: SparkSession, run: PipelineRun) -> str | None: + """Load any existing checkpoint data, filtered by current pipeline and data source path. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + :return: either the last entry successfully saved and dumped to disk, or None + :rtype: str | None + """ + rows = ( + spark.table(f"{run.namespace}.{CHECKPOINT}") + .filter(sf.col(RUN_ID) == run.run_id) + .filter(sf.col(PIPELINE) == run.pipeline) + .filter(sf.col(SOURCE) == run.source_path) + .select(LAST_ENTRY_ID) + .limit(1) + .collect() + ) + return rows[0][LAST_ENTRY_ID] if rows else None diff --git a/src/cdm_data_loader_utils/audit/metrics.py b/src/cdm_data_loader_utils/audit/metrics.py new file mode 100644 index 0000000..ef6b35e --- /dev/null +++ b/src/cdm_data_loader_utils/audit/metrics.py @@ -0,0 +1,97 @@ +"""Audit table for recording metrics.""" + +from delta.tables import DeltaTable +from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql import functions as sf + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + METRICS, + N_INVALID, + N_READ, + N_VALID, + PIPELINE, + ROW_ERRORS, + RUN_ID, + SOURCE, + UPDATED, + VALIDATION_ERRORS, + current_run_expr, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger + +logger = get_cdm_logger() + + +def write_metrics( + spark: SparkSession, + annotated_df: DataFrame, + run: PipelineRun, +) -> Row: + """Write metrics for the current batch of imports to disk. + + :param spark: spark sesh + :type spark: SparkSession + :param run: current pipeline run + :type run: PipelineRun + :param records_read: total number of records read + :type records_read: int + :param records_valid: number of valid records + :type records_valid: int + :param records_invalid: number of invalid records + :type records_invalid: int + :param validation_errors: list of validation errors encountered in the batch + :type validation_errors: list[str] + :return: row of a dataframe with validation metrics + :rtype: Row + """ + if annotated_df.count() == 0: + # nothing to do here + logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, METRICS) + return Row(records_read=0, records_valid=0, records_invalid=0, validation_errors=[]) + + invalid_df = annotated_df.filter(sf.size(ROW_ERRORS) > 0) + + validation_errors = sorted( + [r.reason for r in invalid_df.select(sf.explode(ROW_ERRORS).alias("reason")).distinct().collect()] + ) + + metrics_df = annotated_df.agg( + sf.count("*").alias(N_READ), + sf.sum(sf.when(sf.size(ROW_ERRORS) == 0, 1).otherwise(0)).alias(N_VALID), + sf.sum(sf.when(sf.size(ROW_ERRORS) > 0, 1).otherwise(0)).alias(N_INVALID), + sf.lit(validation_errors).alias(VALIDATION_ERRORS), + ) + metrics = metrics_df.collect()[0] + + df = spark.range(1).select( + sf.lit(run.run_id).alias(RUN_ID), + sf.lit(run.pipeline).alias(PIPELINE), + sf.lit(run.source_path).alias(SOURCE), + sf.lit(metrics.records_read).alias(N_READ), + sf.lit(metrics.records_valid).alias(N_VALID), + sf.lit(metrics.records_invalid).alias(N_INVALID), + sf.lit(metrics.validation_errors).alias(VALIDATION_ERRORS), + sf.current_timestamp().alias(UPDATED), + ) + updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[METRICS]) + + target = DeltaTable.forName( + spark, + f"{run.namespace}.{METRICS}", + ) + + ( + target.alias("t") + .merge( + updates.alias("s"), + current_run_expr(), + ) + .whenMatchedUpdate(set={k: f"s.{k}" for k in [N_READ, N_VALID, N_INVALID, VALIDATION_ERRORS, UPDATED]}) + .whenNotMatchedInsertAll() + .execute() + ) + get_cdm_logger().info("%s %s: ingest metrics written to '%s' table.", run.pipeline, run.run_id, METRICS) + + return metrics diff --git a/src/cdm_data_loader_utils/audit/rejects.py b/src/cdm_data_loader_utils/audit/rejects.py new file mode 100644 index 0000000..3a42100 --- /dev/null +++ b/src/cdm_data_loader_utils/audit/rejects.py @@ -0,0 +1,80 @@ +"""Audit table for recording data rejected as invalid during ingest.""" + +import pyspark.sql.functions as sf +from pyspark.sql import DataFrame +from pyspark.sql.types import StructField + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + PARSED_ROW, + PIPELINE, + RAW_ROW, + REJECTS, + ROW_ERRORS, + RUN_ID, + SOURCE, + TIMESTAMP, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger + +logger = get_cdm_logger() + + +def write_rejects( + run: PipelineRun, + annotated_df: DataFrame, + schema_fields: list[StructField], + invalid_col: str, +) -> None: + """Write rejected data to the rejects audit table. + + This should plug in directly to readers like the spark CSV reader, which put any non-compliant data into + a single column when run in PERMISSIVE mode (default for the cdm_data_loader_utils readers). + + It is expected that the dataframe will contain a column called ROW_ERRORS, which contains a list of strings + describing the errors found in the rows. + + :param run: current pipeline run + :type run: PipelineRun + :param annotated_df: dataframe with errors to be written out + :type annotated_df: DataFrame + :param schema: schema of the dataframe -- s1hould be the unamended version without a column for invalid data + :type schema: StructType + :param invalid_col: name of the column with the invalid data in it + :type invalid_col: str + """ + if ROW_ERRORS not in annotated_df.columns: + err_msg = f"{run.pipeline} {run.run_id}: '{ROW_ERRORS}' column not present in dataframe; cannot record rejects." + logger.error(err_msg) + raise RuntimeError(err_msg) + + if annotated_df.count() == 0: + logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) + return + + invalid_df: DataFrame = annotated_df.filter(sf.size(ROW_ERRORS) > 0) + if invalid_df.count() == 0: + # nothing to do here + logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) + return + + data_fields = [f.name for f in schema_fields] + + rejects_df = invalid_df.select( + sf.lit(run.run_id).alias(RUN_ID), + sf.lit(run.pipeline).alias(PIPELINE), + sf.lit(run.source_path).alias(SOURCE), + sf.col(invalid_col).alias(RAW_ROW), + sf.to_json(sf.struct(*[sf.col(c) for c in data_fields])).alias(PARSED_ROW), + sf.col(ROW_ERRORS), + sf.current_timestamp().alias(TIMESTAMP), + ) + # ensure that it conforms to the schema + rejects_df = rejects_df.select( + *[sf.col(f.name).cast(f.dataType).alias(f.name) for f in AUDIT_SCHEMA[REJECTS].fields] + ) + # write to disk + rejects_df.write.format("delta").mode("append").saveAsTable(f"{run.namespace}.{REJECTS}") + + get_cdm_logger().info("%s %s: invalid rows written to '%s' audit table.", run.pipeline, run.run_id, REJECTS) diff --git a/src/cdm_data_loader_utils/audit/run.py b/src/cdm_data_loader_utils/audit/run.py new file mode 100644 index 0000000..4b7007f --- /dev/null +++ b/src/cdm_data_loader_utils/audit/run.py @@ -0,0 +1,117 @@ +"""Run audit table functions: additions and updates to the run table, which tracks overall run status for a pipeline.""" + +from delta.tables import DeltaTable +from pyspark.sql import SparkSession +from pyspark.sql import functions as sf + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + END_TIME, + ERROR, + PIPELINE, + RECORDS_PROCESSED, + RUN, + RUN_ID, + SOURCE, + START_TIME, + STATUS, + STATUS_ERROR, + STATUS_RUNNING, + STATUS_SUCCESS, + match_run, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger + + +def _table_not_updated(delta: DeltaTable) -> bool: + metrics = delta.history(1).select("operationMetrics").collect()[0][0] + # check the number of updated rows, returning True if it is more than 0 + return int(metrics.get("numUpdatedRows", 0)) == 0 + + +def start_run(spark: SparkSession, run: PipelineRun) -> None: + """Write to the RUN table to indicate that a new ingestion run has started. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + """ + df = spark.range(1).select( + sf.lit(run.run_id).alias(RUN_ID), + sf.lit(run.pipeline).alias(PIPELINE), + sf.lit(run.source_path).alias(SOURCE), + sf.lit(STATUS_RUNNING).alias(STATUS), + sf.lit(sf.lit(0).cast("long")).alias(RECORDS_PROCESSED), + sf.lit(sf.current_timestamp()).alias(START_TIME), + sf.lit(sf.lit(None).cast("timestamp")).alias(END_TIME), + sf.lit(None).cast("string").alias(ERROR), + ) + + spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[RUN]).write.format("delta").mode("append").saveAsTable( + f"{run.namespace}.{RUN}" + ) + + +def complete_run(spark: SparkSession, run: PipelineRun, records_processed: int) -> None: + """Write to the RUN table to indicate that the ingestion run has completed. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + :param records_processed: number of records parsed and saved to disk + :type records_processed: int + """ + delta: DeltaTable = DeltaTable.forName(spark, f"{run.namespace}.{RUN}") + + delta.update( + match_run(run), + { + STATUS: sf.lit(STATUS_SUCCESS), + END_TIME: sf.current_timestamp(), + RECORDS_PROCESSED: sf.lit(records_processed).cast("long"), + }, + ) + # check whether rows were updated by looking in the delta log + if _table_not_updated(delta): + get_cdm_logger().warning( + "%s %s: cannot update '%s' to status %s because no record exists.", + run.pipeline, + run.run_id, + RUN, + STATUS_SUCCESS, + ) + get_cdm_logger().info("%s %s: run completed", run.pipeline, run.run_id) + + +def fail_run(spark: SparkSession, run: PipelineRun, error: Exception) -> None: + """Write to the RUN table to indicate that an error has occurred. + + :param spark: spark sesh + :type spark: SparkSession + :param run: pipeline run + :type run: PipelineRun + :param error: error object thrown during ingestion. + :type error: Exception + """ + delta = DeltaTable.forName(spark, f"{run.namespace}.{RUN}") + delta.update( + match_run(run), + set={ + END_TIME: sf.current_timestamp(), + STATUS: sf.lit(STATUS_ERROR), + ERROR: sf.lit(str(error)[:1000]), + }, + ) + # check whether rows were updated by looking in the delta log + if _table_not_updated(delta): + get_cdm_logger().warning( + "%s %s: cannot update '%s' to status %s because no record exists.", + run.pipeline, + run.run_id, + RUN, + STATUS_ERROR, + ) + get_cdm_logger().error("%s %s: run failed with %s", run.pipeline, run.run_id, error.__repr__()) diff --git a/src/cdm_data_loader_utils/validation/__init__.py b/src/cdm_data_loader_utils/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cdm_data_loader_utils/validation/dataframe_validator.py b/src/cdm_data_loader_utils/validation/dataframe_validator.py new file mode 100644 index 0000000..5e0243b --- /dev/null +++ b/src/cdm_data_loader_utils/validation/dataframe_validator.py @@ -0,0 +1,88 @@ +"""Class providing an interface to validation with error and metrics auditing.""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import functions as sf +from pyspark.sql.types import StructField + +from cdm_data_loader_utils.audit.metrics import write_metrics +from cdm_data_loader_utils.audit.rejects import write_rejects +from cdm_data_loader_utils.audit.schema import ROW_ERRORS +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger +from cdm_data_loader_utils.validation.validation_result import ValidationResult + +logger = get_cdm_logger() + + +@dataclass +class Validator: + """Base validator dataclass.""" + + validation_fn: Callable + args: dict[str, Any] + + +class DataFrameValidator: + """Class for validating data.""" + + def __init__(self, spark: SparkSession) -> None: + """Instantiate an IngestValidator. + + :param spark: spark sesh + :type spark: SparkSession + """ + self.spark = spark + + def validate_dataframe( + self, + data_to_validate: DataFrame, + schema: list[StructField], + run: PipelineRun, + validator: Validator, + invalid_col: str, + ) -> ValidationResult: + """Validate a dataframe, outputting a ValidationResult. + + :param data_to_validate: dataframe to be validated + :type data_to_validate: DataFrame + :param schema: schema of the fields in the dataframe, expressed as a list of StructFields + :type schema: list[StructField] + :param run: the current pipeline run + :type run: PipelineRun + :param validation_fn: function for validating the dataframe + :type validation_fn: Callable + :return: data validation metrics and the valid data from the input dataframe + :rtype: ValidationResult + """ + if data_to_validate.count() == 0: + logger.warning("%s %s: no data found to validate. Aborting.") + return ValidationResult( + valid_df=data_to_validate, + records_read=0, + records_valid=0, + records_invalid=0, + validation_errors=[], + ) + + # running the validator should produce a dataframe with a column called ROW_ERRORS + annotated_df: DataFrame = validator.validation_fn(data_to_validate, schema, **validator.args) + valid_df: DataFrame = annotated_df.filter(sf.size(ROW_ERRORS) == 0) + write_rejects( + run=run, + annotated_df=annotated_df, + schema_fields=schema, + invalid_col=invalid_col, + ) + metrics = write_metrics(self.spark, annotated_df, run) + + return ValidationResult( + valid_df=valid_df.drop(ROW_ERRORS).drop(invalid_col), + records_read=metrics.records_read, + records_valid=metrics.records_valid, + records_invalid=metrics.records_invalid, + validation_errors=metrics.validation_errors, + ) diff --git a/src/cdm_data_loader_utils/validation/df_nullable_fields.py b/src/cdm_data_loader_utils/validation/df_nullable_fields.py new file mode 100644 index 0000000..b885a7f --- /dev/null +++ b/src/cdm_data_loader_utils/validation/df_nullable_fields.py @@ -0,0 +1,61 @@ +"""Simple validator for checking for null columns in a dataframe.""" + +from pyspark.sql import DataFrame +from pyspark.sql import functions as sf +from pyspark.sql.types import StructField + +from cdm_data_loader_utils.audit.schema import ROW_ERRORS + +COLLECTED_ERRORS = "collected_errors" + + +def validate( + df: DataFrame, + schema_fields: list[StructField], + invalid_col: str, +) -> DataFrame: + """Validation function that ensures that nullability constraints in the schema are enforced. + + As of Jan 2026, Spark automagically converts the `nullable=False` attribute on StructFields to + True when a schema is applied to a dataframe. This function should be supplied with the original + schema, which it will use to check that null constraints are actually enforced! + + :param df: dataframe to test + :type df: DataFrame + :param schema_fields: schema, in the form of a list of SchemaFields + :type schema_fields: list[StructField] + :param invalid_col: name of the column where invalid data is stored + :type invalid_col: str + :return: df containing rejected data + :rtype: DataFrame + """ + # if a column is null but the schema says that it should not be, add the notation "missing_required:{col.name}" + missing_fields = [ + sf.when( + sf.col(col.name).isNull(), + sf.lit(f"missing_required: {col.name}"), + ) + for col in schema_fields + # required columns only + if not col.nullable + ] + + return ( + df.withColumn(COLLECTED_ERRORS, sf.array(*missing_fields)) + .withColumn( + ROW_ERRORS, + sf.when( + # from spark data ingestion: if the incoming row does not match the supplied schema, + # incorrect data will go in invalid_col + sf.col(invalid_col).isNotNull(), + # mark this as a parse error + sf.array(sf.lit("parse_error")), + ).otherwise( + sf.filter( + sf.col(COLLECTED_ERRORS), + lambda x: x.isNotNull(), + ) + ), + ) + .drop(COLLECTED_ERRORS) + ) diff --git a/src/cdm_data_loader_utils/validation/validation_result.py b/src/cdm_data_loader_utils/validation/validation_result.py new file mode 100644 index 0000000..250a5ce --- /dev/null +++ b/src/cdm_data_loader_utils/validation/validation_result.py @@ -0,0 +1,16 @@ +"""Dataclass for capturing validation results.""" + +from dataclasses import dataclass + +from pyspark.sql import DataFrame + + +@dataclass(frozen=True) +class ValidationResult: + """Dataclass for capturing validation results.""" + + valid_df: DataFrame + records_read: int + records_valid: int + records_invalid: int + validation_errors: list[str] diff --git a/tests/audit/__init__.py b/tests/audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/audit/conftest.py b/tests/audit/conftest.py new file mode 100644 index 0000000..6461f21 --- /dev/null +++ b/tests/audit/conftest.py @@ -0,0 +1,118 @@ +"""Tests for the checkpoint and run audit table functions.""" + +import datetime +from typing import Any + +from pyspark.sql import SparkSession + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + CHECKPOINT, + END_TIME, + ERROR, + LAST_ENTRY_ID, + METRICS, + N_INVALID, + N_READ, + N_VALID, + RECORDS_PROCESSED, + REJECTS, + RUN, + START_TIME, + STATUS, + TIMESTAMP, + UPDATED, + VALIDATION_ERRORS, +) +from tests.conftest import PIPELINE_RUN, TEST_NS + +DEFAULT_DATA = { + CHECKPOINT: { + **PIPELINE_RUN, + STATUS: "RUNNING", + RECORDS_PROCESSED: 1, + LAST_ENTRY_ID: "e1", + # UPDATED + }, + RUN: { + **PIPELINE_RUN, + STATUS: "RUNNING", + RECORDS_PROCESSED: 0, + END_TIME: None, + ERROR: None, + # START_TIME + }, + METRICS: { + **PIPELINE_RUN, + N_READ: 10, + N_VALID: 8, + N_INVALID: 2, + VALIDATION_ERRORS: ["a mess o' trouble", "another fine mess"], + # UPDATED + }, + REJECTS: { + **PIPELINE_RUN, + }, +} + +INIT_TIMESTAMP_FIELDS = {CHECKPOINT: [UPDATED], RUN: [START_TIME], METRICS: [UPDATED], REJECTS: [TIMESTAMP]} +END_TIMESTAMP_FIELDS = {RUN: [END_TIME]} + + +def create_table(spark: SparkSession, table_name: str, add_default_data: bool = False) -> list[dict[str, Any]]: # noqa: FBT001, FBT002 + """Helper for creating audit tables for tests.""" + # ensure the db exists + if table_name not in AUDIT_SCHEMA: + msg = f"Invalid audit table: {table_name}" + raise ValueError(msg) + + spark.sql(f"CREATE DATABASE IF NOT EXISTS {TEST_NS}") + + # add in default data, if there is any + starter_data = [] + if add_default_data and table_name in (CHECKPOINT, METRICS, RUN): + starter_dict = {**DEFAULT_DATA[table_name]} + # add in any fields that are set when the table is added + for f in INIT_TIMESTAMP_FIELDS.get(table_name, []): + starter_dict[f] = datetime.datetime.now(tz=datetime.UTC) + starter_data = [starter_dict] + + spark.createDataFrame(starter_data, AUDIT_SCHEMA[table_name]).write.format("delta").mode("error").saveAsTable( + f"{TEST_NS}.{table_name}" + ) + # validate + df = spark.table(f"{TEST_NS}.{table_name}") + if add_default_data: + assert df.count() == 1 + check_saved_data(spark, table_name, DEFAULT_DATA[table_name]) + else: + assert df.count() == 0 + + return [r.asDict() for r in df.collect()] + + +def check_saved_data(spark: SparkSession, table_name: str, expected: dict[str, Any]) -> list[dict[str, Any]]: + """Ensure the appropriate checkpoint data has been saved. + + :param spark: spark sesh + :type spark: SparkSession + :param table_name: the name of the table to check + :type table_name: str + :param expected: dictionary of expected data + :type expected: dict[str, Any] + :return: checkpoint data obj from the db + :rtype: dict[str, Any] + """ + df = spark.table(f"{TEST_NS}.{table_name}") + assert len(df.collect()) == 1 + df_row = df.collect()[0].asDict() + # check and then remove the timestamp fields + for f in INIT_TIMESTAMP_FIELDS.get(table_name, []): + assert isinstance(df_row[f], datetime.datetime) + del df_row[f] + # special case: end time may or may not be populated + # if table_name == RUN and + # the remaining fields should be checkable + assert df_row == expected + + return [r.asDict() for r in df.collect()] diff --git a/tests/audit/test_checkpoint.py b/tests/audit/test_checkpoint.py new file mode 100644 index 0000000..393c721 --- /dev/null +++ b/tests/audit/test_checkpoint.py @@ -0,0 +1,166 @@ +"""Tests for the checkpoint audit table functions.""" + +import datetime +import logging + +import pytest +from pyspark.sql import SparkSession + +from cdm_data_loader_utils.audit.checkpoint import load_checkpoint, update_checkpoint_status, upsert_checkpoint +from cdm_data_loader_utils.audit.schema import ( + CHECKPOINT, + LAST_ENTRY_ID, + PIPELINE, + RECORDS_PROCESSED, + RUN_ID, + SOURCE, + STATUS, + STATUS_ERROR, + STATUS_RUNNING, + UPDATED, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from tests.audit.conftest import DEFAULT_DATA, check_saved_data, create_table +from tests.conftest import PIPELINE_RUN, TEST_NS + + +@pytest.mark.requires_spark +def test_upsert_checkpoint_first_entry(spark: SparkSession, pipeline_run: PipelineRun) -> None: + """Add an entry to the checkpoint table.""" + create_table(spark, CHECKPOINT) + + df = spark.table(f"{TEST_NS}.{CHECKPOINT}") + assert len(df.collect()) == 0 + + # insert an entry + upsert_checkpoint( + spark, + pipeline_run, + last_entry_id=DEFAULT_DATA[CHECKPOINT][LAST_ENTRY_ID], + records_processed=DEFAULT_DATA[CHECKPOINT][RECORDS_PROCESSED], + ) + check_saved_data(spark, CHECKPOINT, DEFAULT_DATA[CHECKPOINT]) + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "some other pipeline"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/some/random/dir"]) +def test_upsert_checkpoint_update(spark: SparkSession, pipeline: str, source: str) -> None: + """Add an entry to the checkpoint table and then update it.""" + df_rows = create_table(spark, CHECKPOINT, add_default_data=True) + is_saved_table = pipeline == PIPELINE_RUN[PIPELINE] and source == PIPELINE_RUN[SOURCE] + pipeline_run = PipelineRun(PIPELINE_RUN[RUN_ID], pipeline, source, TEST_NS) + + n_recs = 5 + upsert_checkpoint( + spark, + pipeline_run, + last_entry_id="e2", + records_processed=n_recs, + ) + df = spark.table(f"{TEST_NS}.{CHECKPOINT}") + new_df_data = [r.asDict() for r in df.collect()] + + if is_saved_table: + # update to the current pipeline: checkpoint row should be updated + assert len(new_df_data) == 1 + row = new_df_data[0] + else: + # different pipeline and source: new checkpoint row. + assert len(new_df_data) == 2 + assert df_rows[0] in new_df_data + row = next(r for r in new_df_data if r[LAST_ENTRY_ID] == "e2") + + # check the new data + assert isinstance(row[UPDATED], datetime.datetime) + assert row[UPDATED] > df_rows[0][UPDATED] + assert row[LAST_ENTRY_ID] == "e2" + assert row[RECORDS_PROCESSED] == n_recs + assert row[STATUS] == STATUS_RUNNING + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "some other pipeline"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/path/to/dir"]) +@pytest.mark.parametrize("add_default_data", [True, False]) +def test_load_checkpoint(spark: SparkSession, pipeline: str, source: str, add_default_data: bool) -> None: # noqa: FBT001 + """Test loading a previously-saved checkpoint data.""" + is_saved_table = ( + add_default_data + and pipeline == DEFAULT_DATA[CHECKPOINT][PIPELINE] + and source == DEFAULT_DATA[CHECKPOINT][SOURCE] + ) + create_table(spark, CHECKPOINT, add_default_data=add_default_data) + pipeline_run = PipelineRun(PIPELINE_RUN[RUN_ID], pipeline, source, TEST_NS) + + last_entry = load_checkpoint(spark, pipeline_run) + if is_saved_table: + assert last_entry == DEFAULT_DATA[CHECKPOINT][LAST_ENTRY_ID] + else: + assert last_entry is None + + # add in a valid entry for the pipeline/data source of interest + upsert_checkpoint( + spark, + pipeline_run, + last_entry_id="some_entry", + records_processed=500, + ) + last_entry = load_checkpoint(spark, pipeline_run) + assert last_entry == "some_entry" + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "pipe_2"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/src/two"]) +@pytest.mark.parametrize("add_default_data", [True, False]) +def test_update_checkpoint_status( + spark: SparkSession, + pipeline: str, + source: str, + caplog: pytest.LogCaptureFixture, + add_default_data: bool, # noqa: FBT001 +) -> None: + """Test updating checkpoint status, with and without existing table data.""" + caplog.set_level(logging.INFO) + is_saved_table = ( + add_default_data + and pipeline == DEFAULT_DATA[CHECKPOINT][PIPELINE] + and source == DEFAULT_DATA[CHECKPOINT][SOURCE] + ) + create_table(spark, CHECKPOINT, add_default_data=add_default_data) + pipeline_run = PipelineRun(PIPELINE_RUN[RUN_ID], pipeline, source, TEST_NS) + + update_checkpoint_status( + spark, + pipeline_run, + status=STATUS_ERROR, + ) + + table_rows = spark.table(f"{TEST_NS}.{CHECKPOINT}").collect() + + # only if the pipeline run refers to the default data will the status have been updated + if is_saved_table: + assert len(table_rows) == 1 + assert table_rows[0][STATUS] == STATUS_ERROR + elif add_default_data: + # this row is from the default data; it won't have been updated + assert len(table_rows) == 1 + assert table_rows[0][STATUS] == DEFAULT_DATA[CHECKPOINT][STATUS] + else: + assert len(table_rows) == 0 + + # check log messages + assert len(caplog.records) == 1 + if is_saved_table: + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline} 1234-5678-90: checkpoint successfully updated to status {STATUS_ERROR}" + ) + else: + assert caplog.records[-1].levelno == logging.WARNING + assert ( + caplog.records[-1].message + == f"{pipeline} 1234-5678-90: cannot update 'checkpoint' to status ERROR because no record exists." + ) diff --git a/tests/audit/test_metrics.py b/tests/audit/test_metrics.py new file mode 100644 index 0000000..670d26d --- /dev/null +++ b/tests/audit/test_metrics.py @@ -0,0 +1,95 @@ +"""Tests for the metrics audit table functions.""" + +import datetime +import logging +from typing import Any + +import pytest +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructType + +from cdm_data_loader_utils.audit.metrics import write_metrics +from cdm_data_loader_utils.audit.schema import ( + METRICS, + N_INVALID, + N_READ, + N_VALID, + PIPELINE, + RUN_ID, + SOURCE, + VALIDATION_ERRORS, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from tests.audit.conftest import ( + DEFAULT_DATA, + INIT_TIMESTAMP_FIELDS, + check_saved_data, + create_table, +) +from tests.conftest import NAMESPACE, PIPELINE_RUN, TEST_NS + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "pipe_2"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/src/two"]) +@pytest.mark.parametrize("add_default_data", [True, False]) +def test_write_metrics( # noqa: PLR0913 + spark: SparkSession, + pipeline: str, + source: str, + add_default_data: bool, # noqa: FBT001 + caplog: pytest.LogCaptureFixture, + annotated_df_data: list[dict[str, Any]], + annotated_df_schema: StructType, + annotated_df_errors: set[str], +) -> None: + """Calculate the metrics for a given set of results and add it to the metrics table.""" + create_table(spark, METRICS, add_default_data=add_default_data) + is_saved_table = add_default_data and pipeline == PIPELINE_RUN[PIPELINE] and source == PIPELINE_RUN[SOURCE] + pipeline_args = {RUN_ID: PIPELINE_RUN[RUN_ID], PIPELINE: pipeline, SOURCE: source} + pipeline_run = PipelineRun(**{**pipeline_args, NAMESPACE: TEST_NS}) + + annotated_df = spark.createDataFrame(annotated_df_data, schema=annotated_df_schema) + metrics_out = {N_READ: 20, N_VALID: 4, N_INVALID: 16, VALIDATION_ERRORS: sorted(annotated_df_errors)} + + metrics = write_metrics(spark, annotated_df, pipeline_run) + assert [metrics.asDict()] == [metrics_out] + + df = spark.table(f"{pipeline_run.namespace}.{METRICS}") + if is_saved_table or not add_default_data: + # if we've either updated an existing row or added a row afresh, expect one data point + assert len(df.collect()) == 1 + check_saved_data(spark, METRICS, {**pipeline_args, **metrics_out}) + elif add_default_data: + # expect two rows + assert len(df.collect()) == 2 + row_data = [r.asDict() for r in df.collect()] + for row in row_data: + # filter out timestamp rows + for f in INIT_TIMESTAMP_FIELDS.get(METRICS, []): + assert isinstance(row[f], datetime.datetime) + del row[f] + assert DEFAULT_DATA[METRICS] in row_data + assert {**pipeline_args, **metrics_out} in row_data + + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline_run.pipeline} {pipeline_run.run_id}: ingest metrics written to '{METRICS}' table." + ) + + +@pytest.mark.requires_spark +def test_write_metrics_empty_df( + spark: SparkSession, pipeline_run: PipelineRun, caplog: pytest.LogCaptureFixture, empty_df: DataFrame +) -> None: + """Ensure that the writer behaves sensibly with an empty table as input.""" + metrics = write_metrics(spark, empty_df, pipeline_run) + assert metrics.asDict() == {N_READ: 0, N_INVALID: 0, N_VALID: 0, VALIDATION_ERRORS: []} + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline_run.pipeline} {pipeline_run.run_id}: nothing to write to '{METRICS}' audit table." + ) diff --git a/tests/audit/test_rejects.py b/tests/audit/test_rejects.py new file mode 100644 index 0000000..6a80e94 --- /dev/null +++ b/tests/audit/test_rejects.py @@ -0,0 +1,156 @@ +"""Tests for the rejects audit table functions.""" + +import datetime +import json +import logging +from typing import Any + +import pyspark.sql.functions as sf +import pytest +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StringType, StructField, StructType + +from cdm_data_loader_utils.audit.rejects import write_rejects +from cdm_data_loader_utils.audit.schema import ( + PARSED_ROW, + PIPELINE, + RAW_ROW, + REJECTS, + ROW_ERRORS, + RUN_ID, + SOURCE, + TIMESTAMP, +) +from cdm_data_loader_utils.core.constants import INVALID_DATA_FIELD_NAME +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from tests.audit.conftest import create_table + +n_rows_per_file = 4 + + +@pytest.mark.requires_spark +def test_write_rejects( # noqa: PLR0913 + spark: SparkSession, + pipeline_run: PipelineRun, + caplog: pytest.LogCaptureFixture, + csv_schema: list[StructField], + annotated_df_data: list[dict[str, Any]], + annotated_df_schema: StructType, +) -> None: + """Write out the rejects in the annotated_df dataframe.""" + create_table(spark, REJECTS) + annotated_df = spark.createDataFrame(annotated_df_data, schema=annotated_df_schema) + invalid_df = annotated_df.filter(sf.size(ROW_ERRORS) > 0) + assert invalid_df.count() == n_rows_per_file * 4 + # filter out rows + write_rejects(pipeline_run, annotated_df, csv_schema, INVALID_DATA_FIELD_NAME) + + rejects_df = spark.table(f"{pipeline_run.namespace}.{REJECTS}") + rejects_rows = [r.asDict() for r in rejects_df.collect()] + assert len(rejects_rows) == invalid_df.count() # 4 lots of 4 invalid entries + timestamp = rejects_rows[0][TIMESTAMP] + for r in rejects_rows: + # all should have the same timestamp + assert isinstance(r[TIMESTAMP], datetime.datetime) + assert r[TIMESTAMP] == timestamp + assert r[RUN_ID] == pipeline_run.run_id + assert r[PIPELINE] == pipeline_run.pipeline + assert r[SOURCE] == pipeline_run.source_path + # shit test + assert r[ROW_ERRORS] in ( + ["parse_error"], + ["missing_required: col2", "missing_required: col3", "missing_required: col4"], + ["missing_required: col1", "missing_required: col2"], + ["missing_required: col4", "missing_required: col5"], + [ + "missing_required: col1", + "missing_required: col2", + "missing_required: col3", + "missing_required: col4", + "missing_required: col5", + ], + ) + # even worse tests + assert isinstance(json.loads(r[PARSED_ROW]), dict) + if r[ROW_ERRORS] != ["parse_error"]: + assert r[RAW_ROW] is None + else: + assert len(r[RAW_ROW]) >= 1 + + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline_run.pipeline} {pipeline_run.run_id}: invalid rows written to '{REJECTS}' audit table." + ) + + +@pytest.mark.requires_spark +def test_write_rejects_no_rejects( # noqa: PLR0913 + spark: SparkSession, + pipeline_run: PipelineRun, + caplog: pytest.LogCaptureFixture, + csv_schema: list[StructField], + annotated_df_data: list[dict[str, Any]], + annotated_df_schema: StructType, +) -> None: + """Submit the annotated df to the write_rejects function for reject writing... but there are no rejects!""" + create_table(spark, REJECTS) + # filter out all invalid rows + annotated_df = spark.createDataFrame(annotated_df_data, schema=annotated_df_schema) + valid_df = annotated_df.filter(sf.size(ROW_ERRORS) == 0) + assert valid_df.count() == n_rows_per_file + write_rejects( + pipeline_run, + valid_df, + csv_schema, + INVALID_DATA_FIELD_NAME, + ) + + df = spark.table(f"{pipeline_run.namespace}.{REJECTS}") + assert df.count() == 0 + + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline_run.pipeline} {pipeline_run.run_id}: nothing to write to '{REJECTS}' audit table." + ) + + +@pytest.mark.requires_spark +def test_write_rejects_no_row_errors( + empty_df: DataFrame, + empty_df_schema: list[StructField], + pipeline_run: PipelineRun, + caplog: pytest.LogCaptureFixture, +) -> None: + """Ensure that the writer behaves sensibly with an empty table as input.""" + err_msg = f"{pipeline_run.pipeline} {pipeline_run.run_id}: '{ROW_ERRORS}' column not present in dataframe; cannot record rejects." + with pytest.raises( + RuntimeError, + match=err_msg, + ): + write_rejects(pipeline_run, empty_df, empty_df_schema, INVALID_DATA_FIELD_NAME) + + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.ERROR + assert caplog.records[-1].message == err_msg + + +@pytest.mark.requires_spark +def test_write_rejects_empty_df( + spark: SparkSession, + empty_df_schema: list[StructField], + pipeline_run: PipelineRun, + caplog: pytest.LogCaptureFixture, +) -> None: + """Ensure that the writer behaves sensibly with an empty table as input.""" + empty_df = spark.createDataFrame([], schema=StructType([*empty_df_schema, StructField(ROW_ERRORS, StringType())])) + write_rejects(pipeline_run, empty_df, empty_df_schema, INVALID_DATA_FIELD_NAME) + assert len(caplog.records) == 1 + assert caplog.records[-1].levelno == logging.INFO + assert ( + caplog.records[-1].message + == f"{pipeline_run.pipeline} {pipeline_run.run_id}: nothing to write to '{REJECTS}' audit table." + ) diff --git a/tests/audit/test_run.py b/tests/audit/test_run.py new file mode 100644 index 0000000..7a77c92 --- /dev/null +++ b/tests/audit/test_run.py @@ -0,0 +1,161 @@ +"""Tests for the run audit table functions.""" + +import datetime +import logging + +import pytest +from pyspark.sql import SparkSession + +from cdm_data_loader_utils.audit.run import complete_run, fail_run, start_run +from cdm_data_loader_utils.audit.schema import ( + END_TIME, + ERROR, + PIPELINE, + RECORDS_PROCESSED, + RUN, + RUN_ID, + SOURCE, + START_TIME, + STATUS, + STATUS_ERROR, + STATUS_SUCCESS, +) +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from tests.audit.conftest import DEFAULT_DATA, check_saved_data, create_table +from tests.conftest import PIPELINE_RUN, TEST_NS + + +def check_no_record_log_message( + run: PipelineRun, table_name: str, status: str, caplog_message: logging.LogRecord +) -> None: + """Checker for the log message if an update is attempted to an entry that doesn't exist. + + :param run: pipeline run + :type run: PipelineRun + :param table_name: table name + :type table_name: str + :param status: status + :type status: str + :param caplog_message: the log entry to test + :type caplog_message: + """ + assert caplog_message.levelno == logging.WARNING + assert ( + caplog_message.message + == f"{run.pipeline} {run.run_id}: cannot update '{table_name}' to status {status} because no record exists." + ) + + +# Tests for the run-related functions +@pytest.mark.requires_spark +def test_start_run(spark: SparkSession, pipeline_run: PipelineRun) -> None: + """Test that starting a run creates a new entry in the RUN audit table.""" + create_table(spark, RUN) + df = spark.table(f"{TEST_NS}.{RUN}") + assert len(df.collect()) == 0 + + start_run(spark, pipeline_run) + check_saved_data(spark, RUN, DEFAULT_DATA[RUN]) + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "pipe_2"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/src/two"]) +@pytest.mark.parametrize("add_default_data", [True, False]) +def test_complete_run( + spark: SparkSession, + pipeline: str, + source: str, + add_default_data: bool, # noqa: FBT001 + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that completing a run updates the appropriate run table.""" + create_table(spark, RUN, add_default_data=add_default_data) + is_saved_table = add_default_data and pipeline == PIPELINE_RUN[PIPELINE] and source == PIPELINE_RUN[SOURCE] + pipeline_run = PipelineRun(PIPELINE_RUN[RUN_ID], pipeline, source, TEST_NS) + n_records = 123_456_789_012_345 + + complete_run( + spark, + pipeline_run, + records_processed=n_records, + ) + + rows = [r.asDict() for r in spark.table(f"{TEST_NS}.{RUN}").collect()] + if is_saved_table: + assert len(rows) == 1 + row = rows[0] + assert row[STATUS] == STATUS_SUCCESS + assert row[RECORDS_PROCESSED] == n_records + assert isinstance(row[END_TIME], datetime.datetime) + assert row[END_TIME] > row[START_TIME] + assert len(caplog.records) == 1 + + else: + if add_default_data: + # only row will be the default data + assert len(rows) == 1 + check_saved_data(spark, RUN, DEFAULT_DATA[RUN]) + else: + # no data successfully added + assert len(rows) == 0 + # two log messages: invalid row, completion + assert len(caplog.records) == 2 + check_no_record_log_message(pipeline_run, RUN, STATUS_SUCCESS, caplog.records[0]) + + assert caplog.records[-1].levelno == logging.INFO + assert caplog.records[-1].message == f"{pipeline} {PIPELINE_RUN[RUN_ID]}: run completed" + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("pipeline", [PIPELINE_RUN[PIPELINE], "pipe_2"]) +@pytest.mark.parametrize("source", [PIPELINE_RUN[SOURCE], "/src/two"]) +@pytest.mark.parametrize("add_default_data", [True, False]) +def test_fail_run( + spark: SparkSession, + pipeline: str, + source: str, + add_default_data: bool, # noqa: FBT001 + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that failing a run does the appropriate fun stuff.""" + create_table(spark, RUN, add_default_data=add_default_data) + is_saved_table = add_default_data and pipeline == PIPELINE_RUN[PIPELINE] and source == PIPELINE_RUN[SOURCE] + pipeline_run = PipelineRun(PIPELINE_RUN[RUN_ID], pipeline, source, TEST_NS) + error = RuntimeError("ZOMGWTELF! " * 200) + + fail_run( + spark, + pipeline_run, + error=error, + ) + rows = [r.asDict() for r in spark.table(f"{TEST_NS}.{RUN}").collect()] + + if is_saved_table: + assert len(rows) == 1 + row = rows[0] + assert row[RECORDS_PROCESSED] == 0 + assert isinstance(row[END_TIME], datetime.datetime) + assert row[END_TIME] > row[START_TIME] + + assert row[STATUS] == STATUS_ERROR + assert row[END_TIME] is not None + assert len(row[ERROR]) == 1000 + assert len(caplog.records) == 1 + + else: + if add_default_data: + # only row will be the default data + assert len(rows) == 1 + check_saved_data(spark, RUN, DEFAULT_DATA[RUN]) + else: + assert len(rows) == 0 + + # two log messages: invalid row, completion + assert len(caplog.records) == 2 + check_no_record_log_message(pipeline_run, RUN, STATUS_ERROR, caplog.records[0]) + + assert caplog.records[-1].levelno == logging.ERROR + assert caplog.records[-1].message.startswith( + f"{pipeline} {PIPELINE_RUN[RUN_ID]}: run failed with RuntimeError('ZOMGWTELF! ZOMGWTELF! " + ) diff --git a/tests/audit/test_schema.py b/tests/audit/test_schema.py new file mode 100644 index 0000000..1bf179a --- /dev/null +++ b/tests/audit/test_schema.py @@ -0,0 +1,32 @@ +"""Tests for the schema package in the audit section.""" + +import pytest + +from cdm_data_loader_utils.audit.schema import PIPELINE, RUN_ID, SOURCE, current_run_expr, match_run +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from tests.audit.conftest import PIPELINE_RUN + + +def test_current_run_expr_no_args() -> None: + """Test generation of the current run expression without any arguments.""" + assert current_run_expr() == f"t.{RUN_ID} = s.{RUN_ID} AND t.{SOURCE} = s.{SOURCE} AND t.{PIPELINE} = s.{PIPELINE}" + + +@pytest.mark.parametrize("target", ["", "some_target"]) +@pytest.mark.parametrize("source", ["", "some_source"]) +def test_current_run_expr(target: str, source: str) -> None: + """Test generation of the current run expression with some arguments.""" + t_str = target if target else "t" + s_str = source if source else "s" + assert ( + current_run_expr(target, source) + == f"{t_str}.{RUN_ID} = {s_str}.{RUN_ID} AND {t_str}.{SOURCE} = {s_str}.{SOURCE} AND {t_str}.{PIPELINE} = {s_str}.{PIPELINE}" + ) + + +def test_match_run(pipeline_run: PipelineRun) -> None: + """Test generation of the current run expression without any arguments.""" + assert ( + match_run(pipeline_run) + == f"{RUN_ID} = '{PIPELINE_RUN[RUN_ID]}' AND {PIPELINE} = '{PIPELINE_RUN[PIPELINE]}' AND {SOURCE} = '{PIPELINE_RUN[SOURCE]}'" + ) diff --git a/tests/readers/test_dsv_read_with_validation.py b/tests/readers/test_dsv_read_with_validation.py new file mode 100644 index 0000000..1c50ff5 --- /dev/null +++ b/tests/readers/test_dsv_read_with_validation.py @@ -0,0 +1,144 @@ +"""Tests for parser error handling, schema compliance, and so on.""" + +from typing import Any + +import pytest +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StringType, StructField + +from cdm_data_loader_utils.audit.schema import ( + AUDIT_SCHEMA, + METRICS, + N_INVALID, + N_READ, + N_VALID, + VALIDATION_ERRORS, +) +from cdm_data_loader_utils.core.constants import INVALID_DATA_FIELD_NAME +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.readers.dsv import read +from cdm_data_loader_utils.validation.dataframe_validator import DataFrameValidator, Validator +from cdm_data_loader_utils.validation.df_nullable_fields import validate as nullable_fields +from tests.conftest import ALL_LINES, MISSING_REQUIRED, TEST_NS, TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, VALID +from tests.helpers import create_empty_delta_table + +PERMISSIVE = "PERMISSIVE" +DROP = "DROPMALFORMED" +FAILFAST = "FAILFAST" + +INGEST_MODES = [PERMISSIVE, DROP, FAILFAST] + +TEST_SCHEMA_FIELD = StructField("what", StringType(), nullable=False) + + +def read_with_validation( + spark: SparkSession, + run: PipelineRun, + path: str, + schema_fields: list[StructField], + options: dict[str, Any], +) -> DataFrame: + """Read in a delimiter-separated file, performing some minimal validation checks. + + :param spark: spark sesh + :type spark: SparkSession + :param run: current run info + :type run: PipelineRun + :param path: location of the file to parse + :type path: str + :param schema_fields: list of StructFields describing the expected input + :type schema_fields: list[StructField] + :return: just the valid rows from the dataframe + :rtype: DataFrame + """ + parsed_df = read(spark, path, schema_fields, options) + + # validate the output + validator = DataFrameValidator(spark) + result = validator.validate_dataframe( + data_to_validate=parsed_df, + schema=schema_fields, + run=run, + validator=Validator(nullable_fields, {"invalid_col": INVALID_DATA_FIELD_NAME}), + invalid_col=INVALID_DATA_FIELD_NAME, + ) + + return result.valid_df + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("mode", INGEST_MODES) +@pytest.mark.parametrize("csv_lines", [VALID, MISSING_REQUIRED, TYPE_MISMATCH, TOO_FEW_COLS, TOO_MANY_COLS, ALL_LINES]) +def test_csv_read_with_validation_errors( # noqa: PLR0913 + spark: SparkSession, + mode: str, + csv_lines: str, + csv_schema: list[StructField], + request: pytest.FixtureRequest, + pipeline_run: PipelineRun, +) -> None: + """Test ingestion of valid and invalid CSV data.""" + n_rows = 4 + csv_lines_path = request.getfixturevalue(csv_lines) + + read_options = { + "delimiter": ",", + "header": False, + "comment": "#", + "dateFormat": "yyyyMMdd", + "ignoreLeadingWhiteSpace": True, + "ignoreTrailingWhiteSpace": True, + "mode": mode, + } + + # DROPMALFORMED won't be run + if mode == DROP: + with pytest.raises(ValueError, match="The only permitted read modes are PERMISSIVE and FAILFAST"): + read_with_validation(spark, pipeline_run, str(csv_lines_path), csv_schema, options=read_options) + return + + if mode == FAILFAST and csv_lines in (TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, ALL_LINES): + with pytest.raises(Py4JJavaError, match="An error occurred while calling "): + read_with_validation(spark, pipeline_run, str(csv_lines_path), csv_schema, options=read_options) + return + + # prep the appropriate tables + for t in AUDIT_SCHEMA: + create_empty_delta_table(spark, TEST_NS, t, AUDIT_SCHEMA[t]) + + test_df = read_with_validation(spark, pipeline_run, str(csv_lines_path), csv_schema, options=read_options) + data_rows = [r.asDict() for r in test_df.collect()] + fields = {f.name for f in csv_schema} + assert set(test_df.schema.fieldNames()) == fields + + metrics = spark.table(f"{pipeline_run.namespace}.{METRICS}").collect() + assert len(metrics) == 1 + result = metrics[0] + assert len(data_rows) == result[N_VALID] + + # all modes should correctly parse the valid data + if csv_lines == VALID: + assert result[N_VALID] == n_rows + assert result[N_READ] == n_rows + assert result[N_INVALID] == 0 + + if csv_lines in (TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH): + # permissive will pull in all the data + assert result[N_VALID] == 0 + assert result[N_READ] == n_rows + assert result[N_INVALID] == n_rows + + # none of the modes GAF about missing required values, so all will be read + # they will be removed by the validator + if csv_lines == MISSING_REQUIRED: + assert result[N_VALID] == 0 + assert result[N_READ] == n_rows + assert result[N_INVALID] == n_rows + for n in range(1, 6): + assert f"missing_required: col{n}" in result[VALIDATION_ERRORS] + + if csv_lines == ALL_LINES: + assert result[N_VALID] == n_rows + assert result[N_READ] == n_rows * 5 if mode == PERMISSIVE else n_rows + assert result[N_INVALID] == n_rows * 4 if mode == PERMISSIVE else n_rows diff --git a/tests/validation/__init__.py b/tests/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/validation/test_dataframe_validator.py b/tests/validation/test_dataframe_validator.py new file mode 100644 index 0000000..63a316c --- /dev/null +++ b/tests/validation/test_dataframe_validator.py @@ -0,0 +1,82 @@ +"""Tests for parser error handling, schema compliance, and so on.""" + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import StructField, StructType + +from cdm_data_loader_utils.audit.schema import METRICS, REJECTS, ROW_ERRORS +from cdm_data_loader_utils.core.constants import INVALID_DATA_FIELD_NAME +from cdm_data_loader_utils.core.pipeline_run import PipelineRun +from cdm_data_loader_utils.validation.dataframe_validator import DataFrameValidator, Validator +from tests.audit.conftest import create_table + + +@pytest.mark.requires_spark +def test_validate_dataframe_empty_df(pipeline_run: PipelineRun, empty_df: DataFrame) -> None: + """Assert that an empty dataframe does not perform any validation.""" + dfv = DataFrameValidator(MagicMock()) + validation_fn = MagicMock() + assert empty_df.count() == 0 + + validator = Validator(validation_fn, {"invalid_col": INVALID_DATA_FIELD_NAME}) + output = dfv.validate_dataframe( + data_to_validate=empty_df, + schema=empty_df.schema.fields, + run=pipeline_run, + validator=validator, + invalid_col=INVALID_DATA_FIELD_NAME, + ) + assert output.records_read == 0 + assert output.records_invalid == 0 + assert output.records_valid == 0 + assert output.validation_errors == [] + assert output.valid_df.count() == 0 + validation_fn.assert_not_called() + + +@pytest.mark.requires_spark +def test_validate_dataframe_no_validation( # noqa: PLR0913 + spark: SparkSession, + csv_schema: list[StructField], + annotated_df_data: list[dict[str, Any]], + annotated_df_schema: StructType, + annotated_df_errors: set[str], + pipeline_run: PipelineRun, +) -> None: + """Test the dataframe validator writes to the appropriate tables. + + The validator is mocked for test purposes. + """ + for t in [METRICS, REJECTS]: + create_table(spark, t, add_default_data=False) + + dfv = DataFrameValidator(spark) + annotated_df = spark.createDataFrame(annotated_df_data, schema=annotated_df_schema) + validation_fn = MagicMock(return_value=annotated_df) + + output = dfv.validate_dataframe( + data_to_validate=annotated_df, + schema=csv_schema, + run=pipeline_run, + validator=Validator(validation_fn, {}), + invalid_col=INVALID_DATA_FIELD_NAME, + ) + n_rows = 4 # 4 rows per csv file + assert output.records_invalid == n_rows * 4 + assert output.records_valid == n_rows + assert output.records_read == n_rows * 5 + + assert set(output.validation_errors) == annotated_df_errors + # the first 4 rows are valid; everything else triggers validation errors + assert [r.asDict() for r in output.valid_df.collect()] == [ + {k: v for k, v in r.items() if k not in (INVALID_DATA_FIELD_NAME, ROW_ERRORS)} for r in annotated_df_data[:4] + ] + + # check that metrics and rejects have been written + metrics = spark.table(f"{pipeline_run.namespace}.{METRICS}") + assert metrics.count() == 1 + rejects = spark.table(f"{pipeline_run.namespace}.{REJECTS}") + assert rejects.count() == output.records_invalid diff --git a/tests/validation/test_df_nullable_fields.py b/tests/validation/test_df_nullable_fields.py new file mode 100644 index 0000000..e008e80 --- /dev/null +++ b/tests/validation/test_df_nullable_fields.py @@ -0,0 +1,71 @@ +"""Tests for the nullable fields validator.""" + +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import StructField, StructType + +from cdm_data_loader_utils.audit.schema import ROW_ERRORS +from cdm_data_loader_utils.core.constants import INVALID_DATA_FIELD_NAME +from cdm_data_loader_utils.readers.dsv import INVALID_DATA_FIELD +from cdm_data_loader_utils.validation.df_nullable_fields import validate +from tests.conftest import ALL_LINES, MISSING_REQUIRED, TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, VALID + + +@pytest.mark.requires_spark +@pytest.mark.parametrize("csv_lines", [VALID, MISSING_REQUIRED, TYPE_MISMATCH, TOO_FEW_COLS, TOO_MANY_COLS, ALL_LINES]) +def test_csv_read_errors( + spark: SparkSession, + csv_lines: str, + request: pytest.FixtureRequest, + csv_schema: list[StructField], + invalid_csv_missing_required_annots: list[list[str]], +) -> None: + """Test ingestion of valid and invalid CSV data.""" + n_rows = 4 + csv_lines_path = request.getfixturevalue(csv_lines) + + read_options = { + "inferSchema": False, + "enforceSchema": True, + "delimiter": ",", + "header": False, + "comment": "#", + "dateFormat": "yyyyMMdd", + "ignoreLeadingWhiteSpace": True, + "ignoreTrailingWhiteSpace": True, + "mode": "PERMISSIVE", + "columnNameOfCorruptRecord": INVALID_DATA_FIELD_NAME, + } + # this mimics what the dsv reader does + dsv_schema = StructType([*csv_schema, INVALID_DATA_FIELD]) + df = spark.read.options(**read_options).csv(str(csv_lines_path), schema=dsv_schema) + + # any data in invalid_col will show up as a parse_error in annotated_df + original_df_rows = [r.asDict() for r in df.collect()] + n_invalid_rows = 0 + for r in original_df_rows: + if r[INVALID_DATA_FIELD_NAME]: + n_invalid_rows += 1 + + # missing required cols does not show up as an error (thanks, Spark schema validation) + if csv_lines in (VALID, MISSING_REQUIRED): + assert n_invalid_rows == 0 + elif csv_lines == ALL_LINES: + assert n_invalid_rows == n_rows * 3 + else: + # the other inputs all generate parse errors for each row due to incorrect number of cols or type mismatches + assert n_invalid_rows == n_rows + + # run it through the validator + annotated_df = validate(df, csv_schema, INVALID_DATA_FIELD_NAME) + annotated_rows = [r.asDict() for r in annotated_df.collect()] + row_errors = [r[ROW_ERRORS] for r in annotated_rows] + # any data in invalid_col will show up as a parse_error in annotated_df + assert n_invalid_rows == row_errors.count(["parse_error"]) + + # get the specific errors for the missing data + if csv_lines in (MISSING_REQUIRED, ALL_LINES): + if csv_lines == MISSING_REQUIRED: + assert row_errors == invalid_csv_missing_required_annots + else: + assert row_errors[4:8] == invalid_csv_missing_required_annots From 3b8ffbb92ff0a0744431ae8dd3c34f7c53e997c4 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Wed, 21 Jan 2026 16:57:46 -0800 Subject: [PATCH 2/5] fix failfast not failing fast enough --- src/cdm_data_loader_utils/audit/rejects.py | 8 +++++--- src/cdm_data_loader_utils/readers/dsv.py | 5 +++-- tests/readers/test_dsv.py | 19 +++++-------------- .../readers/test_dsv_read_with_validation.py | 13 ++++--------- 4 files changed, 17 insertions(+), 28 deletions(-) diff --git a/src/cdm_data_loader_utils/audit/rejects.py b/src/cdm_data_loader_utils/audit/rejects.py index 3a42100..3f3b37c 100644 --- a/src/cdm_data_loader_utils/audit/rejects.py +++ b/src/cdm_data_loader_utils/audit/rejects.py @@ -53,14 +53,16 @@ def write_rejects( logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) return - invalid_df: DataFrame = annotated_df.filter(sf.size(ROW_ERRORS) > 0) - if invalid_df.count() == 0: + # add in a dummy column so that spark doesn't optimise away everything except the error col + invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter(sf.size(ROW_ERRORS) > 0) + if not invalid_df.select("_dummy").head(1): # nothing to do here logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) return data_fields = [f.name for f in schema_fields] - + # drop the dummy column + invalid_df = invalid_df.drop("_dummy") rejects_df = invalid_df.select( sf.lit(run.run_id).alias(RUN_ID), sf.lit(run.pipeline).alias(PIPELINE), diff --git a/src/cdm_data_loader_utils/readers/dsv.py b/src/cdm_data_loader_utils/readers/dsv.py index bcd2db0..14ed12f 100644 --- a/src/cdm_data_loader_utils/readers/dsv.py +++ b/src/cdm_data_loader_utils/readers/dsv.py @@ -68,14 +68,15 @@ def read( **options, } - if dsv_options["mode"] not in (PERMISSIVE, FAILFAST): - msg = "The only permitted read modes are PERMISSIVE and FAILFAST." + if dsv_options["mode"] != PERMISSIVE: + msg = "The only permitted read mode is PERMISSIVE." log_and_die(msg, ValueError) format_name = get_format_name(options.get("delimiter", options.get("sep"))) try: df = spark.read.options(**dsv_options).csv(path, schema=dsv_schema) + df.head(1) # force spark to read NOW instead of being lazy except Exception: # Log the full stack trace and re-raise to be handled by the caller logger.exception("Failed to load %s from %s", format_name, path) diff --git a/tests/readers/test_dsv.py b/tests/readers/test_dsv.py index 987f7c3..d2bb1c8 100644 --- a/tests/readers/test_dsv.py +++ b/tests/readers/test_dsv.py @@ -85,8 +85,8 @@ def test_csv_read_modes( # noqa: PLR0913 "mode": mode, } - if mode == DROP: - with pytest.raises(ValueError, match="The only permitted read modes are PERMISSIVE and FAILFAST"): + if mode in (DROP, FAILFAST): + with pytest.raises(ValueError, match="The only permitted read mode is PERMISSIVE"): read(spark, str(csv_lines_path), csv_schema, options=read_options) return @@ -101,11 +101,6 @@ def test_csv_read_modes( # noqa: PLR0913 == f"Loaded {n_rows * 5 if csv_lines == ALL_LINES else n_rows} CSV records from {csv_lines_path!s}" ) - if mode == FAILFAST and csv_lines in (TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, ALL_LINES): - with pytest.raises(Py4JJavaError, match="An error occurred while calling "): - test_df.collect() - return - read(spark, str(csv_lines_path), csv_schema, options=read_options) data_rows = [r.asDict() for r in test_df.collect()] @@ -117,13 +112,9 @@ def test_csv_read_modes( # noqa: PLR0913 return if csv_lines in (TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH): - if mode == DROP: - # dropmalformed will not parse any content from these files as all lines are invalid - assert len(data_rows) == 0 - else: - # permissive will pull in all the data - assert len(data_rows) == n_rows + # permissive will pull in all the data + assert len(data_rows) == n_rows return # ALL_LINES: permissive will pull in all, DROP will just pull in the VALID + MISSING_REQUIRED lines - assert len(data_rows) == n_rows * 5 if mode == PERMISSIVE else n_rows * 2 + assert len(data_rows) == n_rows * 5 diff --git a/tests/readers/test_dsv_read_with_validation.py b/tests/readers/test_dsv_read_with_validation.py index 1c50ff5..aed5599 100644 --- a/tests/readers/test_dsv_read_with_validation.py +++ b/tests/readers/test_dsv_read_with_validation.py @@ -93,13 +93,8 @@ def test_csv_read_with_validation_errors( # noqa: PLR0913 } # DROPMALFORMED won't be run - if mode == DROP: - with pytest.raises(ValueError, match="The only permitted read modes are PERMISSIVE and FAILFAST"): - read_with_validation(spark, pipeline_run, str(csv_lines_path), csv_schema, options=read_options) - return - - if mode == FAILFAST and csv_lines in (TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, ALL_LINES): - with pytest.raises(Py4JJavaError, match="An error occurred while calling "): + if mode in (FAILFAST, DROP): + with pytest.raises(ValueError, match="The only permitted read mode is PERMISSIVE"): read_with_validation(spark, pipeline_run, str(csv_lines_path), csv_schema, options=read_options) return @@ -140,5 +135,5 @@ def test_csv_read_with_validation_errors( # noqa: PLR0913 if csv_lines == ALL_LINES: assert result[N_VALID] == n_rows - assert result[N_READ] == n_rows * 5 if mode == PERMISSIVE else n_rows - assert result[N_INVALID] == n_rows * 4 if mode == PERMISSIVE else n_rows + assert result[N_READ] == n_rows * 5 + assert result[N_INVALID] == n_rows * 4 From 839930c1a0010bee6d6beb3ef683dd2d5fcb796c Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Wed, 21 Jan 2026 19:33:36 -0800 Subject: [PATCH 3/5] spark sucks --- src/cdm_data_loader_utils/audit/rejects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cdm_data_loader_utils/audit/rejects.py b/src/cdm_data_loader_utils/audit/rejects.py index 3f3b37c..e63d02b 100644 --- a/src/cdm_data_loader_utils/audit/rejects.py +++ b/src/cdm_data_loader_utils/audit/rejects.py @@ -54,7 +54,7 @@ def write_rejects( return # add in a dummy column so that spark doesn't optimise away everything except the error col - invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter(sf.size(ROW_ERRORS) > 0) + invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter((sf.size(ROW_ERRORS) > 0) & (sf.col("_dummy") == 1)) if not invalid_df.select("_dummy").head(1): # nothing to do here logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) From 77cbede0ac1776833d7d8471759fa743f38d98c9 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Wed, 21 Jan 2026 19:36:20 -0800 Subject: [PATCH 4/5] spark sucks --- src/cdm_data_loader_utils/audit/rejects.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/cdm_data_loader_utils/audit/rejects.py b/src/cdm_data_loader_utils/audit/rejects.py index e63d02b..2605ea3 100644 --- a/src/cdm_data_loader_utils/audit/rejects.py +++ b/src/cdm_data_loader_utils/audit/rejects.py @@ -54,7 +54,9 @@ def write_rejects( return # add in a dummy column so that spark doesn't optimise away everything except the error col - invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter((sf.size(ROW_ERRORS) > 0) & (sf.col("_dummy") == 1)) + invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter( + (sf.size(ROW_ERRORS) > 0) & (sf.col("_dummy") == 1) + ) if not invalid_df.select("_dummy").head(1): # nothing to do here logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS) From 5636d49592af0c03682bb376c2501b353ca33e45 Mon Sep 17 00:00:00 2001 From: ialarmedalien Date: Thu, 22 Jan 2026 06:36:03 -0800 Subject: [PATCH 5/5] Adding in tsv/csv parse test --- src/cdm_data_loader_utils/readers/dsv.py | 4 +-- tests/conftest.py | 10 ++++-- tests/data/dsv/all_lines.tsv | 34 +++++++++++++++++++ tests/readers/test_dsv.py | 32 +++++++++++++++-- .../readers/test_dsv_read_with_validation.py | 1 - 5 files changed, 74 insertions(+), 7 deletions(-) create mode 100644 tests/data/dsv/all_lines.tsv diff --git a/src/cdm_data_loader_utils/readers/dsv.py b/src/cdm_data_loader_utils/readers/dsv.py index 14ed12f..0715d79 100644 --- a/src/cdm_data_loader_utils/readers/dsv.py +++ b/src/cdm_data_loader_utils/readers/dsv.py @@ -105,7 +105,7 @@ def read_tsv( """ if not options: options = {} - options["separator"] = "\t" + options["delimiter"] = "\t" return read(spark, path, schema_fields, options) @@ -127,5 +127,5 @@ def read_csv( """ if not options: options = {} - options["separator"] = "," + options["delimiter"] = "," return read(spark, path, schema_fields, options) diff --git a/tests/conftest.py b/tests/conftest.py index a232a83..049dbb9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,7 +98,7 @@ def empty_df(spark: SparkSession, empty_df_schema: list[StructField]) -> Generat ALL_LINES = "all_lines" -@pytest.fixture(scope="session") +@pytest.fixture def csv_schema() -> list[StructField]: """List of fields for parsing the various CSV snippets.""" return [ @@ -139,7 +139,7 @@ def invalid_csv_missing_required(test_data_dir: Path) -> Path: return test_data_dir / "dsv" / "missing_required.csv" -@pytest.fixture(scope="session") +@pytest.fixture def invalid_csv_missing_required_annots() -> list[list[str]]: """Generate the expected error annotations for the lines in invalid_csv_missing_required. @@ -201,6 +201,12 @@ def all_lines(test_data_dir: Path) -> Path: @pytest.fixture(scope="session") +def all_lines_tsv(test_data_dir: Path) -> Path: + """All the CSV lines in a single fixture!""" + return test_data_dir / "dsv" / "all_lines.tsv" + + +@pytest.fixture def annotated_df_schema(csv_schema: list[StructField]) -> StructType: """The schema for the annotated dataframe produced by validating one of the CSV files above.""" actual_csv_schema = list(csv_schema) diff --git a/tests/data/dsv/all_lines.tsv b/tests/data/dsv/all_lines.tsv new file mode 100644 index 0000000..dd7d269 --- /dev/null +++ b/tests/data/dsv/all_lines.tsv @@ -0,0 +1,34 @@ +# all valid +1 20250301 1.2345 true EcoCyc:EG10986-MONOMER +2 20250201 0.2 false MetaCyc:EG10986-MONOMER +3 20250801 23 True 4261555 +4 00010101 .1234 False col5 +# correct number of cols cols empty +1 col5 +# missing leading cols + 2.345 True col5 +# missing trailing cols +3 20250531 23.45 +# all missing + +# too many cols - all have 6 cols + +2 20250710 col3 True col6 +# empty trailing +3 20250101 +# empty leading + col6 +# too few cols +1 +2 20250502 0.2345 +3 23.56 False + +# CSV with incorrect types +# Y N Y N Y +1 2 3 4 5 +# N N Y N Y +1.234 2.3456 3.45 4.5 5 +# N N N Y Y +true false true false true +# Y Y N N Y +00200202 00200202 00200202 00200202 00200202 diff --git a/tests/readers/test_dsv.py b/tests/readers/test_dsv.py index d2bb1c8..2533ff2 100644 --- a/tests/readers/test_dsv.py +++ b/tests/readers/test_dsv.py @@ -1,16 +1,17 @@ """Tests for parser error handling, schema compliance, and so on.""" import logging +from pathlib import Path from typing import Any from unittest.mock import MagicMock import pytest -from py4j.protocol import Py4JJavaError from pyspark.errors import AnalysisException from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from pyspark.testing import assertDataFrameEqual, assertSchemaEqual -from cdm_data_loader_utils.readers.dsv import read +from cdm_data_loader_utils.readers.dsv import read, read_csv, read_tsv from tests.conftest import ALL_LINES, MISSING_REQUIRED, TOO_FEW_COLS, TOO_MANY_COLS, TYPE_MISMATCH, VALID PERMISSIVE = "PERMISSIVE" @@ -60,6 +61,33 @@ def test_read_errors(spark: SparkSession, delimiter: str | None, fmt: str, caplo assert caplog.records[0].message == f"Failed to load {fmt} from /path/to/nowhere" +@pytest.mark.requires_spark +def test_read_tsv_csv(spark: SparkSession, csv_schema: list[StructField], all_lines: Path, all_lines_tsv: Path) -> None: + """Ensure that the TSV and CSV reader shortcuts work as expected.""" + read_options = { + "header": False, + "comment": "#", + "dateFormat": "yyyyMMdd", + "ignoreLeadingWhiteSpace": False, + "ignoreTrailingWhiteSpace": False, + } + csv_options = {"delimiter": ","} + tsv_options = {"delimiter": "\t"} + + test_df_csv = read(spark, str(all_lines), csv_schema, options={**read_options, **csv_options}) + test_df_tsv = read(spark, str(all_lines_tsv), csv_schema, options={**read_options, **tsv_options}) + + csv_df = read_csv(spark, str(all_lines), csv_schema, read_options) + tsv_df = read_tsv(spark, str(all_lines_tsv), csv_schema, read_options) + + for df in [test_df_tsv, csv_df, tsv_df]: + assertSchemaEqual(test_df_csv.schema, df.schema) + + assertDataFrameEqual(test_df_tsv, tsv_df) + assertDataFrameEqual(test_df_csv, csv_df) + # TODO: compare the TSV and CSV versions? + + @pytest.mark.requires_spark @pytest.mark.parametrize("mode", INGEST_MODES) @pytest.mark.parametrize("csv_lines", [VALID, MISSING_REQUIRED, TYPE_MISMATCH, TOO_FEW_COLS, TOO_MANY_COLS, ALL_LINES]) diff --git a/tests/readers/test_dsv_read_with_validation.py b/tests/readers/test_dsv_read_with_validation.py index aed5599..c6383b2 100644 --- a/tests/readers/test_dsv_read_with_validation.py +++ b/tests/readers/test_dsv_read_with_validation.py @@ -3,7 +3,6 @@ from typing import Any import pytest -from py4j.protocol import Py4JJavaError from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StringType, StructField