2020import torch_tensorrt
2121from monai .inferers import sliding_window_inference
2222from monai .networks .nets import SegResNet
23- from monai .transforms import (Activationsd , AsDiscreted , Compose ,
24- EnsureChannelFirstd , EnsureTyped , Invertd ,
25- LoadImaged , NormalizeIntensityd , Orientationd ,
26- ScaleIntensityd , Spacingd )
27-
28- from utils import (prepare_model_weights , prepare_tensorrt_model ,
29- prepare_test_datalist )
23+ from monai .transforms import (
24+ Activationsd ,
25+ AsDiscreted ,
26+ Compose ,
27+ EnsureChannelFirstd ,
28+ EnsureTyped ,
29+ Invertd ,
30+ LoadImaged ,
31+ NormalizeIntensityd ,
32+ Orientationd ,
33+ ScaleIntensityd ,
34+ Spacingd ,
35+ )
36+
37+ from utils import prepare_model_weights , prepare_tensorrt_model , prepare_test_datalist
3038
3139
3240def get_transforms (device , gpu_loading_flag = False , gpu_transforms_flag = False ):
@@ -49,6 +57,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
4957
5058 return infer_transforms
5159
60+
5261def get_post_transforms (infer_transforms ):
5362 post_transforms = Compose (
5463 [
@@ -65,6 +74,7 @@ def get_post_transforms(infer_transforms):
6574 )
6675 return post_transforms
6776
77+
6878def get_model (device , weights_path , trt_model_path , trt_flag = False ):
6979 if not trt_flag :
7080 model = SegResNet (
@@ -84,11 +94,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
8494 model = torch .jit .load (trt_model_path )
8595 return model
8696
97+
8798def run_inference (data_list , infer_transforms , model , device , benchmark_type ):
8899 total_time_dict = {}
89100 roi_size = (96 , 96 , 96 )
90101 sw_batch_size = 1
91-
102+
92103 for idx , sample in enumerate (data_list ):
93104 start = timer ()
94105 data = infer_transforms ({"image" : sample })
@@ -117,6 +128,7 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
117128
118129 return total_time_dict
119130
131+
120132def main ():
121133 parser = argparse .ArgumentParser (description = "Run inference benchmark." )
122134 parser .add_argument ("--benchmark_type" , type = str , default = "original" , help = "Type of benchmark to run" )
@@ -146,5 +158,6 @@ def main():
146158 df = pd .DataFrame (list (total_time_dict .items ()), columns = ["file_name" , "time" ])
147159 df .to_csv (os .path .join (root_dir , f"time_{ args .benchmark_type } .csv" ), index = False )
148160
161+
149162if __name__ == "__main__" :
150163 main ()
0 commit comments