Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions pyhealth/datasets/tuev.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ def __init__(
raise ValueError("subset must be one of 'train', 'eval', or 'both'")

self.prepare_metadata()

# Determine where the CSVs are located (shared directory or cache)
root_path = Path(root)
cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev"

# Check if CSVs exist in cache and not in shared location
use_cache = False
for table in tables:
shared_csv = root_path / f"tuev-{table}-pyhealth.csv"
cache_csv = cache_dir / f"tuev-{table}-pyhealth.csv"
if not shared_csv.exists() and cache_csv.exists():
use_cache = True
break

# Use cache directory as root if CSVs are there
if use_cache:
logger.info(f"Using cached metadata from {cache_dir}")
root = str(cache_dir)

super().__init__(
root=root,
Expand All @@ -107,12 +125,16 @@ def prepare_metadata(self) -> None:
- segment_id = a_ / a_1 / ...
"""
root = Path(self.root)
cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev"

train_rows: list[dict] = []
eval_rows: list[dict] = []

for split in ("train", "eval"):
if os.path.exists(root / f"tuev-{split}-pyhealth.csv"):
# Check if metadata exists in either shared location or cache
shared_csv = root / f"tuev-{split}-pyhealth.csv"
cache_csv = cache_dir / f"tuev-{split}-pyhealth.csv"
if shared_csv.exists() or cache_csv.exists():
continue

split_dir = root / split
Expand Down Expand Up @@ -153,7 +175,9 @@ def prepare_metadata(self) -> None:
}
)

root.mkdir(parents=True, exist_ok=True)
# Setup cache directory as fallback for metadata CSVs
cache_dir = Path.home() / ".cache" / "pyhealth" / "tuev"
cache_dir.mkdir(parents=True, exist_ok=True)

# Write train metadata
if train_rows:
Expand All @@ -162,9 +186,18 @@ def prepare_metadata(self) -> None:
["patient_id", "record_id"], inplace=True, na_position="last"
)
train_df.reset_index(drop=True, inplace=True)
train_csv = root / "tuev-train-pyhealth.csv"
train_df.to_csv(train_csv, index=False)


# Try shared location first, fall back to cache if no write permission
train_csv_shared = root / "tuev-train-pyhealth.csv"
train_csv_cache = cache_dir / "tuev-train-pyhealth.csv"

try:
train_csv_shared.parent.mkdir(parents=True, exist_ok=True)
train_df.to_csv(train_csv_shared, index=False)
logger.info(f"Wrote train metadata to {train_csv_shared}")
except (PermissionError, OSError):
train_df.to_csv(train_csv_cache, index=False)
logger.info(f"Wrote train metadata to cache: {train_csv_cache}")

# Write eval metadata
if eval_rows:
Expand All @@ -175,8 +208,18 @@ def prepare_metadata(self) -> None:
na_position="last",
)
eval_df.reset_index(drop=True, inplace=True)
eval_csv = root / "tuev-eval-pyhealth.csv"
eval_df.to_csv(eval_csv, index=False)

# Try shared location first, fall back to cache if no write permission
eval_csv_shared = root / "tuev-eval-pyhealth.csv"
eval_csv_cache = cache_dir / "tuev-eval-pyhealth.csv"

try:
eval_csv_shared.parent.mkdir(parents=True, exist_ok=True)
eval_df.to_csv(eval_csv_shared, index=False)
logger.info(f"Wrote eval metadata to {eval_csv_shared}")
except (PermissionError, OSError):
eval_df.to_csv(eval_csv_cache, index=False)
logger.info(f"Wrote eval metadata to cache: {eval_csv_cache}")

@property
def default_task(self) -> EEGEventsTUEV:
Expand Down