@@ -49,6 +49,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
4949
5050 return infer_transforms
5151
52+
5253def get_post_transforms (infer_transforms ):
5354 post_transforms = Compose (
5455 [
@@ -65,6 +66,7 @@ def get_post_transforms(infer_transforms):
6566 )
6667 return post_transforms
6768
69+
6870def get_model (device , weights_path , trt_model_path , trt_flag = False ):
6971 if not trt_flag :
7072 model = SegResNet (
@@ -84,11 +86,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
8486 model = torch .jit .load (trt_model_path )
8587 return model
8688
89+
8790def run_inference (data_list , infer_transforms , model , device , benchmark_type ):
8891 total_time_dict = {}
8992 roi_size = (96 , 96 , 96 )
90- sw_batch_size = 1
91-
93+ sw_batch_size = 4
94+
9295 for idx , sample in enumerate (data_list ):
9396 start = timer ()
9497 data = infer_transforms ({"image" : sample })
@@ -114,9 +117,10 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
114117 sample_name = sample .split ("/" )[- 1 ]
115118 if idx > 0 :
116119 total_time_dict [sample_name ] = end - start
117-
120+ print ( f"Time taken for { sample_name } : { end - start } seconds" )
118121 return total_time_dict
119122
123+
120124def main ():
121125 parser = argparse .ArgumentParser (description = "Run inference benchmark." )
122126 parser .add_argument ("--benchmark_type" , type = str , default = "original" , help = "Type of benchmark to run" )
@@ -128,8 +132,8 @@ def main():
128132 torch_tensorrt .runtime .set_multi_device_safe_mode (True )
129133 device = torch .device ("cuda:0" ) if torch .cuda .is_available () else torch .device ("cpu" )
130134 train_files = prepare_test_datalist (root_dir )
131- # since the dataset is too large, the smallest 21 files are used for warm up (1 file) and benchmarking (11 files)
132- train_files = sorted (train_files , key = lambda x : os .path .getsize (x ), reverse = False )[:21 ]
135+ # since the dataset is too large, the smallest 31 files are used for warm up (1 file) and benchmarking (30 files)
136+ train_files = sorted (train_files , key = lambda x : os .path .getsize (x ), reverse = False )[:31 ]
133137 weights_path = prepare_model_weights (root_dir = root_dir , bundle_name = "wholeBody_ct_segmentation" )
134138 trt_model_name = "model_trt.ts"
135139 trt_model_path = prepare_tensorrt_model (root_dir , weights_path , trt_model_name )
@@ -146,5 +150,6 @@ def main():
146150 df = pd .DataFrame (list (total_time_dict .items ()), columns = ["file_name" , "time" ])
147151 df .to_csv (os .path .join (root_dir , f"time_{ args .benchmark_type } .csv" ), index = False )
148152
153+
149154if __name__ == "__main__" :
150155 main ()
0 commit comments