Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"mysqlclient",
"python_dotenv",
"xmltodict",
"python-multipart",
]

[project.optional-dependencies]
Expand Down
35 changes: 31 additions & 4 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from enum import StrEnum
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Any, Literal, NamedTuple

from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, File, HTTPException, UploadFile
from sqlalchemy import Connection, text
from sqlalchemy.engine import Row

Expand All @@ -21,7 +22,13 @@
from database.users import User, UserGroup
from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType
from schemas.datasets.openml import (
DatasetMetadata,
DatasetMetadataView,
DatasetStatus,
Feature,
FeatureType,
)

router = APIRouter(prefix="/datasets", tags=["datasets"])

Expand Down Expand Up @@ -370,6 +377,26 @@ def update_dataset_status(
return {"dataset_id": dataset_id, "status": status}


@router.post(path="")
def upload_data(
file: Annotated[UploadFile, File(description="A pyarrow parquet file containing the data.")],
metadata: DatasetMetadata, # noqa: ARG001
user: Annotated[User | None, Depends(fetch_user)] = None,
) -> None:
if user is None:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="You need to authenticate to upload a dataset.",
)
# Spooled file-- where is it stored?
if file.filename is None or Path(file.filename).suffix != ".pq":
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="The uploaded file needs to be a parquet file (.pq).",
)
# use async interface


@router.get(
path="/{dataset_id}",
description="Get meta-data for dataset with ID `dataset_id`.",
Expand All @@ -379,7 +406,7 @@ def get_dataset(
user: Annotated[User | None, Depends(fetch_user)] = None,
user_db: Annotated[Connection, Depends(userdb_connection)] = None,
expdb_db: Annotated[Connection, Depends(expdb_connection)] = None,
) -> DatasetMetadata:
) -> DatasetMetadataView:
dataset = _get_dataset_raise_otherwise(dataset_id, user, expdb_db)
if not (
dataset_file := database.datasets.get_file(file_id=dataset.file_id, connection=user_db)
Expand Down Expand Up @@ -411,7 +438,7 @@ def get_dataset(
original_data_url = _csv_as_list(dataset.original_data_url, unquote_items=True)
default_target_attribute = _csv_as_list(dataset.default_target_attribute, unquote_items=True)

return DatasetMetadata(
return DatasetMetadataView(
id=dataset.did,
visibility=dataset.visibility,
status=status_,
Expand Down
4 changes: 2 additions & 2 deletions src/schemas/datasets/convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
SpdxChecksum,
XSDDateTime,
)
from schemas.datasets.openml import DatasetMetadata
from schemas.datasets.openml import DatasetMetadataView


def openml_dataset_to_dcat(metadata: DatasetMetadata) -> DcatApWrapper:
def openml_dataset_to_dcat(metadata: DatasetMetadataView) -> DcatApWrapper:
checksum = SpdxChecksum(
id_=metadata.md5_checksum,
algorithm="md5",
Expand Down
58 changes: 32 additions & 26 deletions src/schemas/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ class EstimationProcedure(BaseModel):


class DatasetMetadata(BaseModel):
id_: int = Field(json_schema_extra={"example": 1}, alias="id")
visibility: Visibility = Field(json_schema_extra={"example": Visibility.PUBLIC})
status: DatasetStatus = Field(json_schema_extra={"example": DatasetStatus.ACTIVE})

name: str = Field(json_schema_extra={"example": "Anneal"})
licence: str = Field(json_schema_extra={"example": "CC0"})
version: int = Field(json_schema_extra={"example": 2})
version_label: str = Field(
version_label: str | None = Field(
json_schema_extra={
"example": 2,
"description": "Not sure how this relates to `version`.",
},
max_length=128,
)
language: str | None = Field(
json_schema_extra={"example": "English"},
max_length=128,
)
language: str = Field(json_schema_extra={"example": "English"})

creators: list[str] = Field(
json_schema_extra={"example": ["David Sterling", "Wray Buntine"]},
Expand All @@ -82,14 +82,38 @@ class DatasetMetadata(BaseModel):
json_schema_extra={"example": ["David Sterling", "Wray Buntine"]},
alias="contributor",
)
citation: str = Field(
citation: str | None = Field(
json_schema_extra={"example": "https://archive.ics.uci.edu/ml/citation_policy.html"},
)
paper_url: HttpUrl | None = Field(
json_schema_extra={
"example": "http://digital.library.adelaide.edu.au/dspace/handle/2440/15227",
},
)
collection_date: str | None = Field(json_schema_extra={"example": "1990"})

description: str = Field(
json_schema_extra={"example": "The original Annealing dataset from UCI."},
)
default_target_attribute: list[str] = Field(json_schema_extra={"example": "class"})
ignore_attribute: list[str] = Field(json_schema_extra={"example": "sensitive_feature"})
row_id_attribute: list[str] = Field(json_schema_extra={"example": "ssn"})

format_: DatasetFileFormat = Field(
json_schema_extra={"example": DatasetFileFormat.ARFF},
alias="format",
)
original_data_url: list[HttpUrl] | None = Field(
json_schema_extra={"example": "https://www.openml.org/d/2"},
)


class DatasetMetadataView(DatasetMetadata):
id_: int = Field(json_schema_extra={"example": 1}, alias="id")
visibility: Visibility = Field(json_schema_extra={"example": Visibility.PUBLIC})
status: DatasetStatus = Field(json_schema_extra={"example": DatasetStatus.ACTIVE})
description_version: int = Field(json_schema_extra={"example": 2})
tags: list[str] = Field(json_schema_extra={"example": ["study_1", "uci"]}, alias="tag")
upload_date: datetime = Field(
json_schema_extra={"example": str(datetime(2014, 4, 6, 23, 12, 20))},
)
Expand All @@ -101,17 +125,7 @@ class DatasetMetadata(BaseModel):
alias="error",
)
processing_warning: str | None = Field(alias="warning")
collection_date: str | None = Field(json_schema_extra={"example": "1990"})

description: str = Field(
json_schema_extra={"example": "The original Annealing dataset from UCI."},
)
description_version: int = Field(json_schema_extra={"example": 2})
tags: list[str] = Field(json_schema_extra={"example": ["study_1", "uci"]}, alias="tag")
default_target_attribute: list[str] = Field(json_schema_extra={"example": "class"})
ignore_attribute: list[str] = Field(json_schema_extra={"example": "sensitive_feature"})
row_id_attribute: list[str] = Field(json_schema_extra={"example": "ssn"})

file_id: int = Field(json_schema_extra={"example": 1})
url: HttpUrl = Field(
json_schema_extra={
"example": "https://www.openml.org/data/download/1/dataset_1_anneal.arff",
Expand All @@ -124,14 +138,6 @@ class DatasetMetadata(BaseModel):
"description": "URL of the parquet dataset data file.",
},
)
file_id: int = Field(json_schema_extra={"example": 1})
format_: DatasetFileFormat = Field(
json_schema_extra={"example": DatasetFileFormat.ARFF},
alias="format",
)
original_data_url: list[HttpUrl] | None = Field(
json_schema_extra={"example": "https://www.openml.org/d/2"},
)
md5_checksum: str = Field(json_schema_extra={"example": "d01f6ccd68c88b749b20bbe897de3713"})


Expand Down
35 changes: 31 additions & 4 deletions tests/routers/openml/datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from http import HTTPStatus
from io import BytesIO

import pytest
from fastapi import HTTPException
from fastapi import HTTPException, UploadFile
from sqlalchemy import Connection
from starlette.testclient import TestClient

from database.users import User
from routers.openml.datasets import get_dataset
from schemas.datasets.openml import DatasetMetadata, DatasetStatus
from routers.openml.datasets import get_dataset, upload_data
from schemas.datasets.openml import DatasetMetadataView, DatasetStatus
from tests.users import ADMIN_USER, NO_USER, OWNER_USER, SOME_USER, ApiKey


Expand Down Expand Up @@ -100,7 +101,7 @@ def test_private_dataset_access(user: User, expdb_test: Connection, user_test: C
user_db=user_test,
expdb_db=expdb_test,
)
assert isinstance(dataset, DatasetMetadata)
assert isinstance(dataset, DatasetMetadataView)


def test_dataset_features(py_api: TestClient) -> None:
Expand Down Expand Up @@ -269,3 +270,29 @@ def test_dataset_status_unauthorized(
json={"dataset_id": dataset_id, "status": status},
)
assert response.status_code == HTTPStatus.FORBIDDEN


def test_dataset_upload_needs_authentication() -> None:
with pytest.raises(HTTPException) as e:
upload_data(user=None, file=None, metadata=None) # type: ignore[arg-type]

assert e.value.status_code == HTTPStatus.UNAUTHORIZED
assert e.value.detail == "You need to authenticate to upload a dataset."


@pytest.mark.parametrize(
"file_name", ["parquet.csv", pytest.param("parquet.pq", marks=pytest.mark.xfail)]
)
def test_dataset_upload_error_if_not_parquet(file_name: str) -> None:
# we do not expect the server to actually check the parquet content
file = UploadFile(filename=file_name, file=BytesIO(b""))

with pytest.raises(HTTPException) as e:
upload_data(file=file, user=SOME_USER, metadata=None) # type: ignore[arg-type]

assert e.value.status_code == HTTPStatus.BAD_REQUEST
assert e.value.detail == "The uploaded file needs to be a parquet file (.pq)."


def test_dataset_upload() -> None:
pass
4 changes: 2 additions & 2 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from http import HTTPStatus

import constants
import httpx
import pytest
from starlette.testclient import TestClient

import tests.constants
from core.conversions import nested_remove_single_element_list
from tests.users import ApiKey

Expand Down Expand Up @@ -127,7 +127,7 @@
php_api: TestClient,
api_key: str,
) -> None:
[private_dataset] = constants.PRIVATE_DATASET_ID
[private_dataset] = tests.constants.PRIVATE_DATASET_ID

Check warning on line 130 in tests/routers/openml/migration/datasets_migration_test.py

View check run for this annotation

Codecov / codecov/patch

tests/routers/openml/migration/datasets_migration_test.py#L130

Added line #L130 was not covered by tests
new_response = py_api.get(f"/datasets/{private_dataset}?api_key={api_key}")
old_response = php_api.get(f"/data/{private_dataset}?api_key={api_key}")
assert old_response.status_code == HTTPStatus.OK
Expand Down
Loading