From 9faa75fd408a06c09f14d60e7cffa73ad3b1d381 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 19 Dec 2025 11:05:22 +0100
Subject: [PATCH 1/2] Update pre-commit hooks
---
.pre-commit-config.yaml | 6 ++--
pyproject.toml | 6 ++--
src/config.py | 4 +--
src/database/evaluations.py | 2 +-
src/database/flows.py | 4 +--
src/database/studies.py | 6 ++--
src/database/tasks.py | 8 ++---
src/routers/openml/tasks.py | 35 +++++++++++--------
src/routers/openml/tasktype.py | 4 +--
src/schemas/datasets/mldcat_ap.py | 9 ++---
.../openml/datasets_list_datasets_test.py | 8 ++---
tests/users.py | 8 ++---
12 files changed, 50 insertions(+), 50 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 09b9d47..2f3bac8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,7 +3,7 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v5.0.0
+ rev: v6.0.0
hooks:
- id: check-ast
- id: check-toml
@@ -18,7 +18,7 @@ repos:
# - id: no-commit-to-branch
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: 'v1.11.2'
+ rev: 'v1.19.1'
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +26,7 @@ repos:
- pytest
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: 'v0.6.9'
+ rev: 'v0.14.10'
hooks:
- id: ruff
args: [--fix]
diff --git a/pyproject.toml b/pyproject.toml
index a7bff61..25f459b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,8 +54,6 @@ line-length = 100
# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later
select = ["ALL"]
ignore = [
- "ANN101", # style choice - no annotation for self
- "ANN102", # style choice - no annotation for cls
"CPY", # we do not require copyright in every file
"D", # todo: docstring linting
"D203",
@@ -63,8 +61,8 @@ ignore = [
"D213",
"DTZ", # To add
# Linter does not detect when types are used for Pydantic
- "TCH001",
- "TCH003",
+ "TC001",
+ "TC003",
]
[tool.ruff.lint.per-file-ignores]
diff --git a/src/config.py b/src/config.py
index 8a19f04..0712271 100644
--- a/src/config.py
+++ b/src/config.py
@@ -22,11 +22,11 @@ def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable:
@functools.cache
def _load_configuration(file: Path) -> TomlTable:
- return typing.cast(TomlTable, tomllib.loads(file.read_text()))
+ return tomllib.loads(file.read_text())
def load_routing_configuration(file: Path = CONFIG_PATH) -> TomlTable:
- return typing.cast(TomlTable, _load_configuration(file)["routing"])
+ return typing.cast("TomlTable", _load_configuration(file)["routing"])
@functools.cache
diff --git a/src/database/evaluations.py b/src/database/evaluations.py
index f98b15e..799d411 100644
--- a/src/database/evaluations.py
+++ b/src/database/evaluations.py
@@ -9,7 +9,7 @@
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
connection.execute(
text(
"""
diff --git a/src/database/flows.py b/src/database/flows.py
index 93fb219..3129e91 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -6,7 +6,7 @@
def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -36,7 +36,7 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:
def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
diff --git a/src/database/studies.py b/src/database/studies.py
index 848c034..35c1b79 100644
--- a/src/database/studies.py
+++ b/src/database/studies.py
@@ -43,7 +43,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
"""
if study.type_ == StudyType.TASK:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -56,7 +56,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
).all(),
)
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -103,7 +103,7 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int:
},
)
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one()
- return cast(int, study_id)
+ return cast("int", study_id)
def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None:
diff --git a/src/database/tasks.py b/src/database/tasks.py
index 56a6718..97caef3 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -19,7 +19,7 @@ def get(id_: int, expdb: Connection) -> Row | None:
def get_task_types(expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -46,7 +46,7 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None:
def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -62,7 +62,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro
def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
@@ -78,7 +78,7 @@ def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]:
return cast(
- Sequence[Row],
+ "Sequence[Row]",
expdb.execute(
text(
"""
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 96d0198..069e611 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -1,7 +1,7 @@
import json
import re
from http import HTTPStatus
-from typing import Annotated, Any
+from typing import Annotated, cast
import xmltodict
from fastapi import APIRouter, Depends, HTTPException
@@ -15,14 +15,16 @@
router = APIRouter(prefix="/tasks", tags=["tasks"])
+type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
-def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401
+
+def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
json_str = json.dumps(json_template)
# To account for the differences between PHP and Python conversions:
for py, php in [("@name", "name"), ("#text", "value"), ("@type", "type")]:
json_str = json_str.replace(py, php)
- return json.loads(json_str)
+ return cast("dict[str, JSON]", json.loads(json_str))
def fill_template(
@@ -30,7 +32,7 @@ def fill_template(
task: RowMapping,
task_inputs: dict[str, str],
connection: Connection,
-) -> Any: # noqa: ANN401
+) -> dict[str, JSON]:
"""Fill in the XML template as used for task descriptions and return the result,
converted to JSON.
@@ -79,22 +81,25 @@ def fill_template(
}
"""
json_template = convert_template_xml_to_json(template)
- return _fill_json_template(
- json_template,
- task,
- task_inputs,
- fetched_data={},
- connection=connection,
+ return cast(
+ "dict[str, JSON]",
+ _fill_json_template(
+ json_template,
+ task,
+ task_inputs,
+ fetched_data={},
+ connection=connection,
+ ),
)
def _fill_json_template(
- template: dict[str, Any],
+ template: JSON,
task: RowMapping,
task_inputs: dict[str, str],
- fetched_data: dict[str, Any],
+ fetched_data: dict[str, str],
connection: Connection,
-) -> dict[str, Any] | list[dict[str, Any]] | str:
+) -> JSON:
if isinstance(template, dict):
return {
k: _fill_json_template(v, task, task_inputs, fetched_data, connection)
@@ -158,7 +163,7 @@ def get_task(
)
task_inputs = {
- row.input: int(row.value) if row.value.isdigit() else row.value
+ row.input: str(int(row.value)) if row.value.isdigit() else row.value
for row in database.tasks.get_input_for_task(task_id, expdb)
}
ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb)
@@ -176,7 +181,7 @@ def get_task(
tags = database.tasks.get_tags(task_id, expdb)
name = f"Task {task_id} ({task_type.name})"
dataset_id = task_inputs.get("source_data")
- if dataset_id and (dataset := database.datasets.get(dataset_id, expdb)):
+ if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)):
name = f"Task {task_id}: {dataset.name} ({task_type.name})"
return Task(
diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py
index dcc9b1c..5213f17 100644
--- a/src/routers/openml/tasktype.py
+++ b/src/routers/openml/tasktype.py
@@ -53,11 +53,11 @@ def get_task_type(
task_type = _normalize_task_type(task_type_record)
# Some names are quoted, or have typos in their comma-separation (e.g. 'A ,B')
task_type["creator"] = [
- creator.strip(' "') for creator in cast(str, task_type["creator"]).split(",")
+ creator.strip(' "') for creator in cast("str", task_type["creator"]).split(",")
]
if contributors := task_type.pop("contributors"):
task_type["contributor"] = [
- creator.strip(' "') for creator in cast(str, contributors).split(",")
+ creator.strip(' "') for creator in cast("str", contributors).split(",")
]
task_type["creation_date"] = task_type.pop("creationDate")
task_type_inputs = get_input_for_task_type(task_type_id, expdb)
diff --git a/src/schemas/datasets/mldcat_ap.py b/src/schemas/datasets/mldcat_ap.py
index d7e277f..ffbe644 100644
--- a/src/schemas/datasets/mldcat_ap.py
+++ b/src/schemas/datasets/mldcat_ap.py
@@ -10,7 +10,7 @@
from abc import ABC
from enum import StrEnum
-from typing import Generic, Literal, TypeVar
+from typing import Literal
from pydantic import BaseModel, Field, HttpUrl, field_serializer, model_serializer
@@ -41,10 +41,7 @@ class JsonLDObject(BaseModel, ABC):
}
-T = TypeVar("T", bound=JsonLDObject)
-
-
-class JsonLDObjectReference(BaseModel, Generic[T]):
+class JsonLDObjectReference[T: JsonLDObject](BaseModel):
id_: str = Field(serialization_alias="@id")
model_config = {"populate_by_name": True, "extra": "forbid"}
@@ -275,7 +272,7 @@ class DataService(JsonLDObject):
class JsonLDGraph(BaseModel):
- context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context") # type: ignore[arg-type]
+ context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context")
graph: list[Distribution | DataService | Dataset | Quality | Feature | Agent | MD5Checksum] = (
Field(default_factory=list, serialization_alias="@graph")
)
diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py
index bf0fbc8..6e0492e 100644
--- a/tests/routers/openml/datasets_list_datasets_test.py
+++ b/tests/routers/openml/datasets_list_datasets_test.py
@@ -224,12 +224,12 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
@pytest.mark.slow
-@hypothesis.settings(
+@hypothesis.settings( # type: ignore[untyped-decorator] # 108
max_examples=5000,
suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
deadline=None,
-) # type: ignore[misc] # https://github.com/openml/server-api/issues/108
-@given(
+)
+@given( # type: ignore[untyped-decorator] # 108
number_missing_values=st.sampled_from([None, "2", "2..10000"]),
number_features=st.sampled_from([None, "5", "2..100"]),
number_classes=st.sampled_from([None, "5", "2..100"]),
@@ -243,7 +243,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
tag=st.sampled_from([None, "study_14", "study_not_in_db"]),
# We don't test ADMIN user, as we fixed a bug which treated them as a regular user
api_key=st.sampled_from([None, ApiKey.SOME_USER, ApiKey.OWNER_USER]),
-) # type: ignore[misc] # https://github.com/openml/server-api/issues/108
+)
def test_list_data_identical(
py_api: TestClient,
php_api: httpx.Client,
diff --git a/tests/users.py b/tests/users.py
index ab92593..23bc325 100644
--- a/tests/users.py
+++ b/tests/users.py
@@ -9,7 +9,7 @@
class ApiKey(StrEnum):
- ADMIN: str = "AD000000000000000000000000000000"
- SOME_USER: str = "00000000000000000000000000000000"
- OWNER_USER: str = "DA1A0000000000000000000000000000"
- INVALID: str = "11111111111111111111111111111111"
+ ADMIN = "AD000000000000000000000000000000"
+ SOME_USER = "00000000000000000000000000000000"
+ OWNER_USER = "DA1A0000000000000000000000000000"
+ INVALID = "11111111111111111111111111111111"
From 0efcb7193b9451c0333017ccab3abced15b0b025 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 19 Dec 2025 11:29:52 +0100
Subject: [PATCH 2/2] Preserve integer typing for input parameter
---
src/routers/openml/tasks.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 069e611..8397f1d 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -30,7 +30,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
def fill_template(
template: str,
task: RowMapping,
- task_inputs: dict[str, str],
+ task_inputs: dict[str, str | int],
connection: Connection,
) -> dict[str, JSON]:
"""Fill in the XML template as used for task descriptions and return the result,
@@ -96,7 +96,7 @@ def fill_template(
def _fill_json_template(
template: JSON,
task: RowMapping,
- task_inputs: dict[str, str],
+ task_inputs: dict[str, str | int],
fetched_data: dict[str, str],
connection: Connection,
) -> JSON:
@@ -120,7 +120,7 @@ def _fill_json_template(
if match.string == template:
# How do we know the default value? probably ttype_io table?
return task_inputs.get(field, [])
- template = template.replace(match.group(), task_inputs[field])
+ template = template.replace(match.group(), str(task_inputs[field]))
if match := re.search(r"\[LOOKUP:(.*)]", template):
(field,) = match.groups()
if field not in fetched_data:
@@ -163,7 +163,7 @@ def get_task(
)
task_inputs = {
- row.input: str(int(row.value)) if row.value.isdigit() else row.value
+ row.input: int(row.value) if row.value.isdigit() else row.value
for row in database.tasks.get_input_for_task(task_id, expdb)
}
ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb)