|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
| 12 | +import os |
| 13 | +import shutil |
12 | 14 | import sys |
| 15 | +import warnings |
13 | 16 | from pathlib import Path |
14 | | -from typing import Callable, Dict, List, Optional, Sequence, Union |
| 17 | +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
15 | 18 |
|
16 | 19 | import numpy as np |
17 | 20 |
|
| 21 | +from monai.apps.tcia import ( |
| 22 | + download_tcia_series_instance, |
| 23 | + get_tcia_metadata, |
| 24 | + get_tcia_ref_uid, |
| 25 | + match_tcia_ref_uid_in_study, |
| 26 | +) |
18 | 27 | from monai.apps.utils import download_and_extract |
19 | 28 | from monai.config.type_definitions import PathLike |
20 | 29 | from monai.data import ( |
21 | 30 | CacheDataset, |
| 31 | + PydicomReader, |
22 | 32 | load_decathlon_datalist, |
23 | 33 | load_decathlon_properties, |
24 | 34 | partition_dataset, |
|
27 | 37 | from monai.transforms import LoadImaged, Randomizable |
28 | 38 | from monai.utils import ensure_tuple |
29 | 39 |
|
30 | | -__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation"] |
| 40 | +__all__ = ["MedNISTDataset", "DecathlonDataset", "CrossValidation", "TciaDataset"] |
31 | 41 |
|
32 | 42 |
|
33 | 43 | class MedNISTDataset(Randomizable, CacheDataset): |
@@ -194,8 +204,8 @@ class DecathlonDataset(Randomizable, CacheDataset): |
194 | 204 | for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D]. |
195 | 205 | download: whether to download and extract the Decathlon from resource link, default is False. |
196 | 206 | if expected file already exists, skip downloading even set it to True. |
197 | | - val_frac: percentage of of validation fraction in the whole dataset, default is 0.2. |
198 | 207 | user can manually copy tar file or dataset folder to the root directory. |
| 208 | + val_frac: percentage of of validation fraction in the whole dataset, default is 0.2. |
199 | 209 | seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0. |
200 | 210 | note to set same seed for `training` and `validation` sections. |
201 | 211 | cache_num: number of items to be cached. Default is `sys.maxsize`. |
@@ -379,6 +389,270 @@ def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: |
379 | 389 | return [datalist[i] for i in self.indices] |
380 | 390 |
|
381 | 391 |
|
| 392 | +class TciaDataset(Randomizable, CacheDataset): |
| 393 | + """ |
| 394 | + The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset |
| 395 | + and generate items for training, validation or test. |
| 396 | +
|
| 397 | + The Highdicom library is used to load dicom data with modality "SEG", but only a part of collections are |
| 398 | + supoorted, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability" |
| 399 | + and "PROSTATEx". Therefore, if "seg" is included in `keys` of the `LoadImaged` transform and loading some |
| 400 | + other collections, errors may be raised. For supported collections, the original "SEG" information may not |
| 401 | + always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use |
| 402 | + the `label_dict` argument of `PydicomReader` when calling the `LoadImaged` transform. The prepared label dicts |
| 403 | + of collections that are mentioned above is also saved in: `monai.apps.tcia.TCIA_LABEL_DICT`. You can also refer |
| 404 | + to the second example bellow. |
| 405 | +
|
| 406 | +
|
| 407 | + This class is based on :py:class:`monai.data.CacheDataset` to accelerate the training process. |
| 408 | +
|
| 409 | + Args: |
| 410 | + root_dir: user's local directory for caching and loading the TCIA dataset. |
| 411 | + collection: name of a TCIA collection. |
| 412 | + a TCIA dataset is defined as a collection. Please check the following list to browse |
| 413 | + the collection list (only public collections can be downloaded): |
| 414 | + https://www.cancerimagingarchive.net/collections/ |
| 415 | + section: expected data section, can be: `training`, `validation` or `test`. |
| 416 | + transform: transforms to execute operations on input data. |
| 417 | + for further usage, use `AddChanneld` or `AsChannelFirstd` to convert the shape to [C, H, W, D]. |
| 418 | + If not specified, `LoadImaged(reader="PydicomReader", keys=["image"])` will be used as the default |
| 419 | + transform. In addition, we suggest to set the argument `labels` for `PydicomReader` if segmentations |
| 420 | + are needed to be loaded. The original labels for each dicom series may be different, using this argument |
| 421 | + is able to unify the format of labels. |
| 422 | + download: whether to download and extract the dataset, default is False. |
| 423 | + if expected file already exists, skip downloading even set it to True. |
| 424 | + user can manually copy tar file or dataset folder to the root directory. |
| 425 | + download_len: number of series that will be downloaded, the value should be larger than 0 or -1, where -1 means |
| 426 | + all series will be downloaded. Default is -1. |
| 427 | + seg_type: modality type of segmentation that is used to do the first step download. Default is "SEG". |
| 428 | + modality_tag: tag of modality. Default is (0x0008, 0x0060). |
| 429 | + ref_series_uid_tag: tag of referenced Series Instance UID. Default is (0x0020, 0x000e). |
| 430 | + ref_sop_uid_tag: tag of referenced SOP Instance UID. Default is (0x0008, 0x1155). |
| 431 | + specific_tags: tags that will be loaded for "SEG" series. This argument will be used in |
| 432 | + `monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010), |
| 433 | + (0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)]. |
| 434 | + val_frac: percentage of of validation fraction in the whole dataset, default is 0.2. |
| 435 | + seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0. |
| 436 | + note to set same seed for `training` and `validation` sections. |
| 437 | + cache_num: number of items to be cached. Default is `sys.maxsize`. |
| 438 | + will take the minimum of (cache_num, data_length x cache_rate, data_length). |
| 439 | + cache_rate: percentage of cached data in total, default is 0.0 (no cache). |
| 440 | + will take the minimum of (cache_num, data_length x cache_rate, data_length). |
| 441 | + num_workers: the number of worker threads to use. |
| 442 | + If num_workers is None then the number returned by os.cpu_count() is used. |
| 443 | + If a value less than 1 is speficied, 1 will be used instead. |
| 444 | + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. |
| 445 | + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, |
| 446 | + default to `True`. if the random transforms don't modify the cached content |
| 447 | + (for example, randomly crop from the cached image and deepcopy the crop region) |
| 448 | + or if every cache item is only used once in a `multi-processing` environment, |
| 449 | + may set `copy=False` for better performance. |
| 450 | + as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. |
| 451 | + it may help improve the performance of following logic. |
| 452 | +
|
| 453 | + Example:: |
| 454 | +
|
| 455 | + # collection is "Pancreatic-CT-CBCT-SEG", seg_type is "RTSTRUCT" |
| 456 | + data = TciaDataset( |
| 457 | + root_dir="./", collection="Pancreatic-CT-CBCT-SEG", seg_type="RTSTRUCT", download=True |
| 458 | + ) |
| 459 | +
|
| 460 | + # collection is "C4KC-KiTS", seg_type is "SEG", and load both images and segmentations |
| 461 | + from monai.apps.tcia import TCIA_LABEL_DICT |
| 462 | + transform = Compose( |
| 463 | + [ |
| 464 | + LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=TCIA_LABEL_DICT["C4KC-KiTS"]), |
| 465 | + EnsureChannelFirstd(keys=["image", "seg"]), |
| 466 | + ResampleToMatchd(keys="image", key_dst="seg"), |
| 467 | + ] |
| 468 | + ) |
| 469 | + data = TciaDataset( |
| 470 | + root_dir="./", collection="C4KC-KiTS", section="validation", seed=12345, download=True |
| 471 | + ) |
| 472 | +
|
| 473 | + print(data[0]["seg"].shape) |
| 474 | +
|
| 475 | + """ |
| 476 | + |
| 477 | + def __init__( |
| 478 | + self, |
| 479 | + root_dir: PathLike, |
| 480 | + collection: str, |
| 481 | + section: str, |
| 482 | + transform: Union[Sequence[Callable], Callable] = (), |
| 483 | + download: bool = False, |
| 484 | + download_len: int = -1, |
| 485 | + seg_type: str = "SEG", |
| 486 | + modality_tag: Tuple = (0x0008, 0x0060), |
| 487 | + ref_series_uid_tag: Tuple = (0x0020, 0x000E), |
| 488 | + ref_sop_uid_tag: Tuple = (0x0008, 0x1155), |
| 489 | + specific_tags: Tuple = ( |
| 490 | + (0x0008, 0x1115), # Referenced Series Sequence |
| 491 | + (0x0008, 0x1140), # Referenced Image Sequence |
| 492 | + (0x3006, 0x0010), # Referenced Frame of Reference Sequence |
| 493 | + (0x0020, 0x000D), # Study Instance UID |
| 494 | + (0x0010, 0x0010), # Patient's Name |
| 495 | + (0x0010, 0x0020), # Patient ID |
| 496 | + (0x0020, 0x0011), # Series Number |
| 497 | + (0x0020, 0x0012), # Acquisition Number |
| 498 | + ), |
| 499 | + seed: int = 0, |
| 500 | + val_frac: float = 0.2, |
| 501 | + cache_num: int = sys.maxsize, |
| 502 | + cache_rate: float = 0.0, |
| 503 | + num_workers: int = 1, |
| 504 | + progress: bool = True, |
| 505 | + copy_cache: bool = True, |
| 506 | + as_contiguous: bool = True, |
| 507 | + ) -> None: |
| 508 | + root_dir = Path(root_dir) |
| 509 | + if not root_dir.is_dir(): |
| 510 | + raise ValueError("Root directory root_dir must be a directory.") |
| 511 | + |
| 512 | + self.section = section |
| 513 | + self.val_frac = val_frac |
| 514 | + self.seg_type = seg_type |
| 515 | + self.modality_tag = modality_tag |
| 516 | + self.ref_series_uid_tag = ref_series_uid_tag |
| 517 | + self.ref_sop_uid_tag = ref_sop_uid_tag |
| 518 | + |
| 519 | + self.set_random_state(seed=seed) |
| 520 | + download_dir = os.path.join(root_dir, collection) |
| 521 | + load_tags = list(specific_tags) |
| 522 | + load_tags += [modality_tag] |
| 523 | + self.load_tags = load_tags |
| 524 | + if download: |
| 525 | + seg_series_list = get_tcia_metadata( |
| 526 | + query=f"getSeries?Collection={collection}&Modality={seg_type}", attribute="SeriesInstanceUID" |
| 527 | + ) |
| 528 | + if download_len > 0: |
| 529 | + seg_series_list = seg_series_list[:download_len] |
| 530 | + if len(seg_series_list) == 0: |
| 531 | + raise ValueError(f"Cannot find data with collection: {collection} seg_type: {seg_type}") |
| 532 | + for series_uid in seg_series_list: |
| 533 | + self._download_series_reference_data(series_uid, download_dir) |
| 534 | + |
| 535 | + if not os.path.exists(download_dir): |
| 536 | + raise RuntimeError(f"Cannot find dataset directory: {download_dir}.") |
| 537 | + |
| 538 | + self.indices: np.ndarray = np.array([]) |
| 539 | + self.datalist = self._generate_data_list(download_dir) |
| 540 | + |
| 541 | + if transform == (): |
| 542 | + transform = LoadImaged(reader="PydicomReader", keys=["image"]) |
| 543 | + CacheDataset.__init__( |
| 544 | + self, |
| 545 | + data=self.datalist, |
| 546 | + transform=transform, |
| 547 | + cache_num=cache_num, |
| 548 | + cache_rate=cache_rate, |
| 549 | + num_workers=num_workers, |
| 550 | + progress=progress, |
| 551 | + copy_cache=copy_cache, |
| 552 | + as_contiguous=as_contiguous, |
| 553 | + ) |
| 554 | + |
| 555 | + def get_indices(self) -> np.ndarray: |
| 556 | + """ |
| 557 | + Get the indices of datalist used in this dataset. |
| 558 | +
|
| 559 | + """ |
| 560 | + return self.indices |
| 561 | + |
| 562 | + def randomize(self, data: np.ndarray) -> None: |
| 563 | + self.R.shuffle(data) |
| 564 | + |
| 565 | + def _download_series_reference_data(self, series_uid: str, download_dir: str): |
| 566 | + """ |
| 567 | + First of all, download a series from TCIA according to `series_uid`. |
| 568 | + Then find all referenced series and download. |
| 569 | + """ |
| 570 | + seg_first_dir = os.path.join(download_dir, "raw", series_uid) |
| 571 | + download_tcia_series_instance( |
| 572 | + series_uid=series_uid, download_dir=download_dir, output_dir=seg_first_dir, check_md5=False |
| 573 | + ) |
| 574 | + dicom_files = [f for f in os.listdir(seg_first_dir) if f.endswith(".dcm")] |
| 575 | + # achieve series number and patient id from the first dicom file |
| 576 | + dcm_path = os.path.join(seg_first_dir, dicom_files[0]) |
| 577 | + ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path) |
| 578 | + # (0x0010,0x0020) and (0x0010,0x0010), better to be contained in `specific_tags` |
| 579 | + patient_id = ds.PatientID if ds.PatientID else ds.PatientName |
| 580 | + if not patient_id: |
| 581 | + warnings.warn(f"unable to find patient name of dicom file: {dcm_path}, use 'patient' instead.") |
| 582 | + patient_id = "patient" |
| 583 | + # (0x0020,0x0011) and (0x0020,0x0012), better to be contained in `specific_tags` |
| 584 | + series_num = ds.SeriesNumber if ds.SeriesNumber else ds.AcquisitionNumber |
| 585 | + if not series_num: |
| 586 | + warnings.warn(f"unable to find series number of dicom file: {dcm_path}, use '0' instead.") |
| 587 | + series_num = 0 |
| 588 | + |
| 589 | + series_num = str(series_num) |
| 590 | + seg_dir = os.path.join(download_dir, patient_id, series_num, self.seg_type.lower()) |
| 591 | + dcm_dir = os.path.join(download_dir, patient_id, series_num, "image") |
| 592 | + |
| 593 | + # get ref uuid |
| 594 | + ref_uid_list = [] |
| 595 | + for dcm_file in dicom_files: |
| 596 | + dcm_path = os.path.join(seg_first_dir, dcm_file) |
| 597 | + ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path) |
| 598 | + if ds[self.modality_tag].value == self.seg_type: |
| 599 | + ref_uid = get_tcia_ref_uid( |
| 600 | + ds, find_sop=False, ref_series_uid_tag=self.ref_series_uid_tag, ref_sop_uid_tag=self.ref_sop_uid_tag |
| 601 | + ) |
| 602 | + if ref_uid == "": |
| 603 | + ref_sop_uid = get_tcia_ref_uid( |
| 604 | + ds, |
| 605 | + find_sop=True, |
| 606 | + ref_series_uid_tag=self.ref_series_uid_tag, |
| 607 | + ref_sop_uid_tag=self.ref_sop_uid_tag, |
| 608 | + ) |
| 609 | + ref_uid = match_tcia_ref_uid_in_study(ds.StudyInstanceUID, ref_sop_uid) |
| 610 | + if ref_uid != "": |
| 611 | + ref_uid_list.append(ref_uid) |
| 612 | + if not ref_uid_list: |
| 613 | + warnings.warn(f"Cannot find the referenced Series Instance UID from series: {series_uid}.") |
| 614 | + else: |
| 615 | + download_tcia_series_instance( |
| 616 | + series_uid=ref_uid_list[0], download_dir=download_dir, output_dir=dcm_dir, check_md5=False |
| 617 | + ) |
| 618 | + if not os.path.exists(seg_dir): |
| 619 | + shutil.copytree(seg_first_dir, seg_dir) |
| 620 | + |
| 621 | + def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]: |
| 622 | + # the types of the item in data list should be compatible with the dataloader |
| 623 | + dataset_dir = Path(dataset_dir) |
| 624 | + datalist = [] |
| 625 | + patient_list = [f.name for f in os.scandir(dataset_dir) if f.is_dir() and f.name != "raw"] |
| 626 | + for patient_id in patient_list: |
| 627 | + series_list = [f.name for f in os.scandir(os.path.join(dataset_dir, patient_id)) if f.is_dir()] |
| 628 | + for series_num in series_list: |
| 629 | + seg_key = self.seg_type.lower() |
| 630 | + image_path = os.path.join(dataset_dir, patient_id, series_num, "image") |
| 631 | + mask_path = os.path.join(dataset_dir, patient_id, series_num, seg_key) |
| 632 | + |
| 633 | + if os.path.exists(image_path): |
| 634 | + datalist.append({"image": image_path, seg_key: mask_path}) |
| 635 | + else: |
| 636 | + datalist.append({seg_key: mask_path}) |
| 637 | + |
| 638 | + return self._split_datalist(datalist) |
| 639 | + |
| 640 | + def _split_datalist(self, datalist: List[Dict]) -> List[Dict]: |
| 641 | + if self.section == "test": |
| 642 | + return datalist |
| 643 | + length = len(datalist) |
| 644 | + indices = np.arange(length) |
| 645 | + self.randomize(indices) |
| 646 | + |
| 647 | + val_length = int(length * self.val_frac) |
| 648 | + if self.section == "training": |
| 649 | + self.indices = indices[val_length:] |
| 650 | + else: |
| 651 | + self.indices = indices[:val_length] |
| 652 | + |
| 653 | + return [datalist[i] for i in self.indices] |
| 654 | + |
| 655 | + |
382 | 656 | class CrossValidation: |
383 | 657 | """ |
384 | 658 | Cross validation dataset based on the general dataset which must have `_split_datalist` API. |
|
0 commit comments