11# -*- coding: utf-8 -*-
22from __future__ import print_function , division
3+ import sys
4+ import numpy as np
5+ from pymic .util .parse_config import *
6+ from pymic .net_run .agent_cls import ClassificationAgent
7+ from pymic .net_run .agent_seg import SegmentationAgent
8+ import SimpleITK as sitk
39
4- import os
5- import torch
6- import pandas as pd
7- import numpy as np
8- from skimage import io , transform
9- from torch .utils .data import Dataset , DataLoader
10- from torchvision import transforms , utils
11- from pymic .io .image_read_write import *
12- from pymic .io .nifty_dataset import NiftyDataset
13- from pymic .io .transform3d import *
10+ def save_array_as_nifty_volume (data , image_name , reference_name = None ):
11+ """
12+ Save a numpy array as nifty image
1413
15- if __name__ == "__main__" :
16- root_dir = '/home/guotai/data/brats/BraTS2018_Training'
17- csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv'
18-
19- crop1 = CropWithBoundingBox (start = None , output_size = [4 , 144 , 180 , 144 ])
20- norm = ChannelWiseNormalize (mean = None , std = None , zero_to_random = True )
21- labconv = LabelConvert ([0 , 1 , 2 , 4 ], [0 , 1 , 2 , 3 ])
22- crop2 = RandomCrop ([128 , 128 , 128 ])
23- rescale = Rescale ([64 , 64 , 64 ])
24- transform_list = [crop1 , norm , labconv , crop2 ,rescale , ToTensor ()]
25- transformed_dataset = NiftyDataset (root_dir = root_dir ,
26- csv_file = csv_file ,
27- modal_num = 4 ,
28- transform = transforms .Compose (transform_list ))
29- dataloader = DataLoader (transformed_dataset , batch_size = 4 ,
30- shuffle = True , num_workers = 4 )
31- # Helper function to show a batch
14+ :param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width].
15+ :param image_name: (str) The ouput file name.
16+ :param reference_name: (str) File name of the reference image of which
17+ meta information is used.
18+ """
19+ img = sitk .GetImageFromArray (data )
20+ if (reference_name is not None ):
21+ img_ref = sitk .ReadImage (reference_name )
22+ #img.CopyInformation(img_ref)
23+ img .SetSpacing (img_ref .GetSpacing ())
24+ img .SetOrigin (img_ref .GetOrigin ())
25+ img .SetDirection (img_ref .GetDirection ())
26+ sitk .WriteImage (img , image_name )
3227
28+ def main ():
29+ """
30+ The main function for running a network for training or inference.
31+ """
32+ if (len (sys .argv ) < 3 ):
33+ print ('Number of arguments should be 3. e.g.' )
34+ print ('python test_nifty_dataset.py train config.cfg' )
35+ exit ()
36+ stage = str (sys .argv [1 ])
37+ cfg_file = str (sys .argv [2 ])
38+ config = parse_config (cfg_file )
39+ config = synchronize_config (config )
40+ # task = config['dataset']['task_type']
41+ # assert task in ['cls', 'cls_nexcl', 'seg']
42+ # if(task == 'cls' or task == 'cls_nexcl'):
43+ # agent = ClassificationAgent(config, stage)
44+ # else:
45+ # agent = SegmentationAgent(config, stage)
46+ agent = SegmentationAgent (config , stage )
47+ agent .create_dataset ()
48+ data_loader = agent .train_loader if stage == "train" else agent .test_loader
49+ it = 0
50+ for data in data_loader :
51+ inputs = agent .convert_tensor_type (data ['image' ])
52+ labels_prob = agent .convert_tensor_type (data ['label_prob' ])
53+ for i in range (inputs .shape [0 ]):
54+ image_i = inputs [i ][0 ]
55+ label_i = np .argmax (labels_prob [i ], axis = 0 )
56+ print (image_i .shape , label_i .shape )
57+ image_name = "temp/image_{0:}_{1:}.nii.gz" .format (it , i )
58+ label_name = "temp/label_{0:}_{1:}.nii.gz" .format (it , i )
59+ save_array_as_nifty_volume (image_i , image_name , reference_name = None )
60+ save_array_as_nifty_volume (label_i , label_name , reference_name = None )
61+ it = it + 1
62+ if (it == 10 ):
63+ break
3364
34- for i_batch , sample_batched in enumerate ( dataloader ) :
35- print ( i_batch , sample_batched [ 'image' ]. size (),
36- sample_batched [ 'label' ]. size ())
65+ if __name__ == "__main__" :
66+ main ()
67+
3768
38- # # observe 4th batch and stop.
39- modals = ['flair' , 't1ce' , 't1' , 't2' ]
40- if i_batch == 0 :
41- image = sample_batched ['image' ].numpy ()
42- label = sample_batched ['label' ].numpy ()
43- for i in range (image .shape [0 ]):
44- for mod in range (4 ):
45- image_i = image [i ][mod ]
46- label_i = label [i ][0 ]
47- image_name = "temp/image_{0:}_{1:}.nii.gz" .format (i , modals [mod ])
48- label_name = "temp/label_{0:}.nii.gz" .format (i )
49- save_array_as_nifty_volume (image_i , image_name , reference_name = None )
50- save_array_as_nifty_volume (label_i , label_name , reference_name = None )
0 commit comments