Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-ast
- id: check-toml
Expand All @@ -18,15 +18,15 @@ repos:
# - id: no-commit-to-branch

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.11.2'
rev: 'v1.19.1'
hooks:
- id: mypy
additional_dependencies:
- fastapi
- pytest

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.6.9'
rev: 'v0.14.10'
hooks:
- id: ruff
args: [--fix]
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,15 @@ line-length = 100
# The D (doc) and DTZ (datetime zone) lint classes current heavily violated - fix later
select = ["ALL"]
ignore = [
"ANN101", # style choice - no annotation for self
"ANN102", # style choice - no annotation for cls
"CPY", # we do not require copyright in every file
"D", # todo: docstring linting
"D203",
"D204",
"D213",
"DTZ", # To add
# Linter does not detect when types are used for Pydantic
"TCH001",
"TCH003",
"TC001",
"TC003",
]

[tool.ruff.lint.per-file-ignores]
Expand Down
4 changes: 2 additions & 2 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def _apply_defaults_to_siblings(configuration: TomlTable) -> TomlTable:

@functools.cache
def _load_configuration(file: Path) -> TomlTable:
return typing.cast(TomlTable, tomllib.loads(file.read_text()))
return tomllib.loads(file.read_text())


def load_routing_configuration(file: Path = CONFIG_PATH) -> TomlTable:
return typing.cast(TomlTable, _load_configuration(file)["routing"])
return typing.cast("TomlTable", _load_configuration(file)["routing"])


@functools.cache
Expand Down
2 changes: 1 addition & 1 deletion src/database/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
connection.execute(
text(
"""
Expand Down
4 changes: 2 additions & 2 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:

def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand Down
6 changes: 3 additions & 3 deletions src/database/studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
"""
if study.type_ == StudyType.TASK:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand All @@ -56,7 +56,7 @@ def get_study_data(study: Row, expdb: Connection) -> Sequence[Row]:
).all(),
)
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def create(study: CreateStudy, user: User, expdb: Connection) -> int:
},
)
(study_id,) = expdb.execute(text("""SELECT LAST_INSERT_ID();""")).one()
return cast(int, study_id)
return cast("int", study_id)


def attach_task(task_id: int, study_id: int, user: User, expdb: Connection) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get(id_: int, expdb: Connection) -> Row | None:

def get_task_types(expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Using string literals as the first argument to typing.cast undermines static type checking.

cast expects a type (e.g. Sequence[Row]), not a string. Using cast("Sequence[Row]", ...) turns this into a cast to a string literal type (or str), which discards the intended type information and can mask real issues. Please switch back to cast(Sequence[Row], ...) (or a suitable type alias) throughout these functions.

"Sequence[Row]",
expdb.execute(
text(
"""
Expand All @@ -46,7 +46,7 @@ def get_task_type(task_type_id: int, expdb: Connection) -> Row | None:

def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand All @@ -62,7 +62,7 @@ def get_input_for_task_type(task_type_id: int, expdb: Connection) -> Sequence[Ro

def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand All @@ -78,7 +78,7 @@ def get_input_for_task(id_: int, expdb: Connection) -> Sequence[Row]:

def get_task_type_inout_with_template(task_type: int, expdb: Connection) -> Sequence[Row]:
return cast(
Sequence[Row],
"Sequence[Row]",
expdb.execute(
text(
"""
Expand Down
39 changes: 22 additions & 17 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from http import HTTPStatus
from typing import Annotated, Any
from typing import Annotated, cast

import xmltodict
from fastapi import APIRouter, Depends, HTTPException
Expand All @@ -15,22 +15,24 @@

router = APIRouter(prefix="/tasks", tags=["tasks"])

type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None

def convert_template_xml_to_json(xml_template: str) -> Any: # noqa: ANN401

def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
json_str = json.dumps(json_template)
# To account for the differences between PHP and Python conversions:
for py, php in [("@name", "name"), ("#text", "value"), ("@type", "type")]:
json_str = json_str.replace(py, php)
return json.loads(json_str)
return cast("dict[str, JSON]", json.loads(json_str))


def fill_template(
template: str,
task: RowMapping,
task_inputs: dict[str, str],
task_inputs: dict[str, str | int],
connection: Connection,
) -> Any: # noqa: ANN401
) -> dict[str, JSON]:
"""Fill in the XML template as used for task descriptions and return the result,
converted to JSON.

Expand Down Expand Up @@ -79,22 +81,25 @@ def fill_template(
}
"""
json_template = convert_template_xml_to_json(template)
return _fill_json_template(
json_template,
task,
task_inputs,
fetched_data={},
connection=connection,
return cast(
"dict[str, JSON]",
_fill_json_template(
json_template,
task,
task_inputs,
fetched_data={},
connection=connection,
),
)


def _fill_json_template(
template: dict[str, Any],
template: JSON,
task: RowMapping,
task_inputs: dict[str, str],
fetched_data: dict[str, Any],
task_inputs: dict[str, str | int],
fetched_data: dict[str, str],
connection: Connection,
) -> dict[str, Any] | list[dict[str, Any]] | str:
) -> JSON:
if isinstance(template, dict):
return {
k: _fill_json_template(v, task, task_inputs, fetched_data, connection)
Expand All @@ -115,7 +120,7 @@ def _fill_json_template(
if match.string == template:
# How do we know the default value? probably ttype_io table?
return task_inputs.get(field, [])
template = template.replace(match.group(), task_inputs[field])
template = template.replace(match.group(), str(task_inputs[field]))
if match := re.search(r"\[LOOKUP:(.*)]", template):
(field,) = match.groups()
if field not in fetched_data:
Expand Down Expand Up @@ -176,7 +181,7 @@ def get_task(
tags = database.tasks.get_tags(task_id, expdb)
name = f"Task {task_id} ({task_type.name})"
dataset_id = task_inputs.get("source_data")
if dataset_id and (dataset := database.datasets.get(dataset_id, expdb)):
if isinstance(dataset_id, int) and (dataset := database.datasets.get(dataset_id, expdb)):
name = f"Task {task_id}: {dataset.name} ({task_type.name})"

return Task(
Expand Down
4 changes: 2 additions & 2 deletions src/routers/openml/tasktype.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def get_task_type(
task_type = _normalize_task_type(task_type_record)
# Some names are quoted, or have typos in their comma-separation (e.g. 'A ,B')
task_type["creator"] = [
creator.strip(' "') for creator in cast(str, task_type["creator"]).split(",")
creator.strip(' "') for creator in cast("str", task_type["creator"]).split(",")
]
if contributors := task_type.pop("contributors"):
task_type["contributor"] = [
creator.strip(' "') for creator in cast(str, contributors).split(",")
creator.strip(' "') for creator in cast("str", contributors).split(",")
]
task_type["creation_date"] = task_type.pop("creationDate")
task_type_inputs = get_input_for_task_type(task_type_id, expdb)
Expand Down
9 changes: 3 additions & 6 deletions src/schemas/datasets/mldcat_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from abc import ABC
from enum import StrEnum
from typing import Generic, Literal, TypeVar
from typing import Literal

from pydantic import BaseModel, Field, HttpUrl, field_serializer, model_serializer

Expand Down Expand Up @@ -41,10 +41,7 @@ class JsonLDObject(BaseModel, ABC):
}


T = TypeVar("T", bound=JsonLDObject)


class JsonLDObjectReference(BaseModel, Generic[T]):
class JsonLDObjectReference[T: JsonLDObject](BaseModel):
id_: str = Field(serialization_alias="@id")

model_config = {"populate_by_name": True, "extra": "forbid"}
Expand Down Expand Up @@ -275,7 +272,7 @@ class DataService(JsonLDObject):


class JsonLDGraph(BaseModel):
context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context") # type: ignore[arg-type]
context: str | dict[str, HttpUrl] = Field(default_factory=dict, serialization_alias="@context")
graph: list[Distribution | DataService | Dataset | Quality | Feature | Agent | MD5Checksum] = (
Field(default_factory=list, serialization_alias="@graph")
)
Expand Down
8 changes: 4 additions & 4 deletions tests/routers/openml/datasets_list_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl


@pytest.mark.slow
@hypothesis.settings(
@hypothesis.settings( # type: ignore[untyped-decorator] # 108
max_examples=5000,
suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture],
deadline=None,
) # type: ignore[misc] # https://github.com/openml/server-api/issues/108
@given(
)
@given( # type: ignore[untyped-decorator] # 108
number_missing_values=st.sampled_from([None, "2", "2..10000"]),
number_features=st.sampled_from([None, "5", "2..100"]),
number_classes=st.sampled_from([None, "5", "2..100"]),
Expand All @@ -243,7 +243,7 @@ def test_list_data_quality(quality: str, range_: str, count: int, py_api: TestCl
tag=st.sampled_from([None, "study_14", "study_not_in_db"]),
# We don't test ADMIN user, as we fixed a bug which treated them as a regular user
api_key=st.sampled_from([None, ApiKey.SOME_USER, ApiKey.OWNER_USER]),
) # type: ignore[misc] # https://github.com/openml/server-api/issues/108
)
def test_list_data_identical(
py_api: TestClient,
php_api: httpx.Client,
Expand Down
8 changes: 4 additions & 4 deletions tests/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class ApiKey(StrEnum):
ADMIN: str = "AD000000000000000000000000000000"
SOME_USER: str = "00000000000000000000000000000000"
OWNER_USER: str = "DA1A0000000000000000000000000000"
INVALID: str = "11111111111111111111111111111111"
ADMIN = "AD000000000000000000000000000000"
SOME_USER = "00000000000000000000000000000000"
OWNER_USER = "DA1A0000000000000000000000000000"
INVALID = "11111111111111111111111111111111"