diff --git a/pyproject.toml b/pyproject.toml index a7bff61..8960f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "mysqlclient", "python_dotenv", "xmltodict", + "python-multipart", ] [project.optional-dependencies] diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index dda2511..0b15475 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -2,9 +2,10 @@ from datetime import datetime from enum import StrEnum from http import HTTPStatus +from pathlib import Path from typing import Annotated, Any, Literal, NamedTuple -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Body, Depends, File, HTTPException, UploadFile from sqlalchemy import Connection, text from sqlalchemy.engine import Row @@ -21,7 +22,13 @@ from database.users import User, UserGroup from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex -from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType +from schemas.datasets.openml import ( + DatasetMetadata, + DatasetMetadataView, + DatasetStatus, + Feature, + FeatureType, +) router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -370,6 +377,26 @@ def update_dataset_status( return {"dataset_id": dataset_id, "status": status} +@router.post(path="") +def upload_data( + file: Annotated[UploadFile, File(description="A pyarrow parquet file containing the data.")], + metadata: DatasetMetadata, # noqa: ARG001 + user: Annotated[User | None, Depends(fetch_user)] = None, +) -> None: + if user is None: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="You need to authenticate to upload a dataset.", + ) + # Spooled file-- where is it stored? + if file.filename is None or Path(file.filename).suffix != ".pq": + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="The uploaded file needs to be a parquet file (.pq).", + ) + # use async interface + + @router.get( path="/{dataset_id}", description="Get meta-data for dataset with ID `dataset_id`.", @@ -379,7 +406,7 @@ def get_dataset( user: Annotated[User | None, Depends(fetch_user)] = None, user_db: Annotated[Connection, Depends(userdb_connection)] = None, expdb_db: Annotated[Connection, Depends(expdb_connection)] = None, -) -> DatasetMetadata: +) -> DatasetMetadataView: dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db) if not ( dataset_file := database.datasets.get_file(file_id=dataset.file_id, connection=user_db) @@ -411,7 +438,7 @@ def get_dataset( original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True) default_target_attribute = _csv_as_list(dataset.default_target_attribute, unquote_items=True) - return DatasetMetadata( + return DatasetMetadataView( id=dataset.did, visibility=dataset.visibility, status=status_, diff --git a/src/schemas/datasets/convertor.py b/src/schemas/datasets/convertor.py index 3c35cd2..7b4a305 100644 --- a/src/schemas/datasets/convertor.py +++ b/src/schemas/datasets/convertor.py @@ -6,10 +6,10 @@ SpdxChecksum, XSDDateTime, ) -from schemas.datasets.openml import DatasetMetadata +from schemas.datasets.openml import DatasetMetadataView -def openml_dataset_to_dcat(metadata: DatasetMetadata) -> DcatApWrapper: +def openml_dataset_to_dcat(metadata: DatasetMetadataView) -> DcatApWrapper: checksum = SpdxChecksum( id_=metadata.md5_checksum, algorithm="md5", diff --git a/src/schemas/datasets/openml.py b/src/schemas/datasets/openml.py index 8edb373..eee3e84 100644 --- a/src/schemas/datasets/openml.py +++ b/src/schemas/datasets/openml.py @@ -59,20 +59,20 @@ class EstimationProcedure(BaseModel): class DatasetMetadata(BaseModel): - id_: int = Field(json_schema_extra={"example": 1}, alias="id") - visibility: Visibility = Field(json_schema_extra={"example": Visibility.PUBLIC}) - status: DatasetStatus = Field(json_schema_extra={"example": DatasetStatus.ACTIVE}) - name: str = Field(json_schema_extra={"example": "Anneal"}) licence: str = Field(json_schema_extra={"example": "CC0"}) version: int = Field(json_schema_extra={"example": 2}) - version_label: str = Field( + version_label: str | None = Field( json_schema_extra={ "example": 2, "description": "Not sure how this relates to `version`.", }, + max_length=128, + ) + language: str | None = Field( + json_schema_extra={"example": "English"}, + max_length=128, ) - language: str = Field(json_schema_extra={"example": "English"}) creators: list[str] = Field( json_schema_extra={"example": ["David Sterling", "Wray Buntine"]}, @@ -82,7 +82,7 @@ class DatasetMetadata(BaseModel): json_schema_extra={"example": ["David Sterling", "Wray Buntine"]}, alias="contributor", ) - citation: str = Field( + citation: str | None = Field( json_schema_extra={"example": "https://archive.ics.uci.edu/ml/citation_policy.html"}, ) paper_url: HttpUrl | None = Field( @@ -90,6 +90,30 @@ class DatasetMetadata(BaseModel): "example": "http://digital.library.adelaide.edu.au/dspace/handle/2440/15227", }, ) + collection_date: str | None = Field(json_schema_extra={"example": "1990"}) + + description: str = Field( + json_schema_extra={"example": "The original Annealing dataset from UCI."}, + ) + default_target_attribute: list[str] = Field(json_schema_extra={"example": "class"}) + ignore_attribute: list[str] = Field(json_schema_extra={"example": "sensitive_feature"}) + row_id_attribute: list[str] = Field(json_schema_extra={"example": "ssn"}) + + format_: DatasetFileFormat = Field( + json_schema_extra={"example": DatasetFileFormat.ARFF}, + alias="format", + ) + original_data_url: list[HttpUrl] | None = Field( + json_schema_extra={"example": "https://www.openml.org/d/2"}, + ) + + +class DatasetMetadataView(DatasetMetadata): + id_: int = Field(json_schema_extra={"example": 1}, alias="id") + visibility: Visibility = Field(json_schema_extra={"example": Visibility.PUBLIC}) + status: DatasetStatus = Field(json_schema_extra={"example": DatasetStatus.ACTIVE}) + description_version: int = Field(json_schema_extra={"example": 2}) + tags: list[str] = Field(json_schema_extra={"example": ["study_1", "uci"]}, alias="tag") upload_date: datetime = Field( json_schema_extra={"example": str(datetime(2014, 4, 6, 23, 12, 20))}, ) @@ -101,17 +125,7 @@ class DatasetMetadata(BaseModel): alias="error", ) processing_warning: str | None = Field(alias="warning") - collection_date: str | None = Field(json_schema_extra={"example": "1990"}) - - description: str = Field( - json_schema_extra={"example": "The original Annealing dataset from UCI."}, - ) - description_version: int = Field(json_schema_extra={"example": 2}) - tags: list[str] = Field(json_schema_extra={"example": ["study_1", "uci"]}, alias="tag") - default_target_attribute: list[str] = Field(json_schema_extra={"example": "class"}) - ignore_attribute: list[str] = Field(json_schema_extra={"example": "sensitive_feature"}) - row_id_attribute: list[str] = Field(json_schema_extra={"example": "ssn"}) - + file_id: int = Field(json_schema_extra={"example": 1}) url: HttpUrl = Field( json_schema_extra={ "example": "https://www.openml.org/data/download/1/dataset_1_anneal.arff", @@ -124,14 +138,6 @@ class DatasetMetadata(BaseModel): "description": "URL of the parquet dataset data file.", }, ) - file_id: int = Field(json_schema_extra={"example": 1}) - format_: DatasetFileFormat = Field( - json_schema_extra={"example": DatasetFileFormat.ARFF}, - alias="format", - ) - original_data_url: list[HttpUrl] | None = Field( - json_schema_extra={"example": "https://www.openml.org/d/2"}, - ) md5_checksum: str = Field(json_schema_extra={"example": "d01f6ccd68c88b749b20bbe897de3713"}) diff --git a/tests/routers/openml/datasets_test.py b/tests/routers/openml/datasets_test.py index 7c1457f..2bb2ede 100644 --- a/tests/routers/openml/datasets_test.py +++ b/tests/routers/openml/datasets_test.py @@ -1,13 +1,14 @@ from http import HTTPStatus +from io import BytesIO import pytest -from fastapi import HTTPException +from fastapi import HTTPException, UploadFile from sqlalchemy import Connection from starlette.testclient import TestClient from database.users import User -from routers.openml.datasets import get_dataset -from schemas.datasets.openml import DatasetMetadata, DatasetStatus +from routers.openml.datasets import get_dataset, upload_data +from schemas.datasets.openml import DatasetMetadataView, DatasetStatus from tests.users import ADMIN_USER, NO_USER, OWNER_USER, SOME_USER, ApiKey @@ -100,7 +101,7 @@ def test_private_dataset_access(user: User, expdb_test: Connection, user_test: C user_db=user_test, expdb_db=expdb_test, ) - assert isinstance(dataset, DatasetMetadata) + assert isinstance(dataset, DatasetMetadataView) def test_dataset_features(py_api: TestClient) -> None: @@ -269,3 +270,29 @@ def test_dataset_status_unauthorized( json={"dataset_id": dataset_id, "status": status}, ) assert response.status_code == HTTPStatus.FORBIDDEN + + +def test_dataset_upload_needs_authentication() -> None: + with pytest.raises(HTTPException) as e: + upload_data(user=None, file=None, metadata=None) # type: ignore[arg-type] + + assert e.value.status_code == HTTPStatus.UNAUTHORIZED + assert e.value.detail == "You need to authenticate to upload a dataset." + + +@pytest.mark.parametrize( + "file_name", ["parquet.csv", pytest.param("parquet.pq", marks=pytest.mark.xfail)] +) +def test_dataset_upload_error_if_not_parquet(file_name: str) -> None: + # we do not expect the server to actually check the parquet content + file = UploadFile(filename=file_name, file=BytesIO(b"")) + + with pytest.raises(HTTPException) as e: + upload_data(file=file, user=SOME_USER, metadata=None) # type: ignore[arg-type] + + assert e.value.status_code == HTTPStatus.BAD_REQUEST + assert e.value.detail == "The uploaded file needs to be a parquet file (.pq)." + + +def test_dataset_upload() -> None: + pass diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 3faca11..f09f86a 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -1,11 +1,11 @@ import json from http import HTTPStatus -import constants import httpx import pytest from starlette.testclient import TestClient +import tests.constants from core.conversions import nested_remove_single_element_list from tests.users import ApiKey @@ -127,7 +127,7 @@ def test_private_dataset_owner_access( php_api: TestClient, api_key: str, ) -> None: - [private_dataset] = constants.PRIVATE_DATASET_ID + [private_dataset] = tests.constants.PRIVATE_DATASET_ID new_response = py_api.get(f"/datasets/{private_dataset}?api_key={api_key}") old_response = php_api.get(f"/data/{private_dataset}?api_key={api_key}") assert old_response.status_code == HTTPStatus.OK