Skip to content

Commit 422cc6d

Browse files
4258 Add TCIA dataset (#4610)
* Add TCIA dataset Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent 9fb9d98 commit 422cc6d

File tree

9 files changed

+777
-31
lines changed

9 files changed

+777
-31
lines changed

docs/source/apps.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ Applications
1515
.. autoclass:: DecathlonDataset
1616
:members:
1717

18+
.. autoclass:: TciaDataset
19+
:members:
20+
1821
.. autoclass:: CrossValidation
1922
:members:
2023

monai/apps/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset
12+
from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset
1313
from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar
1414
from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger

monai/apps/datasets.py

Lines changed: 277 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,26 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import os
13+
import shutil
1214
import sys
15+
import warnings
1316
from pathlib import Path
14-
from typing import Callable, Dict, List, Optional, Sequence, Union
17+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
1518

1619
import numpy as np
1720

21+
from monai.apps.tcia import (
22+
download_tcia_series_instance,
23+
get_tcia_metadata,
24+
get_tcia_ref_uid,
25+
match_tcia_ref_uid_in_study,
26+
)
1827
from monai.apps.utils import download_and_extract
1928
from monai.config.type_definitions import PathLike
2029
from monai.data import (
2130
CacheDataset,
31+
PydicomReader,
2232
load_decathlon_datalist,
2333
load_decathlon_properties,
2434
partition_dataset,
@@ -27,7 +37,7 @@
2737
from monai.transforms import LoadImaged, Randomizable
2838
from monai.utils import ensure_tuple
2939

30-
__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"]
40+
__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation", "TciaDataset"]
3141

3242

3343
class MedNISTDataset(Randomizable, CacheDataset):
@@ -194,8 +204,8 @@ class DecathlonDataset(Randomizable, CacheDataset):
194204
for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D].
195205
download: whether to download and extract the Decathlon from resource link, default is False.
196206
if expected file already exists, skip downloading even set it to True.
197-
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
198207
user can manually copy tar file or dataset folder to the root directory.
208+
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
199209
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
200210
note to set same seed for `training` and `validation` sections.
201211
cache_num: number of items to be cached. Default is `sys.maxsize`.
@@ -379,6 +389,270 @@ def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
379389
return [datalist[i] for i in self.indices]
380390

381391

392+
class TciaDataset(Randomizable, CacheDataset):
393+
"""
394+
The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset
395+
and generate items for training, validation or test.
396+
397+
The Highdicom library is used to load dicom data with modality "SEG", but only a part of collections are
398+
supoorted, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability"
399+
and "PROSTATEx". Therefore, if "seg" is included in `keys` of the `LoadImaged` transform and loading some
400+
other collections, errors may be raised. For supported collections, the original "SEG" information may not
401+
always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use
402+
the `label_dict` argument of `PydicomReader` when calling the `LoadImaged` transform. The prepared label dicts
403+
of collections that are mentioned above is also saved in: `monai.apps.tcia.TCIA_LABEL_DICT`. You can also refer
404+
to the second example bellow.
405+
406+
407+
This class is based on :py:class:`monai.data.CacheDataset` to accelerate the training process.
408+
409+
Args:
410+
root_dir: user's local directory for caching and loading the TCIA dataset.
411+
collection: name of a TCIA collection.
412+
a TCIA dataset is defined as a collection. Please check the following list to browse
413+
the collection list (only public collections can be downloaded):
414+
https://www.cancerimagingarchive.net/collections/
415+
section: expected data section, can be: `training`, `validation` or `test`.
416+
transform: transforms to execute operations on input data.
417+
for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D].
418+
If not specified, `LoadImaged(reader="PydicomReader", keys=["image"])` will be used as the default
419+
transform. In addition, we suggest to set the argument `labels` for `PydicomReader` if segmentations
420+
are needed to be loaded. The original labels for each dicom series may be different, using this argument
421+
is able to unify the format of labels.
422+
download: whether to download and extract the dataset, default is False.
423+
if expected file already exists, skip downloading even set it to True.
424+
user can manually copy tar file or dataset folder to the root directory.
425+
download_len: number of series that will be downloaded, the value should be larger than 0 or -1, where -1 means
426+
all series will be downloaded. Default is -1.
427+
seg_type: modality type of segmentation that is used to do the first step download. Default is "SEG".
428+
modality_tag: tag of modality. Default is (0x0008, 0x0060).
429+
ref_series_uid_tag: tag of referenced Series Instance UID. Default is (0x0020, 0x000e).
430+
ref_sop_uid_tag: tag of referenced SOP Instance UID. Default is (0x0008, 0x1155).
431+
specific_tags: tags that will be loaded for "SEG" series. This argument will be used in
432+
`monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),
433+
(0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].
434+
val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
435+
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
436+
note to set same seed for `training` and `validation` sections.
437+
cache_num: number of items to be cached. Default is `sys.maxsize`.
438+
will take the minimum of (cache_num, data_length x cache_rate, data_length).
439+
cache_rate: percentage of cached data in total, default is 0.0 (no cache).
440+
will take the minimum of (cache_num, data_length x cache_rate, data_length).
441+
num_workers: the number of worker threads to use.
442+
If num_workers is None then the number returned by os.cpu_count() is used.
443+
If a value less than 1 is speficied, 1 will be used instead.
444+
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
445+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
446+
default to `True`. if the random transforms don't modify the cached content
447+
(for example, randomly crop from the cached image and deepcopy the crop region)
448+
or if every cache item is only used once in a `multi-processing` environment,
449+
may set `copy=False` for better performance.
450+
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
451+
it may help improve the performance of following logic.
452+
453+
Example::
454+
455+
# collection is "Pancreatic-CT-CBCT-SEG", seg_type is "RTSTRUCT"
456+
data = TciaDataset(
457+
root_dir="./", collection="Pancreatic-CT-CBCT-SEG", seg_type="RTSTRUCT", download=True
458+
)
459+
460+
# collection is "C4KC-KiTS", seg_type is "SEG", and load both images and segmentations
461+
from monai.apps.tcia import TCIA_LABEL_DICT
462+
transform = Compose(
463+
[
464+
LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=TCIA_LABEL_DICT["C4KC-KiTS"]),
465+
EnsureChannelFirstd(keys=["image", "seg"]),
466+
ResampleToMatchd(keys="image", key_dst="seg"),
467+
]
468+
)
469+
data = TciaDataset(
470+
root_dir="./", collection="C4KC-KiTS", section="validation", seed=12345, download=True
471+
)
472+
473+
print(data[0]["seg"].shape)
474+
475+
"""
476+
477+
def __init__(
478+
self,
479+
root_dir: PathLike,
480+
collection: str,
481+
section: str,
482+
transform: Union[Sequence[Callable], Callable] = (),
483+
download: bool = False,
484+
download_len: int = -1,
485+
seg_type: str = "SEG",
486+
modality_tag: Tuple = (0x0008, 0x0060),
487+
ref_series_uid_tag: Tuple = (0x0020, 0x000E),
488+
ref_sop_uid_tag: Tuple = (0x0008, 0x1155),
489+
specific_tags: Tuple = (
490+
(0x0008, 0x1115), # Referenced Series Sequence
491+
(0x0008, 0x1140), # Referenced Image Sequence
492+
(0x3006, 0x0010), # Referenced Frame of Reference Sequence
493+
(0x0020, 0x000D), # Study Instance UID
494+
(0x0010, 0x0010), # Patient's Name
495+
(0x0010, 0x0020), # Patient ID
496+
(0x0020, 0x0011), # Series Number
497+
(0x0020, 0x0012), # Acquisition Number
498+
),
499+
seed: int = 0,
500+
val_frac: float = 0.2,
501+
cache_num: int = sys.maxsize,
502+
cache_rate: float = 0.0,
503+
num_workers: int = 1,
504+
progress: bool = True,
505+
copy_cache: bool = True,
506+
as_contiguous: bool = True,
507+
) -> None:
508+
root_dir = Path(root_dir)
509+
if not root_dir.is_dir():
510+
raise ValueError("Root directory root_dir must be a directory.")
511+
512+
self.section = section
513+
self.val_frac = val_frac
514+
self.seg_type = seg_type
515+
self.modality_tag = modality_tag
516+
self.ref_series_uid_tag = ref_series_uid_tag
517+
self.ref_sop_uid_tag = ref_sop_uid_tag
518+
519+
self.set_random_state(seed=seed)
520+
download_dir = os.path.join(root_dir, collection)
521+
load_tags = list(specific_tags)
522+
load_tags += [modality_tag]
523+
self.load_tags = load_tags
524+
if download:
525+
seg_series_list = get_tcia_metadata(
526+
query=f"getSeries?Collection={collection}&Modality={seg_type}", attribute="SeriesInstanceUID"
527+
)
528+
if download_len > 0:
529+
seg_series_list = seg_series_list[:download_len]
530+
if len(seg_series_list) == 0:
531+
raise ValueError(f"Cannot find data with collection: {collection} seg_type: {seg_type}")
532+
for series_uid in seg_series_list:
533+
self._download_series_reference_data(series_uid, download_dir)
534+
535+
if not os.path.exists(download_dir):
536+
raise RuntimeError(f"Cannot find dataset directory: {download_dir}.")
537+
538+
self.indices: np.ndarray = np.array([])
539+
self.datalist = self._generate_data_list(download_dir)
540+
541+
if transform == ():
542+
transform = LoadImaged(reader="PydicomReader", keys=["image"])
543+
CacheDataset.__init__(
544+
self,
545+
data=self.datalist,
546+
transform=transform,
547+
cache_num=cache_num,
548+
cache_rate=cache_rate,
549+
num_workers=num_workers,
550+
progress=progress,
551+
copy_cache=copy_cache,
552+
as_contiguous=as_contiguous,
553+
)
554+
555+
def get_indices(self) -> np.ndarray:
556+
"""
557+
Get the indices of datalist used in this dataset.
558+
559+
"""
560+
return self.indices
561+
562+
def randomize(self, data: np.ndarray) -> None:
563+
self.R.shuffle(data)
564+
565+
def _download_series_reference_data(self, series_uid: str, download_dir: str):
566+
"""
567+
First of all, download a series from TCIA according to `series_uid`.
568+
Then find all referenced series and download.
569+
"""
570+
seg_first_dir = os.path.join(download_dir, "raw", series_uid)
571+
download_tcia_series_instance(
572+
series_uid=series_uid, download_dir=download_dir, output_dir=seg_first_dir, check_md5=False
573+
)
574+
dicom_files = [f for f in os.listdir(seg_first_dir) if f.endswith(".dcm")]
575+
# achieve series number and patient id from the first dicom file
576+
dcm_path = os.path.join(seg_first_dir, dicom_files[0])
577+
ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)
578+
# (0x0010,0x0020) and (0x0010,0x0010), better to be contained in `specific_tags`
579+
patient_id = ds.PatientID if ds.PatientID else ds.PatientName
580+
if not patient_id:
581+
warnings.warn(f"unable to find patient name of dicom file: {dcm_path}, use 'patient' instead.")
582+
patient_id = "patient"
583+
# (0x0020,0x0011) and (0x0020,0x0012), better to be contained in `specific_tags`
584+
series_num = ds.SeriesNumber if ds.SeriesNumber else ds.AcquisitionNumber
585+
if not series_num:
586+
warnings.warn(f"unable to find series number of dicom file: {dcm_path}, use '0' instead.")
587+
series_num = 0
588+
589+
series_num = str(series_num)
590+
seg_dir = os.path.join(download_dir, patient_id, series_num, self.seg_type.lower())
591+
dcm_dir = os.path.join(download_dir, patient_id, series_num, "image")
592+
593+
# get ref uuid
594+
ref_uid_list = []
595+
for dcm_file in dicom_files:
596+
dcm_path = os.path.join(seg_first_dir, dcm_file)
597+
ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)
598+
if ds[self.modality_tag].value == self.seg_type:
599+
ref_uid = get_tcia_ref_uid(
600+
ds, find_sop=False, ref_series_uid_tag=self.ref_series_uid_tag, ref_sop_uid_tag=self.ref_sop_uid_tag
601+
)
602+
if ref_uid == "":
603+
ref_sop_uid = get_tcia_ref_uid(
604+
ds,
605+
find_sop=True,
606+
ref_series_uid_tag=self.ref_series_uid_tag,
607+
ref_sop_uid_tag=self.ref_sop_uid_tag,
608+
)
609+
ref_uid = match_tcia_ref_uid_in_study(ds.StudyInstanceUID, ref_sop_uid)
610+
if ref_uid != "":
611+
ref_uid_list.append(ref_uid)
612+
if not ref_uid_list:
613+
warnings.warn(f"Cannot find the referenced Series Instance UID from series: {series_uid}.")
614+
else:
615+
download_tcia_series_instance(
616+
series_uid=ref_uid_list[0], download_dir=download_dir, output_dir=dcm_dir, check_md5=False
617+
)
618+
if not os.path.exists(seg_dir):
619+
shutil.copytree(seg_first_dir, seg_dir)
620+
621+
def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
622+
# the types of the item in data list should be compatible with the dataloader
623+
dataset_dir = Path(dataset_dir)
624+
datalist = []
625+
patient_list = [f.name for f in os.scandir(dataset_dir) if f.is_dir() and f.name != "raw"]
626+
for patient_id in patient_list:
627+
series_list = [f.name for f in os.scandir(os.path.join(dataset_dir, patient_id)) if f.is_dir()]
628+
for series_num in series_list:
629+
seg_key = self.seg_type.lower()
630+
image_path = os.path.join(dataset_dir, patient_id, series_num, "image")
631+
mask_path = os.path.join(dataset_dir, patient_id, series_num, seg_key)
632+
633+
if os.path.exists(image_path):
634+
datalist.append({"image": image_path, seg_key: mask_path})
635+
else:
636+
datalist.append({seg_key: mask_path})
637+
638+
return self._split_datalist(datalist)
639+
640+
def _split_datalist(self, datalist: List[Dict]) -> List[Dict]:
641+
if self.section == "test":
642+
return datalist
643+
length = len(datalist)
644+
indices = np.arange(length)
645+
self.randomize(indices)
646+
647+
val_length = int(length * self.val_frac)
648+
if self.section == "training":
649+
self.indices = indices[val_length:]
650+
else:
651+
self.indices = indices[:val_length]
652+
653+
return [datalist[i] for i in self.indices]
654+
655+
382656
class CrossValidation:
383657
"""
384658
Cross validation dataset based on the general dataset which must have `_split_datalist` API.

monai/apps/tcia/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from .label_desc import TCIA_LABEL_DICT
13+
from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study

monai/apps/tcia/label_desc.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
from typing import Dict
14+
15+
__all__ = ["TCIA_LABEL_DICT"]
16+
17+
18+
TCIA_LABEL_DICT: Dict[str, Dict[str, int]] = {
19+
"C4KC-KiTS": {"Kidney": 0, "Renal Tumor": 1},
20+
"NSCLC-Radiomics": {
21+
"Esophagus": 0,
22+
"GTV-1": 1,
23+
"Lungs-Total": 2,
24+
"Spinal-Cord": 3,
25+
"Lung-Left": 4,
26+
"Lung-Right": 5,
27+
"Heart": 6,
28+
},
29+
"NSCLC-Radiomics-Interobserver1": {
30+
"GTV-1auto-1": 0,
31+
"GTV-1auto-2": 1,
32+
"GTV-1auto-3": 2,
33+
"GTV-1auto-4": 3,
34+
"GTV-1auto-5": 4,
35+
"GTV-1vis-1": 5,
36+
"GTV-1vis-2": 6,
37+
"GTV-1vis-3": 7,
38+
"GTV-1vis-4": 8,
39+
"GTV-1vis-5": 9,
40+
},
41+
"QIN-PROSTATE-Repeatability": {"NormalROI_PZ_1": 0, "TumorROI_PZ_1": 1, "PeripheralZone": 2, "WholeGland": 3},
42+
"PROSTATEx": {
43+
"Prostate": 0,
44+
"Peripheral zone of prostate": 1,
45+
"Transition zone of prostate": 2,
46+
"Distal prostatic urethra": 3,
47+
"Anterior fibromuscular stroma of prostate": 4,
48+
},
49+
}

0 commit comments

Comments
 (0)