From 6dc653f5a9b202e1abd5ad85f5c69c2e675254e1 Mon Sep 17 00:00:00 2001 From: NeurArk Date: Thu, 22 May 2025 16:46:40 +0200 Subject: [PATCH] Validate upload file size and test --- TODO.md | 5 +++++ tests/test_data_utils.py | 40 +++++++++++++++++++++++++++---------- utils/data.py | 43 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 11 deletions(-) diff --git a/TODO.md b/TODO.md index fc05dc2..df669ab 100644 --- a/TODO.md +++ b/TODO.md @@ -195,6 +195,11 @@ After completing a milestone, create a pull request with your changes for review - [x] Provide wrapper storing uploaded data in `st.session_state` - [x] Replace repetitive upload code across pages +## PR19: File Size Validation + +- [x] Enforce maximum file size during upload +- [x] Add tests covering oversized file uploads + ## Notes for Development - Create comprehensive commit messages that clearly describe changes diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index d3e5687..d491d16 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -6,39 +6,39 @@ def test_load_data_csv(tmp_path): - df_exp = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) - csv_file = tmp_path / 'test.csv' + df_exp = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + csv_file = tmp_path / "test.csv" df_exp.to_csv(csv_file, index=False) df = data.load_data(csv_file) pd.testing.assert_frame_equal(df, df_exp) def test_load_data_excel(tmp_path): - df_exp = pd.DataFrame({'a': [1, 2]}) - xls_file = tmp_path / 'test.xlsx' + df_exp = pd.DataFrame({"a": [1, 2]}) + xls_file = tmp_path / "test.xlsx" df_exp.to_excel(xls_file, index=False) df = data.load_data(xls_file) pd.testing.assert_frame_equal(df, df_exp) def test_load_data_invalid(tmp_path): - file = tmp_path / 'bad.txt' - file.write_text('x') + file = tmp_path / "bad.txt" + file.write_text("x") with pytest.raises(ValueError): data.load_data(file) def test_convert_dtypes(): - df = pd.DataFrame({'num': ['1', '2'], 'date': ['2020-01-01', '2020-01-02']}) + df = pd.DataFrame({"num": ["1", "2"], "date": ["2020-01-01", "2020-01-02"]}) conv = data.convert_dtypes(df) - assert conv['num'].dtype.kind in {'i', 'f'} - assert pd.api.types.is_datetime64_any_dtype(conv['date']) + assert conv["num"].dtype.kind in {"i", "f"} + assert pd.api.types.is_datetime64_any_dtype(conv["date"]) def test_data_summary(): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) summary = data.data_summary(df) - assert 'a' in summary.columns + assert "a" in summary.columns def test_sample_dataset_loads(): @@ -98,3 +98,21 @@ def test_upload_data_to_session(monkeypatch, tmp_path): data.upload_data_to_session("Upload", session_key="foo") pd.testing.assert_frame_equal(st.session_state["foo"], df) + +def test_validate_file_size_raises(tmp_path): + file = tmp_path / "big.csv" + file.write_bytes(b"x" * 10) + with pytest.raises(ValueError): + data.validate_file_size(file, max_mb=0) + + +def test_process_uploaded_file_too_large(tmp_path): + import streamlit as st + + df = pd.DataFrame({"a": [1]}) + file = tmp_path / "big.csv" + df.to_csv(file, index=False) + st.session_state.clear() + result = data.process_uploaded_file(file, session_key="up", max_size_mb=0) + assert result is None + assert "up" not in st.session_state diff --git a/utils/data.py b/utils/data.py index 393a589..1652264 100644 --- a/utils/data.py +++ b/utils/data.py @@ -5,6 +5,45 @@ from pathlib import Path from typing import Any, Iterable +MAX_UPLOAD_SIZE_MB = 100 + + +def validate_file_size(file: Any, max_mb: int = MAX_UPLOAD_SIZE_MB) -> int: + """Validate that file size does not exceed ``max_mb`` megabytes. + + Parameters + ---------- + file: + A file-like object or path. + max_mb: + Maximum size in megabytes allowed. + + Returns + ------- + int + Size of the file in bytes. + + Raises + ------ + ValueError + If the file exceeds ``max_mb`` megabytes. + """ + limit = max_mb * 1024 * 1024 + size = getattr(file, "size", None) + if size is None: + try: + if isinstance(file, (str, Path)): + path = Path(file) + else: + path = Path(getattr(file, "name")) + size = path.stat().st_size + except OSError: + size = None + if size is not None and size > limit: + raise ValueError(f"File size {size} exceeds limit of {limit} bytes") + return size or 0 + + import streamlit as st import pandas as pd @@ -78,11 +117,13 @@ def process_uploaded_file( session_key: str, detect_datetime: bool = False, datetime_key: str = "datetime_cols", + max_size_mb: int = MAX_UPLOAD_SIZE_MB, ) -> pd.DataFrame | None: """Load an uploaded file and store the DataFrame in session state.""" if uploaded_file is None: return None try: + _ = validate_file_size(uploaded_file, max_size_mb) _ = validate_file_type(uploaded_file, ["csv", "xls", "xlsx"]) df = load_data(uploaded_file) df = convert_dtypes(df) @@ -105,6 +146,7 @@ def upload_data_to_session( uploader_key: str | None = None, help: str | None = None, types: Iterable[str] = ("csv", "xlsx", "xls"), + max_size_mb: int = MAX_UPLOAD_SIZE_MB, ) -> pd.DataFrame | None: """Upload a file and store the loaded DataFrame in session state.""" file = st.file_uploader( @@ -118,4 +160,5 @@ def upload_data_to_session( session_key=session_key, detect_datetime=datetime_key is not None, datetime_key=datetime_key or "datetime_cols", + max_size_mb=max_size_mb, )