diff --git a/pipeline/_internal.py b/pipeline/_internal.py
index f643312..58c3c17 100644
--- a/pipeline/_internal.py
+++ b/pipeline/_internal.py
@@ -1,25 +1,29 @@
from dataclasses import dataclass
import polars as pl
+import dataframely as dy
+from .schema.report import (
+ AverageCarVolumeSchema,
+ PopularModelsSchema,
+ SafestModelsSchema,
+)
@dataclass
class Report:
- popularity: pl.DataFrame | pl.LazyFrame
- safety: pl.DataFrame | pl.LazyFrame
- volume: pl.DataFrame | pl.LazyFrame
+ popularity: dy.LazyFrame[PopularModelsSchema]
+ safety: dy.LazyFrame[SafestModelsSchema]
+ volume: dy.LazyFrame[AverageCarVolumeSchema]
def to_string(self) -> str:
"""
Create a pretty-printable representation of this report.
"""
- # Enforce laziness and collection here to ensure we are
- #
df_popularity, df_volume, df_safety = pl.collect_all(
[
- self.popularity.lazy(),
- self.volume.lazy().sort("age_of_car"),
- self.safety.lazy(),
+ self.popularity,
+ self.volume,
+ self.safety,
]
)
header = [
diff --git a/pipeline/data.py b/pipeline/data.py
index 9114d6c..34aa36b 100644
--- a/pipeline/data.py
+++ b/pipeline/data.py
@@ -1,15 +1,17 @@
from dataclasses import dataclass
+import dataframely as dy
-import polars as pl
+from .schema.preprocessed import PrepPoliciesSchema, PrepModelsSchema
+from .schema.raw import RawModelsSchema, RawPoliciesSchema
@dataclass
-class RawData[T: (pl.DataFrame | pl.LazyFrame)]:
- models: T
- policies: T
+class RawData:
+ models: dy.LazyFrame[RawModelsSchema]
+ policies: dy.LazyFrame[RawPoliciesSchema]
@dataclass
-class PreprocessedData[T: (pl.DataFrame | pl.LazyFrame)]:
- models: T
- policies: T
+class PreprocessedData:
+ models: dy.LazyFrame[PrepModelsSchema]
+ policies: dy.LazyFrame[PrepPoliciesSchema]
diff --git a/pipeline/preprocess.py b/pipeline/preprocess.py
index 042a15d..b3814ab 100644
--- a/pipeline/preprocess.py
+++ b/pipeline/preprocess.py
@@ -1,6 +1,8 @@
import polars as pl
-
+import dataframely as dy
from .data import PreprocessedData, RawData
+from .schema.preprocessed import PrepModelsSchema, PrepPoliciesSchema
+from .schema.raw import RawModelsSchema, RawPoliciesSchema
def preprocess(raw: RawData) -> PreprocessedData:
@@ -10,62 +12,48 @@ def preprocess(raw: RawData) -> PreprocessedData:
)
-def preprocess_policies[T: (pl.DataFrame, pl.LazyFrame)](policies: T) -> T:
+def preprocess_policies(
+ policies: dy.LazyFrame[RawPoliciesSchema],
+) -> dy.LazyFrame[PrepPoliciesSchema]:
"""Transform the raw policies for optimal representation."""
return policies.with_columns(
- # Categorical columns
- pl.col("model").cast(pl.Categorical),
- pl.col("area_cluster").cast(pl.Categorical),
- # Float columns often do not need full 64-bit precision
- # This depends on the domain we are working on
- pl.col("policy_tenure").cast(pl.Float32),
- pl.col("age_of_car").cast(pl.Float32),
- pl.col("age_of_policyholder").cast(pl.Float32),
- pl.col("population_density").cast(pl.Float32),
# Normalize ID
- pl.col("policy_id").str.strip_prefix("policy").cast(pl.UInt64),
- )
+ pl.col("policy_id").str.strip_prefix("policy"),
+ ).pipe(PrepPoliciesSchema.validate, cast=True, eager=False)
-def preprocess_models[T: (pl.DataFrame, pl.LazyFrame)](models: T) -> T:
+def preprocess_models(
+ models: dy.LazyFrame[RawModelsSchema],
+) -> dy.LazyFrame[PrepModelsSchema]:
"""Transform the raw models for optimal representation."""
+ # Unique to drop duplicate rows we found while investigating primary key failures
+ df = models.unique()
+
# 1. Convert semantically boolean columns from pl.String to pl.Boolean
- df = models.with_columns(pl.col("^is_.*$") == "Yes")
+ df = df.with_columns(pl.col("^is_.*$") == "Yes")
# 2. Split max torque and power into components
torque_parts = pl.col("max_torque").str.split("@")
df = df.with_columns(
- max_torque_nm=torque_parts.list[0].str.strip_suffix("Nm").cast(pl.Float32),
- max_torque_rpm=torque_parts.list[1].str.strip_suffix("rpm").cast(pl.UInt16),
+ max_torque_nm=torque_parts.list[0].str.strip_suffix("Nm"),
+ max_torque_rpm=torque_parts.list[1].str.strip_suffix("rpm"),
)
power_parts = pl.col("max_power").str.split("@")
df = df.with_columns(
- max_power_bhp=power_parts.list[0].str.strip_suffix("bhp").cast(pl.Float16),
- max_power_rpm=power_parts.list[1].str.strip_suffix("rpm").cast(pl.UInt16),
+ max_power_bhp=power_parts.list[0].str.strip_suffix("bhp"),
+ max_power_rpm=power_parts.list[1].str.strip_suffix("rpm"),
)
- # Step 3: Use efficient data types
+ # Step 4: Ensure that length / width / height are in millimeters, not centimeters
+ def _ensure_mm(col: pl.Expr):
+ return pl.when(col < 1_000).then(col * 10).otherwise(col)
+
df = df.with_columns(
- # Some of the categorical columns are easily enumerated
- pl.col("steering_type").cast(pl.Enum(["Electric", "Manual", "Power"])),
- pl.col("fuel_type").cast(pl.Enum(["CNG", "Diesel", "Petrol"])),
- pl.col("rear_brakes_type").cast(pl.Enum(["Drum", "Disc"])),
- # For other categoricals, we may not be sure yet that we have seen all values
- # so we do not want to commit to an Enum, yet
- pl.col("engine_type").cast(pl.Categorical),
- pl.col("model").cast(pl.Categorical),
- pl.col("segment").cast(pl.Categorical),
- # Value-based dtypes
- pl.col("width").cast(pl.UInt16),
- pl.col("height").cast(pl.UInt16),
- pl.col("length").cast(pl.UInt16),
- pl.col("displacement").cast(pl.UInt16),
- pl.col("cylinder").cast(pl.UInt8),
- pl.col("gross_weight").cast(pl.UInt16),
- pl.col("gear_box").cast(pl.UInt8),
- pl.col("airbags").cast(pl.UInt8),
+ _ensure_mm(pl.col("length")),
+ _ensure_mm(pl.col("width")),
+ _ensure_mm(pl.col("height")),
)
- return df
+ return df.pipe(PrepModelsSchema.validate, cast=True, eager=False)
diff --git a/pipeline/report.py b/pipeline/report.py
index 1161270..d36173b 100644
--- a/pipeline/report.py
+++ b/pipeline/report.py
@@ -1,7 +1,14 @@
import polars as pl
+import dataframely as dy
from ._internal import Report
from .data import PreprocessedData
+from .schema.preprocessed import PrepModelsSchema, PrepPoliciesSchema
+from .schema.report import (
+ AverageCarVolumeSchema,
+ PopularModelsSchema,
+ SafestModelsSchema,
+)
def build_report(prep: PreprocessedData) -> Report:
@@ -12,9 +19,9 @@ def build_report(prep: PreprocessedData) -> Report:
)
-def find_three_most_popular_make_and_models[T: (pl.DataFrame, pl.LazyFrame)](
- models: T, policies: T
-) -> T:
+def find_three_most_popular_make_and_models(
+ models: dy.LazyFrame[PrepModelsSchema], policies: dy.LazyFrame[PrepPoliciesSchema]
+) -> dy.LazyFrame[PopularModelsSchema]:
"""Among all policies, compute the three make/model combinations that appears most often.
Returns:
@@ -26,10 +33,13 @@ def find_three_most_popular_make_and_models[T: (pl.DataFrame, pl.LazyFrame)](
.agg(count=pl.len())
.sort("count", descending=True)
.head(3)
+ .pipe(PopularModelsSchema.validate, cast=True, eager=False)
)
-def find_safest_models[T: (pl.DataFrame, pl.LazyFrame)](models: T) -> T:
+def find_safest_models(
+ models: dy.LazyFrame[PrepModelsSchema],
+) -> dy.LazyFrame[SafestModelsSchema]:
"""Among all models, find the safest ones as measured by the number of safety features.
Returns:
@@ -41,12 +51,13 @@ def find_safest_models[T: (pl.DataFrame, pl.LazyFrame)](models: T) -> T:
)
.sort("safety_score", descending=True)
.head(5)
+ .pipe(SafestModelsSchema.validate, cast=True, eager=False)
)
-def find_average_car_volume_by_age[T: (pl.DataFrame, pl.LazyFrame)](
- models: T, policies: T
-) -> T:
+def find_average_car_volume_by_age(
+ models: dy.LazyFrame[PrepModelsSchema], policies: dy.LazyFrame[PrepPoliciesSchema]
+) -> dy.LazyFrame[AverageCarVolumeSchema]:
"""Among all policies, find the mean physical car volume in 10-year blocks of car age.
This method should compute the volume of a car if interpreted as cuboid (i.e. box-shaped).
@@ -68,4 +79,6 @@ def find_average_car_volume_by_age[T: (pl.DataFrame, pl.LazyFrame)](
- 1
)
)
+ .sort("age_of_car")
+ .pipe(AverageCarVolumeSchema.validate, cast=True, eager=False)
)
diff --git a/pipeline/schema/preprocessed.py b/pipeline/schema/preprocessed.py
index e3a5058..9685b6e 100644
--- a/pipeline/schema/preprocessed.py
+++ b/pipeline/schema/preprocessed.py
@@ -1,4 +1,5 @@
import dataframely as dy
+import polars as pl
class PrepPoliciesSchema(dy.Schema):
@@ -51,3 +52,12 @@ class PrepModelsSchema(dy.Schema):
max_torque_rpm = dy.UInt16()
max_power_bhp = dy.Float32()
max_power_rpm = dy.UInt16()
+
+ @dy.rule()
+ def volume_is_realistic(cls) -> pl.Expr:
+ """Only allow reasonably sized cars"""
+ volume = cls.length.col.cast(pl.UInt64) * cls.width.col * cls.height.col
+
+ # Lengths are in millimeters and 1e9 mm^3 is 1 cubic meter
+ cubic_meter = 1e9
+ return volume.is_between(1 * cubic_meter, 20 * cubic_meter)
diff --git a/pipeline/schema/raw.py b/pipeline/schema/raw.py
index 56a3d9c..a83d089 100644
--- a/pipeline/schema/raw.py
+++ b/pipeline/schema/raw.py
@@ -16,7 +16,7 @@ class RawPoliciesSchema(dy.Schema):
class RawModelsSchema(dy.Schema):
"""Schema for the raw models table as provided by our data source"""
- model = dy.String(primary_key=True)
+ model = dy.String()
segment = dy.String()
fuel_type = dy.String()
airbags = dy.Int64()
diff --git a/pipeline/schema/report.py b/pipeline/schema/report.py
index 5adec38..8f54cfc 100644
--- a/pipeline/schema/report.py
+++ b/pipeline/schema/report.py
@@ -2,15 +2,18 @@
class PopularModelsSchema(dy.Schema):
- # TODO: Fill out
- ...
+ make = dy.String()
+ model = dy.String(primary_key=True)
+ count = dy.UInt32()
class SafestModelsSchema(dy.Schema):
- # TODO: Fill out
- ...
+ model = dy.Categorical(primary_key=True)
+ segment = dy.Categorical()
+ safety_score = dy.UInt16()
class AverageCarVolumeSchema(dy.Schema):
- # TODO: Fill out
- ...
+ age_of_car = dy.String(primary_key=True)
+ volume = dy.Float32()
+ change = dy.Float32(nullable=True)
diff --git a/tests/test_report.py b/tests/test_report.py
index d8b7623..87645ee 100644
--- a/tests/test_report.py
+++ b/tests/test_report.py
@@ -13,18 +13,22 @@ def test_find_average_car_volume_by_age():
{"model": "M2", "height": 2_000, "width": 2_000, "length": 2_000},
]
)
- # TODO: Use `.sample` to create a policies dataframe with two policies:
- # One with model "M1" and car age 4.5,
- # One with model "M2" and car age 14.5
- policies = PrepPoliciesSchema.sample(...)
+ policies = PrepPoliciesSchema.sample(
+ overrides=[
+ {"model": "M1", "age_of_car": 4.5},
+ {"model": "M2", "age_of_car": 14.5},
+ ]
+ )
volume_m1 = 1e-9 * 1_500 * 2_000 * 2_500
volume_m2 = 1e-9 * 2_000 * 2_000 * 2_000
change = 100 * (volume_m2 / volume_m1 - 1)
expected = AverageCarVolumeSchema.validate(
- # TODO: Add the second, missing row for the expected dataframe
pl.DataFrame(
- [{"age_of_car": "(-inf, 10]", "volume": volume_m1, "change": None}, ...]
+ [
+ {"age_of_car": "(-inf, 10]", "volume": volume_m1, "change": None},
+ {"age_of_car": "(10, 20]", "volume": volume_m2, "change": change},
+ ]
),
cast=True,
).lazy()
diff --git a/tutorial.ipynb b/tutorial.ipynb
index 5c3fe93..4bf2dc0 100644
--- a/tutorial.ipynb
+++ b/tutorial.ipynb
@@ -421,23 +421,35 @@
"outputs": [],
"source": [
"from pipeline.data import RawData\n",
- "from pipeline.preprocess import preprocess"
+ "from pipeline.preprocess import preprocess\n",
+ "from pipeline.schema.raw import RawModelsSchema, RawPoliciesSchema"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 27,
"id": "d29cb3ab",
"metadata": {},
"outputs": [],
"source": [
- "raw = RawData(models, policies)\n",
+ "raw = RawData(\n",
+ " models=RawModelsSchema.validate(models, cast=True).lazy(),\n",
+ " policies=RawPoliciesSchema.validate(policies, cast=True).lazy()\n",
+ ")\n",
"preprocessed = preprocess(raw)"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
+ "id": "05b24a9d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
"id": "6f3f01b1",
"metadata": {},
"outputs": [],
@@ -451,7 +463,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 29,
"id": "bf070282",
"metadata": {},
"outputs": [
@@ -466,15 +478,15 @@
"source": [
"print(\n",
" \"Policies: \"\n",
- " f\"{raw.policies.estimated_size('mb'):.2f} MB\",\n",
+ " f\"{raw.policies.collect().estimated_size('mb'):.2f} MB\",\n",
" \"->\",\n",
- " f\"{preprocessed.policies.estimated_size('mb'):.2f} MB\",\n",
+ " f\"{preprocessed.policies.collect().estimated_size('mb'):.2f} MB\",\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 30,
"id": "59e7a957",
"metadata": {},
"outputs": [
@@ -482,22 +494,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Models: 0.26 MB -> 0.14 MB\n"
+ "Models: 0.25 MB -> 0.08 MB\n"
]
}
],
"source": [
"print(\n",
" \"Models: \"\n",
- " f\"{raw.models.estimated_size('mb'):.2f} MB\",\n",
+ " f\"{raw.models.collect().estimated_size('mb'):.2f} MB\",\n",
" \"->\",\n",
- " f\"{preprocessed.models.estimated_size('mb'):.2f} MB\",\n",
+ " f\"{preprocessed.models.collect().estimated_size('mb'):.2f} MB\",\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 47,
"id": "f642d158",
"metadata": {},
"outputs": [
@@ -526,7 +538,7 @@
"└─────────┴──────────┘"
]
},
- "execution_count": 21,
+ "execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
@@ -536,12 +548,12 @@
"raw.models.select(\n",
" logical = pl.col(\"fuel_type\"),\n",
" physical = pl.col(\"fuel_type\").to_physical()\n",
- ").head(3)"
+ ").head(3).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 49,
"id": "cc991ed6",
"metadata": {},
"outputs": [
@@ -555,10 +567,10 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (3, 2)
| logical | physical |
|---|
| enum | u8 |
| "Diesel" | 1 |
| "CNG" | 0 |
| "Petrol" | 2 |
"
+ "shape: (10, 2)| logical | physical |
|---|
| enum | u8 |
| "Diesel" | 1 |
| "CNG" | 0 |
| "Diesel" | 1 |
| "Diesel" | 1 |
| "Petrol" | 2 |
| "Petrol" | 2 |
| "Diesel" | 1 |
| "CNG" | 0 |
| "Diesel" | 1 |
| "CNG" | 0 |
"
],
"text/plain": [
- "shape: (3, 2)\n",
+ "shape: (10, 2)\n",
"┌─────────┬──────────┐\n",
"│ logical ┆ physical │\n",
"│ --- ┆ --- │\n",
@@ -566,11 +578,18 @@
"╞═════════╪══════════╡\n",
"│ Diesel ┆ 1 │\n",
"│ CNG ┆ 0 │\n",
+ "│ Diesel ┆ 1 │\n",
+ "│ Diesel ┆ 1 │\n",
+ "│ Petrol ┆ 2 │\n",
"│ Petrol ┆ 2 │\n",
+ "│ Diesel ┆ 1 │\n",
+ "│ CNG ┆ 0 │\n",
+ "│ Diesel ┆ 1 │\n",
+ "│ CNG ┆ 0 │\n",
"└─────────┴──────────┘"
]
},
- "execution_count": 22,
+ "execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
@@ -582,7 +601,7 @@
"preprocessed.models.select(\n",
" logical = pl.col(\"fuel_type\"),\n",
" physical = pl.col(\"fuel_type\").to_physical()\n",
- ").head(3)"
+ ").head(10).collect()"
]
},
{
@@ -595,7 +614,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 33,
"id": "db661295",
"metadata": {},
"outputs": [],
@@ -605,7 +624,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 34,
"id": "e4f6939d",
"metadata": {},
"outputs": [],
@@ -619,7 +638,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 35,
"id": "15183652",
"metadata": {},
"outputs": [
@@ -638,7 +657,7 @@
"┌──────┬───────┬────────┐\n",
"│ make ┆ model ┆ count │\n",
"│ --- ┆ --- ┆ --- │\n",
- "│ u64 ┆ cat ┆ u32 │\n",
+ "│ str ┆ str ┆ u32 │\n",
"╞══════╪═══════╪════════╡\n",
"│ 10 ┆ M868 ┆ 263656 │\n",
"│ 6 ┆ M652 ┆ 131735 │\n",
@@ -650,13 +669,13 @@
"┌───────┬─────────┬──────────────┐\n",
"│ model ┆ segment ┆ safety_score │\n",
"│ --- ┆ --- ┆ --- │\n",
- "│ cat ┆ cat ┆ u32 │\n",
+ "│ str ┆ cat ┆ u16 │\n",
"╞═══════╪═════════╪══════════════╡\n",
- "│ M282 ┆ A ┆ 15 │\n",
"│ M626 ┆ C1 ┆ 15 │\n",
+ "│ M282 ┆ A ┆ 15 │\n",
+ "│ M301 ┆ Utility ┆ 15 │\n",
"│ M919 ┆ A ┆ 15 │\n",
"│ M998 ┆ Utility ┆ 15 │\n",
- "│ M301 ┆ Utility ┆ 15 │\n",
"└───────┴─────────┴──────────────┘\n",
"\n",
"\n",
@@ -666,13 +685,13 @@
"┌────────────┬──────────┬────────────┐\n",
"│ age_of_car ┆ volume ┆ change │\n",
"│ --- ┆ --- ┆ --- │\n",
- "│ cat ┆ f64 ┆ f64 │\n",
+ "│ str ┆ f32 ┆ f32 │\n",
"╞════════════╪══════════╪════════════╡\n",
- "│ (-inf, 10] ┆ 8.854143 ┆ null │\n",
- "│ (10, 20] ┆ 1.411613 ┆ -84.057029 │\n",
- "│ (20, 30] ┆ 0.007873 ┆ -99.442247 │\n",
- "│ (30, 40] ┆ 0.00739 ┆ -6.137952 │\n",
- "│ (40, 50] ┆ 0.007501 ┆ 1.503317 │\n",
+ "│ (-inf, 10] ┆ 9.980824 ┆ null │\n",
+ "│ (10, 20] ┆ 8.299253 ┆ -16.848015 │\n",
+ "│ (20, 30] ┆ 7.874096 ┆ -5.122838 │\n",
+ "│ (30, 40] ┆ 7.391098 ┆ -6.134008 │\n",
+ "│ (40, 50] ┆ 7.501146 ┆ 1.488916 │\n",
"└────────────┴──────────┴────────────┘\n",
"\n",
"============================================================\n"
@@ -693,7 +712,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 36,
"id": "5be4e3ba",
"metadata": {},
"outputs": [],
@@ -708,7 +727,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 37,
"id": "8f301f7b",
"metadata": {},
"outputs": [],
@@ -718,7 +737,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 38,
"id": "fd2ff44c",
"metadata": {},
"outputs": [],
@@ -731,7 +750,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 39,
"id": "8206dc4b",
"metadata": {},
"outputs": [],
@@ -742,148 +761,328 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 40,
"id": "755fea10",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
- "