Skip to content

Commit 2922cfb

Browse files
committed
fix: prevent Null type for datetime columns when all values are null
When all jobs are RUNNING (no end_time), Polars infers the column type as Null instead of Datetime with null values, causing SchemaError during concatenation. Solution: - Define PROCESSED_JOB_SCHEMA constant with explicit types - Use schema parameter in pl.DataFrame() to enforce types - Ensures datetime columns are always Datetime("us", "UTC") even when all null This fixes the "type Datetime is incompatible with expected type Null" error that occurs when collecting data with only RUNNING jobs.
1 parent c13251e commit 2922cfb

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

slurm_usage.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
import re
2626
import subprocess
27+
import typing
2728
from collections import defaultdict
2829
from concurrent.futures import ThreadPoolExecutor, as_completed
2930
from datetime import datetime, timedelta, timezone
@@ -672,6 +673,33 @@ def to_dict(self) -> dict[str, Any]:
672673
"""Convert to dictionary for DataFrame creation."""
673674
return self.model_dump()
674675

676+
@classmethod
677+
def get_polars_schema(cls) -> dict[str, pl.DataType]:
678+
"""Get Polars schema derived from Pydantic model fields."""
679+
mapping: dict[str, pl.DataType] = {
680+
str: pl.Utf8,
681+
int: pl.Int64,
682+
float: pl.Float64,
683+
bool: pl.Boolean,
684+
# All datetime fields should be UTC
685+
datetime: pl.Datetime("us", "UTC"),
686+
}
687+
688+
schema = {}
689+
for field_name, field_info in cls.model_fields.items():
690+
annotation = field_info.annotation
691+
692+
# Handle Optional types (Union[T, None] or T | None)
693+
origin = typing.get_origin(annotation)
694+
if origin is typing.Union or origin is type(None | int):
695+
args = typing.get_args(annotation)
696+
# Get the non-None type
697+
annotation = next((arg for arg in args if arg is not type(None)), str)
698+
# Map Python types to Polars types
699+
schema[field_name] = mapping.get(annotation, pl.Utf8)
700+
701+
return schema
702+
675703

676704
class DateCompletionTracker(BaseModel):
677705
"""Tracks which dates have been fully processed and don't need re-collection."""
@@ -1334,10 +1362,14 @@ def _processed_jobs_to_dataframe(
13341362
DataFrame with job data
13351363
13361364
"""
1337-
return pl.DataFrame(
1338-
[j.to_dict() for j in processed_jobs],
1339-
infer_schema_length=None,
1340-
)
1365+
# Create DataFrame with explicit schema to prevent Null type inference
1366+
schema = ProcessedJob.get_polars_schema()
1367+
1368+
if not processed_jobs:
1369+
return pl.DataFrame(schema=schema)
1370+
1371+
data_dicts = [j.to_dict() for j in processed_jobs]
1372+
return pl.DataFrame(data_dicts, schema=schema)
13411373

13421374

13431375
def _save_processed_jobs_to_parquet(

0 commit comments

Comments
 (0)