From 80f89862ae344813f5d6afed462f500e46732486 Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Thu, 29 Jan 2026 10:00:01 +0100 Subject: [PATCH 1/2] Allow `Discriminator` for `discriminator` in `Field` --- sqlmodel/__init__.py | 4 ++++ sqlmodel/main.py | 10 ++++----- tests/test_pydantic/test_field.py | 37 +++++++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index d4210b06d8..585c928594 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -1,5 +1,9 @@ __version__ = "0.0.31" +# Re-export from Pydantic +from pydantic import Discriminator as Discriminator +from pydantic import Tag as Tag + # Re-export from SQLAlchemy from sqlalchemy.engine import create_engine as create_engine from sqlalchemy.engine import create_mock_engine as create_mock_engine diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 84478f24cf..cd6dcb1679 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -22,7 +22,7 @@ overload, ) -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel, Discriminator, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( Boolean, @@ -228,7 +228,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, - discriminator: Optional[str] = None, + discriminator: Union[str, Discriminator, None] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, @@ -271,7 +271,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, - discriminator: Optional[str] = None, + discriminator: Union[str, Discriminator, None] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: str, @@ -323,7 +323,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, - discriminator: Optional[str] = None, + discriminator: Union[str, Discriminator, None] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, schema_extra: Optional[dict[str, Any]] = None, @@ -356,7 +356,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, - discriminator: Optional[str] = None, + discriminator: Union[str, Discriminator, None] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, foreign_key: Any = Undefined, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 140b02fd9b..f52a1cc1b0 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -1,9 +1,9 @@ from decimal import Decimal -from typing import Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union import pytest from pydantic import ValidationError -from sqlmodel import Field, SQLModel +from sqlmodel import Discriminator, Field, SQLModel, Tag def test_decimal(): @@ -47,6 +47,39 @@ class Model(SQLModel): Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type] +def test_discriminator_callable(): + # Example adapted from + # [Pydantic docs](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator): + + class Pie(SQLModel): + pass + + class ApplePie(Pie): + fruit: Literal["apple"] = "apple" + + class PumpkinPie(Pie): + filling: Literal["pumpkin"] = "pumpkin" + + def get_discriminator_value(v: Any) -> str: + if isinstance(v, dict): + return v.get("fruit", v.get("filling")) + return getattr(v, "fruit", getattr(v, "filling", None)) + + class ThanksgivingDinner(SQLModel): + dessert: Union[ + Annotated[ApplePie, Tag("apple")], + Annotated[PumpkinPie, Tag("pumpkin")], + ] = Field( + discriminator=Discriminator(get_discriminator_value), + ) + + apple_pie = ThanksgivingDinner.model_validate({"dessert": {"fruit": "apple"}}) + assert isinstance(apple_pie.dessert, ApplePie) + + pumpkin_pie = ThanksgivingDinner.model_validate({"dessert": {"filling": "pumpkin"}}) + assert isinstance(pumpkin_pie.dessert, PumpkinPie) + + def test_repr(): class Model(SQLModel): id: Optional[int] = Field(primary_key=True) From 8bd4163579837d68b852650fedeab9b2c57c002b Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Fri, 20 Feb 2026 22:37:06 +0100 Subject: [PATCH 2/2] Update test to Python 3.10+ syntax --- tests/test_pydantic/test_field.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index f52a1cc1b0..1956f55bf2 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal import pytest from pydantic import ValidationError @@ -38,7 +38,7 @@ class Lizard(SQLModel): scales: bool class Model(SQLModel): - pet: Union[Cat, Dog, Lizard] = Field(..., discriminator="pet_type") + pet: Cat | Dog | Lizard = Field(..., discriminator="pet_type") n: int Model(pet={"pet_type": "dog", "barks": 3.14}, n=1) # type: ignore[arg-type] @@ -66,10 +66,9 @@ def get_discriminator_value(v: Any) -> str: return getattr(v, "fruit", getattr(v, "filling", None)) class ThanksgivingDinner(SQLModel): - dessert: Union[ - Annotated[ApplePie, Tag("apple")], - Annotated[PumpkinPie, Tag("pumpkin")], - ] = Field( + dessert: ( + Annotated[ApplePie, Tag("apple")] | Annotated[PumpkinPie, Tag("pumpkin")] + ) = Field( discriminator=Discriminator(get_discriminator_value), ) @@ -82,7 +81,7 @@ class ThanksgivingDinner(SQLModel): def test_repr(): class Model(SQLModel): - id: Optional[int] = Field(primary_key=True) + id: int | None = Field(primary_key=True) foo: str = Field(repr=False) instance = Model(id=123, foo="bar")