diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py index f7fdd667..7e8dacf9 100644 --- a/pyhealth/datasets/tuev.py +++ b/pyhealth/datasets/tuev.py @@ -81,6 +81,24 @@ def __init__( raise ValueError("subset must be one of 'train', 'eval', or 'both'") self.prepare_metadata() + + # Determine where the CSVs are located (shared directory or cache) + root_path = Path(root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev" + + # Check if CSVs exist in cache and not in shared location + use_cache = False + for table in tables: + shared_csv = root_path / f"tuev-{table}-pyhealth.csv" + cache_csv = cache_dir / f"tuev-{table}-pyhealth.csv" + if not shared_csv.exists() and cache_csv.exists(): + use_cache = True + break + + # Use cache directory as root if CSVs are there + if use_cache: + logger.info(f"Using cached metadata from {cache_dir}") + root = str(cache_dir) super().__init__( root=root, @@ -107,12 +125,16 @@ def prepare_metadata(self) -> None: - segment_id = a_ / a_1 / ... """ root = Path(self.root) + cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev" train_rows: list[dict] = [] eval_rows: list[dict] = [] for split in ("train", "eval"): - if os.path.exists(root / f"tuev-{split}-pyhealth.csv"): + # Check if metadata exists in either shared location or cache + shared_csv = root / f"tuev-{split}-pyhealth.csv" + cache_csv = cache_dir / f"tuev-{split}-pyhealth.csv" + if shared_csv.exists() or cache_csv.exists(): continue split_dir = root / split @@ -153,7 +175,9 @@ def prepare_metadata(self) -> None: } ) - root.mkdir(parents=True, exist_ok=True) + # Setup cache directory as fallback for metadata CSVs + cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev" + cache_dir.mkdir(parents=True, exist_ok=True) # Write train metadata if train_rows: @@ -162,9 +186,18 @@ def prepare_metadata(self) -> None: ["patient_id", "record_id"], inplace=True, na_position="last" ) train_df.reset_index(drop=True, inplace=True) - train_csv = root / "tuev-train-pyhealth.csv" - train_df.to_csv(train_csv, index=False) - + + # Try shared location first, fall back to cache if no write permission + train_csv_shared = root / "tuev-train-pyhealth.csv" + train_csv_cache = cache_dir / "tuev-train-pyhealth.csv" + + try: + train_csv_shared.parent.mkdir(parents=True, exist_ok=True) + train_df.to_csv(train_csv_shared, index=False) + logger.info(f"Wrote train metadata to {train_csv_shared}") + except (PermissionError, OSError): + train_df.to_csv(train_csv_cache, index=False) + logger.info(f"Wrote train metadata to cache: {train_csv_cache}") # Write eval metadata if eval_rows: @@ -175,8 +208,18 @@ def prepare_metadata(self) -> None: na_position="last", ) eval_df.reset_index(drop=True, inplace=True) - eval_csv = root / "tuev-eval-pyhealth.csv" - eval_df.to_csv(eval_csv, index=False) + + # Try shared location first, fall back to cache if no write permission + eval_csv_shared = root / "tuev-eval-pyhealth.csv" + eval_csv_cache = cache_dir / "tuev-eval-pyhealth.csv" + + try: + eval_csv_shared.parent.mkdir(parents=True, exist_ok=True) + eval_df.to_csv(eval_csv_shared, index=False) + logger.info(f"Wrote eval metadata to {eval_csv_shared}") + except (PermissionError, OSError): + eval_df.to_csv(eval_csv_cache, index=False) + logger.info(f"Wrote eval metadata to cache: {eval_csv_cache}") @property def default_task(self) -> EEGEventsTUEV: