diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09b9d47..2f3bac8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-ast - id: check-toml @@ -18,7 +18,7 @@ repos: # - id: no-commit-to-branch - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.11.2' + rev: 'v1.19.1' hooks: - id: mypy additional_dependencies: @@ -26,7 +26,7 @@ repos: - pytest - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.6.9' + rev: 'v0.14.10' hooks: - id: ruff args: [--fix] diff --git a/pyproject.toml b/pyproject.toml index a7bff61..25f459b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,6 @@ line-length = 100 # The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later select = ["ALL"] ignore = [ - "ANN101", # style choice - no annotation for self - "ANN102", # style choice - no annotation for cls "CPY", # we do not require copyright in every file "D", # todo: docstring linting "D203", @@ -63,8 +61,8 @@ ignore = [ "D213", "DTZ", # To add # Linter does not detect when types are used for Pydantic - "TCH001", - "TCH003", + "TC001", + "TC003", ] [tool.ruff.lint.per-file-ignores] diff --git a/src/config.py b/src/config.py index 8a19f04..0712271 100644 --- a/src/config.py +++ b/src/config.py @@ -22,11 +22,11 @@ def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable: @functools.cache def _load_configuration(file: Path) -> TomlTable: - return typing.cast(TomlTable, tomllib.loads(file.read_text())) + return tomllib.loads(file.read_text()) def load_routing_configuration(file: Path = CONFIG_PATH) -> TomlTable: - return typing.cast(TomlTable, _load_configuration(file)["routing"]) + return typing.cast("TomlTable", _load_configuration(file)["routing"]) @functools.cache diff --git a/src/database/evaluations.py b/src/database/evaluations.py index f98b15e..799d411 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -9,7 +9,7 @@ def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", connection.execute( text( """ diff --git a/src/database/flows.py b/src/database/flows.py index 93fb219..3129e91 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -6,7 +6,7 @@ def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -36,7 +36,7 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]: def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ diff --git a/src/database/studies.py b/src/database/studies.py index 848c034..35c1b79 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -43,7 +43,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: """ if study.type_ == StudyType.TASK: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -56,7 +56,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]: ).all(), ) return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -103,7 +103,7 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int: }, ) (study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one() - return cast(int, study_id) + return cast("int", study_id) def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None: diff --git a/src/database/tasks.py b/src/database/tasks.py index 56a6718..97caef3 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -19,7 +19,7 @@ def get(id_: int, expdb: Connection) -> Row | None: def get_task_types(expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -46,7 +46,7 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None: def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -62,7 +62,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ @@ -78,7 +78,7 @@ def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]: def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]: return cast( - Sequence[Row], + "Sequence[Row]", expdb.execute( text( """ diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 96d0198..8397f1d 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -1,7 +1,7 @@ import json import re from http import HTTPStatus -from typing import Annotated, Any +from typing import Annotated, cast import xmltodict from fastapi import APIRouter, Depends, HTTPException @@ -15,22 +15,24 @@ router = APIRouter(prefix="/tasks", tags=["tasks"]) +type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None -def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401 + +def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: json_template = xmltodict.parse(xml_template.replace("oml:", "")) json_str = json.dumps(json_template) # To account for the differences between PHP and Python conversions: for py, php in [("@name", "name"), ("#text", "value"), ("@type", "type")]: json_str = json_str.replace(py, php) - return json.loads(json_str) + return cast("dict[str, JSON]", json.loads(json_str)) def fill_template( template: str, task: RowMapping, - task_inputs: dict[str, str], + task_inputs: dict[str, str | int], connection: Connection, -) -> Any: # noqa: ANN401 +) -> dict[str, JSON]: """Fill in the XML template as used for task descriptions and return the result, converted to JSON. @@ -79,22 +81,25 @@ def fill_template( } """ json_template = convert_template_xml_to_json(template) - return _fill_json_template( - json_template, - task, - task_inputs, - fetched_data={}, - connection=connection, + return cast( + "dict[str, JSON]", + _fill_json_template( + json_template, + task, + task_inputs, + fetched_data={}, + connection=connection, + ), ) def _fill_json_template( - template: dict[str, Any], + template: JSON, task: RowMapping, - task_inputs: dict[str, str], - fetched_data: dict[str, Any], + task_inputs: dict[str, str | int], + fetched_data: dict[str, str], connection: Connection, -) -> dict[str, Any] | list[dict[str, Any]] | str: +) -> JSON: if isinstance(template, dict): return { k: _fill_json_template(v, task, task_inputs, fetched_data, connection) @@ -115,7 +120,7 @@ def _fill_json_template( if match.string == template: # How do we know the default value? probably ttype_io table? return task_inputs.get(field, []) - template = template.replace(match.group(), task_inputs[field]) + template = template.replace(match.group(), str(task_inputs[field])) if match := re.search(r"\[LOOKUP:(.*)]", template): (field,) = match.groups() if field not in fetched_data: @@ -176,7 +181,7 @@ def get_task( tags = database.tasks.get_tags(task_id, expdb) name = f"Task {task_id} ({task_type.name})" dataset_id = task_inputs.get("source_data") - if dataset_id and (dataset := database.datasets.get(dataset_id, expdb)): + if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)): name = f"Task {task_id}: {dataset.name} ({task_type.name})" return Task( diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index dcc9b1c..5213f17 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -53,11 +53,11 @@ def get_task_type( task_type = _normalize_task_type(task_type_record) # Some names are quoted, or have typos in their comma-separation (e.g. 'A ,B') task_type["creator"] = [ - creator.strip(' "') for creator in cast(str, task_type["creator"]).split(",") + creator.strip(' "') for creator in cast("str", task_type["creator"]).split(",") ] if contributors := task_type.pop("contributors"): task_type["contributor"] = [ - creator.strip(' "') for creator in cast(str, contributors).split(",") + creator.strip(' "') for creator in cast("str", contributors).split(",") ] task_type["creation_date"] = task_type.pop("creationDate") task_type_inputs = get_input_for_task_type(task_type_id, expdb) diff --git a/src/schemas/datasets/mldcat_ap.py b/src/schemas/datasets/mldcat_ap.py index d7e277f..ffbe644 100644 --- a/src/schemas/datasets/mldcat_ap.py +++ b/src/schemas/datasets/mldcat_ap.py @@ -10,7 +10,7 @@ from abc import ABC from enum import StrEnum -from typing import Generic, Literal, TypeVar +from typing import Literal from pydantic import BaseModel, Field, HttpUrl, field_serializer, model_serializer @@ -41,10 +41,7 @@ class JsonLDObject(BaseModel, ABC): } -T = TypeVar("T", bound=JsonLDObject) - - -class JsonLDObjectReference(BaseModel, Generic[T]): +class JsonLDObjectReference[T: JsonLDObject](BaseModel): id_: str = Field(serialization_alias="@id") model_config = {"populate_by_name": True, "extra": "forbid"} @@ -275,7 +272,7 @@ class DataService(JsonLDObject): class JsonLDGraph(BaseModel): - context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context") # type: ignore[arg-type] + context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context") graph: list[Distribution | DataService | Dataset | Quality | Feature | Agent | MD5Checksum] = ( Field(default_factory=list, serialization_alias="@graph") ) diff --git a/tests/routers/openml/datasets_list_datasets_test.py b/tests/routers/openml/datasets_list_datasets_test.py index bf0fbc8..6e0492e 100644 --- a/tests/routers/openml/datasets_list_datasets_test.py +++ b/tests/routers/openml/datasets_list_datasets_test.py @@ -224,12 +224,12 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl @pytest.mark.slow -@hypothesis.settings( +@hypothesis.settings( # type: ignore[untyped-decorator] # 108 max_examples=5000, suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None, -) # type: ignore[misc] # https://github.com/openml/server-api/issues/108 -@given( +) +@given( # type: ignore[untyped-decorator] # 108 number_missing_values=st.sampled_from([None, "2", "2..10000"]), number_features=st.sampled_from([None, "5", "2..100"]), number_classes=st.sampled_from([None, "5", "2..100"]), @@ -243,7 +243,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl tag=st.sampled_from([None, "study_14", "study_not_in_db"]), # We don't test ADMIN user, as we fixed a bug which treated them as a regular user api_key=st.sampled_from([None, ApiKey.SOME_USER, ApiKey.OWNER_USER]), -) # type: ignore[misc] # https://github.com/openml/server-api/issues/108 +) def test_list_data_identical( py_api: TestClient, php_api: httpx.Client, diff --git a/tests/users.py b/tests/users.py index ab92593..23bc325 100644 --- a/tests/users.py +++ b/tests/users.py @@ -9,7 +9,7 @@ class ApiKey(StrEnum): - ADMIN: str = "AD000000000000000000000000000000" - SOME_USER: str = "00000000000000000000000000000000" - OWNER_USER: str = "DA1A0000000000000000000000000000" - INVALID: str = "11111111111111111111111111111111" + ADMIN = "AD000000000000000000000000000000" + SOME_USER = "00000000000000000000000000000000" + OWNER_USER = "DA1A0000000000000000000000000000" + INVALID = "11111111111111111111111111111111"