Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/cdm_data_loader_utils/audit/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Checkpoint audit table functions: adding and updating information on data import pipeline execution."""

from delta.tables import DeltaTable
from pyspark.sql import SparkSession
from pyspark.sql import functions as sf

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,
CHECKPOINT,
LAST_ENTRY_ID,
PIPELINE,
RECORDS_PROCESSED,
RUN_ID,
SOURCE,
STATUS,
STATUS_RUNNING,
UPDATED,
current_run_expr,
)
from cdm_data_loader_utils.core.pipeline_run import PipelineRun
from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger


# Checkpoint table-related functions
def upsert_checkpoint(
spark: SparkSession,
run: PipelineRun,
last_entry_id: str,
records_processed: int,
) -> None:
"""Add or update checkpoint records for the current pipeline ingest.

:param spark: spark sesh
:type spark: SparkSession
:param run: pipeline run
:type run: PipelineRun
:param last_entry_id: ID of the last entry processed
:type last_entry_id: str
:param records_processed: number of entries processed
:type records_processed: int
"""
delta = DeltaTable.forName(spark, f"{run.namespace}.{CHECKPOINT}")

df = spark.range(1).select(
sf.lit(run.run_id).alias(RUN_ID),
sf.lit(run.pipeline).alias(PIPELINE),
sf.lit(run.source_path).alias(SOURCE),
sf.lit(STATUS_RUNNING).alias(STATUS),
sf.lit(records_processed).alias(RECORDS_PROCESSED),
sf.lit(last_entry_id).alias(LAST_ENTRY_ID),
sf.current_timestamp().alias(UPDATED),
)
updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[CHECKPOINT])

(
delta.alias("t")
.merge(updates.alias("s"), current_run_expr())
.whenMatchedUpdate(set={val: f"s.{val}" for val in [STATUS, RECORDS_PROCESSED, LAST_ENTRY_ID, UPDATED]})
.whenNotMatchedInsertAll()
.execute()
)
get_cdm_logger().info("%s %s: checkpoint created/updated", run.pipeline, run.run_id)


def update_checkpoint_status(spark: SparkSession, run: PipelineRun, status: str) -> None:
"""Update checkpoint status and timestamp.

:param spark: spark sesh
:type spark: SparkSession
:param run: pipeline run
:type run: PipelineRun
:param status: pipeline status
:type status: str
"""
delta = DeltaTable.forName(spark, f"{run.namespace}.{CHECKPOINT}")
delta.update(
" AND ".join([f"{p} = '{getattr(run, p)}'" for p in [RUN_ID, PIPELINE, SOURCE]]),
{STATUS: sf.lit(status), UPDATED: sf.current_timestamp()},
)
# check whether rows were updated by looking in the delta log
# N.b. this may not work correctly if another process updates the table in the interim
metrics = delta.history(1).select("operationMetrics").collect()[0][0]
if int(metrics.get("numUpdatedRows", 0)) == 0:
get_cdm_logger().warning(
"%s %s: cannot update '%s' to status %s because no record exists.",
run.pipeline,
run.run_id,
CHECKPOINT,
status,
)
else:
get_cdm_logger().info("%s %s: checkpoint successfully updated to status %s", run.pipeline, run.run_id, status)


def load_checkpoint(spark: SparkSession, run: PipelineRun) -> str | None:
"""Load any existing checkpoint data, filtered by current pipeline and data source path.

:param spark: spark sesh
:type spark: SparkSession
:param run: pipeline run
:type run: PipelineRun
:return: either the last entry successfully saved and dumped to disk, or None
:rtype: str | None
"""
rows = (
spark.table(f"{run.namespace}.{CHECKPOINT}")
.filter(sf.col(RUN_ID) == run.run_id)
.filter(sf.col(PIPELINE) == run.pipeline)
.filter(sf.col(SOURCE) == run.source_path)
.select(LAST_ENTRY_ID)
.limit(1)
.collect()
)
return rows[0][LAST_ENTRY_ID] if rows else None
97 changes: 97 additions & 0 deletions src/cdm_data_loader_utils/audit/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Audit table for recording metrics."""

from delta.tables import DeltaTable
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as sf

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,
METRICS,
N_INVALID,
N_READ,
N_VALID,
PIPELINE,
ROW_ERRORS,
RUN_ID,
SOURCE,
UPDATED,
VALIDATION_ERRORS,
current_run_expr,
)
from cdm_data_loader_utils.core.pipeline_run import PipelineRun
from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger

logger = get_cdm_logger()


def write_metrics(
spark: SparkSession,
annotated_df: DataFrame,
run: PipelineRun,
) -> Row:
"""Write metrics for the current batch of imports to disk.

:param spark: spark sesh
:type spark: SparkSession
:param run: current pipeline run
:type run: PipelineRun
:param records_read: total number of records read
:type records_read: int
:param records_valid: number of valid records
:type records_valid: int
:param records_invalid: number of invalid records
:type records_invalid: int
:param validation_errors: list of validation errors encountered in the batch
:type validation_errors: list[str]
:return: row of a dataframe with validation metrics
:rtype: Row
"""
if annotated_df.count() == 0:
# nothing to do here
logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, METRICS)
return Row(records_read=0, records_valid=0, records_invalid=0, validation_errors=[])

invalid_df = annotated_df.filter(sf.size(ROW_ERRORS) > 0)

validation_errors = sorted(
[r.reason for r in invalid_df.select(sf.explode(ROW_ERRORS).alias("reason")).distinct().collect()]
)

metrics_df = annotated_df.agg(
sf.count("*").alias(N_READ),
sf.sum(sf.when(sf.size(ROW_ERRORS) == 0, 1).otherwise(0)).alias(N_VALID),
sf.sum(sf.when(sf.size(ROW_ERRORS) > 0, 1).otherwise(0)).alias(N_INVALID),
sf.lit(validation_errors).alias(VALIDATION_ERRORS),
)
metrics = metrics_df.collect()[0]

df = spark.range(1).select(
sf.lit(run.run_id).alias(RUN_ID),
sf.lit(run.pipeline).alias(PIPELINE),
sf.lit(run.source_path).alias(SOURCE),
sf.lit(metrics.records_read).alias(N_READ),
sf.lit(metrics.records_valid).alias(N_VALID),
sf.lit(metrics.records_invalid).alias(N_INVALID),
sf.lit(metrics.validation_errors).alias(VALIDATION_ERRORS),
sf.current_timestamp().alias(UPDATED),
)
updates = spark.createDataFrame(df.rdd, schema=AUDIT_SCHEMA[METRICS])

target = DeltaTable.forName(
spark,
f"{run.namespace}.{METRICS}",
)

(
target.alias("t")
.merge(
updates.alias("s"),
current_run_expr(),
)
.whenMatchedUpdate(set={k: f"s.{k}" for k in [N_READ, N_VALID, N_INVALID, VALIDATION_ERRORS, UPDATED]})
.whenNotMatchedInsertAll()
.execute()
)
get_cdm_logger().info("%s %s: ingest metrics written to '%s' table.", run.pipeline, run.run_id, METRICS)

return metrics
84 changes: 84 additions & 0 deletions src/cdm_data_loader_utils/audit/rejects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Audit table for recording data rejected as invalid during ingest."""

import pyspark.sql.functions as sf
from pyspark.sql import DataFrame
from pyspark.sql.types import StructField

from cdm_data_loader_utils.audit.schema import (
AUDIT_SCHEMA,
PARSED_ROW,
PIPELINE,
RAW_ROW,
REJECTS,
ROW_ERRORS,
RUN_ID,
SOURCE,
TIMESTAMP,
)
from cdm_data_loader_utils.core.pipeline_run import PipelineRun
from cdm_data_loader_utils.utils.cdm_logger import get_cdm_logger

logger = get_cdm_logger()


def write_rejects(
run: PipelineRun,
annotated_df: DataFrame,
schema_fields: list[StructField],
invalid_col: str,
) -> None:
"""Write rejected data to the rejects audit table.

This should plug in directly to readers like the spark CSV reader, which put any non-compliant data into
a single column when run in PERMISSIVE mode (default for the cdm_data_loader_utils readers).

It is expected that the dataframe will contain a column called ROW_ERRORS, which contains a list of strings
describing the errors found in the rows.

:param run: current pipeline run
:type run: PipelineRun
:param annotated_df: dataframe with errors to be written out
:type annotated_df: DataFrame
:param schema: schema of the dataframe -- s1hould be the unamended version without a column for invalid data
:type schema: StructType
:param invalid_col: name of the column with the invalid data in it
:type invalid_col: str
"""
if ROW_ERRORS not in annotated_df.columns:
err_msg = f"{run.pipeline} {run.run_id}: '{ROW_ERRORS}' column not present in dataframe; cannot record rejects."
logger.error(err_msg)
raise RuntimeError(err_msg)

if annotated_df.count() == 0:
logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS)
return

# add in a dummy column so that spark doesn't optimise away everything except the error col
invalid_df: DataFrame = annotated_df.withColumn("_dummy", sf.lit(1)).filter(
(sf.size(ROW_ERRORS) > 0) & (sf.col("_dummy") == 1)
)
if not invalid_df.select("_dummy").head(1):
# nothing to do here
logger.info("%s %s: nothing to write to '%s' audit table.", run.pipeline, run.run_id, REJECTS)
return

data_fields = [f.name for f in schema_fields]
# drop the dummy column
invalid_df = invalid_df.drop("_dummy")
rejects_df = invalid_df.select(
sf.lit(run.run_id).alias(RUN_ID),
sf.lit(run.pipeline).alias(PIPELINE),
sf.lit(run.source_path).alias(SOURCE),
sf.col(invalid_col).alias(RAW_ROW),
sf.to_json(sf.struct(*[sf.col(c) for c in data_fields])).alias(PARSED_ROW),
sf.col(ROW_ERRORS),
sf.current_timestamp().alias(TIMESTAMP),
)
# ensure that it conforms to the schema
rejects_df = rejects_df.select(
*[sf.col(f.name).cast(f.dataType).alias(f.name) for f in AUDIT_SCHEMA[REJECTS].fields]
)
# write to disk
rejects_df.write.format("delta").mode("append").saveAsTable(f"{run.namespace}.{REJECTS}")

get_cdm_logger().info("%s %s: invalid rows written to '%s' audit table.", run.pipeline, run.run_id, REJECTS)
Loading