Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions src/sqlmesh_openlineage/console.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""OpenLineage Console wrapper for SQLMesh."""
from __future__ import annotations

import logging
import uuid
import typing as t

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from sqlmesh.core.console import Console
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike
Expand Down Expand Up @@ -72,10 +75,14 @@ def start_snapshot_evaluation_progress(
# Store snapshot for later reference
self._current_snapshots[snapshot.name] = snapshot

self._emitter.emit_snapshot_start(
snapshot=snapshot,
run_id=run_id,
)
try:
self._emitter.emit_snapshot_start(
snapshot=snapshot,
run_id=run_id,
snapshots=self._current_snapshots,
)
except Exception:
logger.warning("Failed to emit START event for %s", snapshot.name, exc_info=True)

# Delegate to wrapped console
self._wrapped.start_snapshot_evaluation_progress(snapshot, audit_only)
Expand All @@ -96,20 +103,24 @@ def update_snapshot_evaluation_progress(
run_id = self._active_runs.pop(snapshot.name, None)

if run_id:
if num_audits_failed > 0:
self._emitter.emit_snapshot_fail(
snapshot=snapshot,
run_id=run_id,
error=f"Audit failed: {num_audits_failed} audit(s) failed",
)
else:
self._emitter.emit_snapshot_complete(
snapshot=snapshot,
run_id=run_id,
interval=interval,
duration_ms=duration_ms,
execution_stats=execution_stats,
)
try:
if num_audits_failed > 0:
self._emitter.emit_snapshot_fail(
snapshot=snapshot,
run_id=run_id,
error=f"Audit failed: {num_audits_failed} audit(s) failed",
)
else:
self._emitter.emit_snapshot_complete(
snapshot=snapshot,
run_id=run_id,
interval=interval,
duration_ms=duration_ms,
execution_stats=execution_stats,
snapshots=self._current_snapshots,
)
except Exception:
logger.warning("Failed to emit event for %s", snapshot.name, exc_info=True)

# Delegate to wrapped console
self._wrapped.update_snapshot_evaluation_progress(
Expand All @@ -130,11 +141,14 @@ def stop_evaluation_progress(self, success: bool = True) -> None:
for snapshot_name, run_id in list(self._active_runs.items()):
snapshot = self._current_snapshots.get(snapshot_name)
if snapshot and run_id:
self._emitter.emit_snapshot_fail(
snapshot=snapshot,
run_id=run_id,
error="Evaluation interrupted" if not success else "Unknown error",
)
try:
self._emitter.emit_snapshot_fail(
snapshot=snapshot,
run_id=run_id,
error="Evaluation interrupted" if not success else "Unknown error",
)
except Exception:
logger.warning("Failed to emit FAIL event for %s", snapshot_name, exc_info=True)

# Clear tracking state
self._active_runs.clear()
Expand Down
30 changes: 23 additions & 7 deletions src/sqlmesh_openlineage/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from __future__ import annotations

import typing as t
from collections import defaultdict

if t.TYPE_CHECKING:
from sqlmesh.core.snapshot import Snapshot
from sqlmesh.core.model import Model
from openlineage.client.event_v2 import InputDataset, OutputDataset


Expand Down Expand Up @@ -90,19 +88,28 @@ def snapshot_to_column_lineage_facet(
# Get column name
source_col = exp.to_column(lineage_node.name).name

# Determine transformation type based on whether
# output column name matches source column name
is_identity = col_name == source_col
transformations = [
column_lineage_dataset.Transformation(
type="DIRECT",
subtype="IDENTITY" if is_identity else "TRANSFORMATION",
)
]

input_fields.append(
column_lineage_dataset.InputField(
namespace=namespace,
name=table_name,
field=source_col,
transformations=transformations,
)
)

if input_fields:
fields[col_name] = column_lineage_dataset.Fields(
inputFields=input_fields,
transformationType="",
transformationDescription="",
)

except Exception:
Expand Down Expand Up @@ -156,19 +163,28 @@ def snapshot_to_output_dataset(
def snapshot_to_input_datasets(
snapshot: "Snapshot",
namespace: str,
snapshots: t.Optional[t.Dict[str, "Snapshot"]] = None,
) -> t.List["InputDataset"]:
"""Get upstream dependencies as input datasets."""
"""Get upstream dependencies as input datasets.

When a snapshots dict is provided, parent snapshots are looked up to
produce fully qualified table names consistent with output datasets.
"""
from openlineage.client.event_v2 import InputDataset

inputs: t.List["InputDataset"] = []

# Get parent snapshot IDs
for parent_id in snapshot.parents:
# Parent ID contains the name we need
# Try to resolve fully qualified name via the snapshots dict
parent_name = parent_id.name
if snapshots and parent_name in snapshots:
parent_name = snapshot_to_table_name(snapshots[parent_name])

inputs.append(
InputDataset(
namespace=namespace,
name=parent_id.name,
name=parent_name,
)
)

Expand Down
119 changes: 107 additions & 12 deletions src/sqlmesh_openlineage/emitter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""OpenLineage event emitter for SQLMesh."""
from __future__ import annotations

import logging
import typing as t
from datetime import datetime, timezone

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
from sqlmesh.core.snapshot import Snapshot
from sqlmesh.core.snapshot.definition import Interval
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats

PRODUCER = "https://github.com/sidequery/sqlmesh-openlineage"


class OpenLineageEmitter:
"""Emits OpenLineage events for SQLMesh operations."""
Expand All @@ -20,6 +25,7 @@ def __init__(
api_key: t.Optional[str] = None,
):
from openlineage.client import OpenLineageClient
from openlineage.client.client import OpenLineageClientOptions

self.namespace = namespace
self.url = url
Expand All @@ -32,15 +38,82 @@ def __init__(
elif api_key:
self.client = OpenLineageClient(
url=url,
options={"api_key": api_key},
options=OpenLineageClientOptions(api_key=api_key),
)
else:
self.client = OpenLineageClient(url=url)

def _build_job_facets(self, snapshot: "Snapshot") -> t.Dict[str, t.Any]:
"""Build job facets including SQL, job type, and source code location."""
from openlineage.client.facet_v2 import job_type_job, sql_job, source_code_location_job

facets: t.Dict[str, t.Any] = {}

# JobTypeJobFacet: identify as SQLMesh batch job
facets["jobType"] = job_type_job.JobTypeJobFacet(
processingType="BATCH",
integration="SQLMESH",
jobType="MODEL",
)

# SQLJobFacet: include the model SQL query
try:
if snapshot.is_model and snapshot.model:
query = snapshot.model.query
if query is not None:
sql_str = str(query)
if sql_str:
facets["sql"] = sql_job.SQLJobFacet(query=sql_str)
except Exception:
pass

# SourceCodeLocationJobFacet: include model file path
try:
if snapshot.is_model and snapshot.model:
model_path = getattr(snapshot.model, "_path", None)
if model_path is not None:
path_str = str(model_path)
if path_str:
facets["sourceCodeLocation"] = (
source_code_location_job.SourceCodeLocationJobFacet(
type="file",
url=f"file://{path_str}",
)
)
except Exception:
pass

return facets

def _build_processing_engine_facet(self) -> t.Dict[str, t.Any]:
"""Build run facets for processing engine info."""
from openlineage.client.facet_v2 import processing_engine_run

facets: t.Dict[str, t.Any] = {}

try:
from sqlmesh import __version__ as sqlmesh_version
except ImportError:
sqlmesh_version = "unknown"

try:
from sqlmesh_openlineage import __version__ as adapter_version
except ImportError:
adapter_version = "unknown"

facets["processing_engine"] = processing_engine_run.ProcessingEngineRunFacet(
version=sqlmesh_version,
name="SQLMesh",
openlineageAdapterVersion=adapter_version,
)

return facets

def emit_snapshot_start(
self,
snapshot: "Snapshot",
run_id: str,
snapshots: t.Optional[t.Dict[str, "Snapshot"]] = None,
) -> None:
"""Emit a START event for snapshot evaluation."""
from openlineage.client.event_v2 import RunEvent, RunState, Run, Job
Expand All @@ -50,19 +123,25 @@ def emit_snapshot_start(
snapshot_to_input_datasets,
)

inputs = snapshot_to_input_datasets(snapshot, self.namespace)
inputs = snapshot_to_input_datasets(snapshot, self.namespace, snapshots=snapshots)
output = snapshot_to_output_dataset(snapshot, self.namespace)

job_facets = self._build_job_facets(snapshot)
run_facets = self._build_processing_engine_facet()

event = RunEvent(
eventType=RunState.START,
eventTime=datetime.now(timezone.utc).isoformat(),
run=Run(runId=run_id),
job=Job(namespace=self.namespace, name=snapshot.name),
run=Run(runId=run_id, facets=run_facets),
job=Job(namespace=self.namespace, name=snapshot.name, facets=job_facets),
inputs=inputs,
outputs=[output] if output else [],
producer="sqlmesh-openlineage",
producer=PRODUCER,
)
self.client.emit(event)
try:
self.client.emit(event)
except Exception:
logger.warning("Failed to emit %s event for %s", event.eventType, snapshot.name, exc_info=True)

def emit_snapshot_complete(
self,
Expand All @@ -71,33 +150,46 @@ def emit_snapshot_complete(
interval: t.Optional["Interval"] = None,
duration_ms: t.Optional[int] = None,
execution_stats: t.Optional["QueryExecutionStats"] = None,
snapshots: t.Optional[t.Dict[str, "Snapshot"]] = None,
) -> None:
"""Emit a COMPLETE event for snapshot evaluation."""
from openlineage.client.event_v2 import RunEvent, RunState, Run, Job

from sqlmesh_openlineage.datasets import snapshot_to_output_dataset
from sqlmesh_openlineage.datasets import (
snapshot_to_output_dataset,
snapshot_to_input_datasets,
)
from sqlmesh_openlineage.facets import build_run_facets, build_output_facets

run_facets = build_run_facets(
duration_ms=duration_ms,
execution_stats=execution_stats,
)
run_facets.update(self._build_processing_engine_facet())

output = snapshot_to_output_dataset(
snapshot,
self.namespace,
facets=build_output_facets(execution_stats),
)

inputs = snapshot_to_input_datasets(snapshot, self.namespace, snapshots=snapshots)

job_facets = self._build_job_facets(snapshot)

event = RunEvent(
eventType=RunState.COMPLETE,
eventTime=datetime.now(timezone.utc).isoformat(),
run=Run(runId=run_id, facets=run_facets),
job=Job(namespace=self.namespace, name=snapshot.name),
job=Job(namespace=self.namespace, name=snapshot.name, facets=job_facets),
inputs=inputs,
outputs=[output] if output else [],
producer="sqlmesh-openlineage",
producer=PRODUCER,
)
self.client.emit(event)
try:
self.client.emit(event)
except Exception:
logger.warning("Failed to emit %s event for %s", event.eventType, snapshot.name, exc_info=True)

def emit_snapshot_fail(
self,
Expand All @@ -124,6 +216,9 @@ def emit_snapshot_fail(
},
),
job=Job(namespace=self.namespace, name=snapshot.name),
producer="sqlmesh-openlineage",
producer=PRODUCER,
)
self.client.emit(event)
try:
self.client.emit(event)
except Exception:
logger.warning("Failed to emit %s event for %s", event.eventType, snapshot.name, exc_info=True)
4 changes: 2 additions & 2 deletions src/sqlmesh_openlineage/facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def build_run_facets(
# Add custom SQLMesh facet with execution info
if duration_ms is not None or execution_stats is not None:
sqlmesh_facet = {
"_producer": "sqlmesh-openlineage",
"_schemaURL": "https://openlineage.io/spec/facets/1-0-0/SQLMeshExecutionFacet.json",
"_producer": "https://github.com/sidequery/sqlmesh-openlineage",
"_schemaURL": "https://github.com/sidequery/sqlmesh-openlineage#SQLMeshExecutionFacet",
}

if duration_ms is not None:
Expand Down
Loading