diff --git a/src/core/errors.py b/src/core/errors.py index dea1f50..69d7e0c 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -7,6 +7,7 @@ from http import HTTPStatus from fastapi import Request +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse # ============================================================================= @@ -89,6 +90,27 @@ def problem_detail_exception_handler( ) +def validation_exception_handler( + request: Request, # noqa: ARG001 + exc: RequestValidationError, +) -> JSONResponse: + """FastAPI exception handler for RequestValidationError. + + Returns a RFC 9457 compliant response for input validation failures. + """ + return JSONResponse( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + content={ + "type": "https://openml.org/problems/validation-error", + "title": "Validation Error", + "status": HTTPStatus.UNPROCESSABLE_ENTITY, + "detail": "Input validation failed.", + "errors": exc.errors(), + }, + media_type="application/problem+json", + ) + + # ============================================================================= # Dataset Errors # ============================================================================= diff --git a/src/main.py b/src/main.py index 3be2c5c..46cd79c 100644 --- a/src/main.py +++ b/src/main.py @@ -7,10 +7,15 @@ import uvicorn from fastapi import FastAPI +from fastapi.exceptions import RequestValidationError from loguru import logger from config import load_configuration -from core.errors import ProblemDetailError, problem_detail_exception_handler +from core.errors import ( + ProblemDetailError, + problem_detail_exception_handler, + validation_exception_handler, +) from core.logging import ( add_request_context_to_log, log_request_duration, @@ -87,6 +92,7 @@ def create_api(configuration_file: Path | None = None) -> FastAPI: app.middleware("http")(add_request_context_to_log) app.add_exception_handler(ProblemDetailError, problem_detail_exception_handler) # type: ignore[arg-type] + app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] logger.info("Adding routers to app") app.include_router(datasets_router) diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index cddd0d8..d11fc96 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -42,7 +42,7 @@ async def test_dataset_tag_invalid_tag_is_rejected( ) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY - assert response.json()["detail"][0]["loc"] == ["body", "tag"] + assert response.json()["errors"][0]["loc"] == ["body", "tag"] # ── Direct call tests: tag_dataset ── diff --git a/tests/routers/openml/task_list_test.py b/tests/routers/openml/task_list_test.py index 4667967..67d8539 100644 --- a/tests/routers/openml/task_list_test.py +++ b/tests/routers/openml/task_list_test.py @@ -135,9 +135,9 @@ async def test_list_tasks_invalid_pagination_type( ) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY # Verify that the error points to the correct field - detail = response.json()["detail"][0] - assert detail["loc"][-2:] == ["pagination", expected_field] - assert detail["type"] in {"type_error.integer", "int_parsing", "int_type"} + error = response.json()["errors"][0] + assert error["loc"][-2:] == ["pagination", expected_field] + assert error["type"] in {"type_error.integer", "int_parsing", "int_type"} @pytest.mark.parametrize( @@ -150,8 +150,8 @@ async def test_list_tasks_invalid_range(value: str, py_api: httpx.AsyncClient) - response = await py_api.post("/tasks/list", json={"number_instances": value}) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY # Verify the error is for the correct field - detail = response.json()["detail"][0] - assert detail["loc"][-1] == "number_instances" + error = response.json()["errors"][0] + assert error["loc"][-1] == "number_instances" @pytest.mark.parametrize( @@ -171,9 +171,9 @@ async def test_list_tasks_invalid_inputs( response = await py_api.post("/tasks/list", json=payload) assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY # Ensure we are failing for the field we provided - detail = response.json()["detail"][0] + error = response.json()["errors"][0] expected_field = next(iter(payload)) - assert detail["loc"][-1] == expected_field + assert error["loc"][-1] == expected_field async def test_list_tasks_no_results_api_mapping(py_api: httpx.AsyncClient) -> None: