|
72 | 72 | "!python -c \"import matplotlib\" || pip install -q matplotlib\n", |
73 | 73 | "!python -c \"import torch_tensorrt\" || pip install torch_tensorrt\n", |
74 | 74 | "!python -c \"import kvikio\" || pip install kvikio-cu12\n", |
75 | | - "!python -c \"import ignite\" || pip install pytorch-ignite\n", |
76 | 75 | "!python -c \"import pandas\" || pip install pandas\n", |
77 | 76 | "!python -c \"import requests\" || pip install requests\n", |
78 | | - "!python -c \"import fire\" || pip install fire\n", |
79 | 77 | "!python -c \"import onnx\" || pip install onnx\n", |
80 | 78 | "%matplotlib inline" |
81 | 79 | ] |
|
106 | 104 | " Spacingd,\n", |
107 | 105 | " NormalizeIntensityd,\n", |
108 | 106 | " ScaleIntensityd,\n", |
109 | | - " Invertd,\n", |
110 | | - " Activationsd,\n", |
111 | | - " AsDiscreted,\n", |
112 | 107 | " Compose,\n", |
113 | 108 | ")\n", |
114 | 109 | "from monai.inferers import sliding_window_inference\n", |
115 | 110 | "from monai.networks.nets import SegResNet\n", |
| 111 | + "import matplotlib.pyplot as plt\n", |
116 | 112 | "import torch\n", |
| 113 | + "import gc\n", |
117 | 114 | "import pandas as pd\n", |
118 | 115 | "from timeit import default_timer as timer\n", |
119 | 116 | "\n", |
120 | | - "print(f\"Torch-TensorRT version: {torch_tensorrt.__version__}.\")\n", |
121 | | - "\n", |
122 | 117 | "print_config()" |
123 | 118 | ] |
124 | 119 | }, |
|
163 | 158 | " precision=\"fp16\",\n", |
164 | 159 | " input_shape=[1, 1, 96, 96, 96],\n", |
165 | 160 | " dynamic_batchsize=[1, 1, 1],\n", |
166 | | - " use_trace=False,\n", |
167 | | - " verify=True,\n", |
| 161 | + " use_trace=True,\n", |
| 162 | + " verify=False,\n", |
168 | 163 | ")\n", |
169 | 164 | "\n", |
170 | 165 | "save_net_with_metadata(torchscript_model, \"segresnet_trt\")\n", |
|
236 | 231 | "\n", |
237 | 232 | "A variable `benchmark_type` is used to specify the type of benchmark to run. To have a fair comparison, each benchmark type should be run after restarting the notebook kernel. `benchmark_type` can be one of the following:\n", |
238 | 233 | "\n", |
239 | | - "- `\"original\"`: benchmark the original model inference (with `amp` enabled).\n", |
| 234 | + "- `\"original\"`: benchmark the original model inference.\n", |
240 | 235 | "- `\"trt\"`: benchmark the TensorRT accelerated model inference.\n", |
241 | | - "- `\"trt_gpu_transforms\"`: benchmark the TensorRT accelerated model inference with GPU transforms.\n", |
242 | | - "- `\"trt_gds_gpu_transforms\"`: benchmark the TensorRT accelerated model inference with GPU data loading and GPU transforms." |
| 236 | + "- `\"trt_gpu_transforms\"`: benchmark the model inference with GPU transforms.\n", |
| 237 | + "- `\"trt_gds_gpu_transforms\"`: benchmark the model inference with GPU data loading and GPU transforms." |
243 | 238 | ] |
244 | 239 | }, |
245 | 240 | { |
246 | 241 | "cell_type": "code", |
247 | | - "execution_count": 3, |
| 242 | + "execution_count": 4, |
248 | 243 | "metadata": {}, |
249 | 244 | "outputs": [], |
250 | 245 | "source": [ |
|
276 | 271 | "from utils import prepare_test_datalist, prepare_model_weights, prepare_tensorrt_model\n", |
277 | 272 | "\n", |
278 | 273 | "root_dir = \".\"\n", |
| 274 | + "torch.backends.cudnn.benchmark = True\n", |
| 275 | + "torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n", |
279 | 276 | "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", |
280 | 277 | "train_files = prepare_test_datalist(root_dir)\n", |
| 278 | + "# since the dataset is too large, the smallest 21 files are used for warm up (1 file) and benchmarking (11 files)\n", |
| 279 | + "train_files = sorted(train_files, key=lambda x: os.path.getsize(x), reverse=False)[:21]\n", |
281 | 280 | "weights_path = prepare_model_weights(root_dir=root_dir, bundle_name=\"wholeBody_ct_segmentation\")\n", |
282 | 281 | "trt_model_name = \"model_trt.ts\"\n", |
283 | 282 | "trt_model_path = prepare_tensorrt_model(root_dir, weights_path, trt_model_name)" |
|
292 | 291 | }, |
293 | 292 | { |
294 | 293 | "cell_type": "code", |
295 | | - "execution_count": 5, |
| 294 | + "execution_count": 6, |
296 | 295 | "metadata": {}, |
297 | 296 | "outputs": [], |
298 | 297 | "source": [ |
|
317 | 316 | " return infer_transforms\n", |
318 | 317 | "\n", |
319 | 318 | "\n", |
320 | | - "def get_post_transforms(infer_transforms):\n", |
321 | | - " post_transforms = Compose(\n", |
322 | | - " [\n", |
323 | | - " Activationsd(keys=\"pred\", softmax=True),\n", |
324 | | - " AsDiscreted(keys=\"pred\", argmax=True),\n", |
325 | | - " Invertd(\n", |
326 | | - " keys=\"pred\",\n", |
327 | | - " transform=infer_transforms,\n", |
328 | | - " orig_keys=\"image\",\n", |
329 | | - " nearest_interp=True,\n", |
330 | | - " to_tensor=True,\n", |
331 | | - " ),\n", |
332 | | - " ]\n", |
333 | | - " )\n", |
334 | | - " return post_transforms\n", |
335 | | - "\n", |
336 | | - "\n", |
337 | 319 | "def get_model(device, weights_path, trt_model_path, trt_flag=False):\n", |
338 | 320 | " if not trt_flag:\n", |
339 | 321 | " model = SegResNet(\n", |
|
364 | 346 | }, |
365 | 347 | { |
366 | 348 | "cell_type": "code", |
367 | | - "execution_count": 6, |
| 349 | + "execution_count": 7, |
368 | 350 | "metadata": {}, |
369 | 351 | "outputs": [], |
370 | 352 | "source": [ |
371 | | - "def run_inference(data_list, infer_transforms, post_transforms, model, device, benchmark_type):\n", |
| 353 | + "def run_inference(data_list, infer_transforms, model, device, benchmark_type):\n", |
372 | 354 | " total_time_dict = {}\n", |
373 | 355 | " roi_size = (96, 96, 96)\n", |
374 | 356 | " sw_batch_size = 1\n", |
375 | | - "\n", |
376 | | - " for idx, sample in enumerate(data_list[:5]):\n", |
| 357 | + " \n", |
| 358 | + " for idx, sample in enumerate(data_list[:10]):\n", |
377 | 359 | " start = timer()\n", |
378 | 360 | " data = infer_transforms({\"image\": sample})\n", |
379 | 361 | "\n", |
|
383 | 365 | " if benchmark_type in [\"trt\", \"original\"]\n", |
384 | 366 | " else data[\"image\"].unsqueeze(0)\n", |
385 | 367 | " )\n", |
386 | | - " if benchmark_type == \"original\":\n", |
387 | | - " with torch.autocast(device_type=\"cuda\"):\n", |
388 | | - " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
389 | | - " else:\n", |
390 | | - " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
391 | 368 | "\n", |
392 | | - " data[\"pred\"] = output_image.squeeze(0)\n", |
393 | | - " # data = post_transforms(data)\n", |
| 369 | + " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
| 370 | + " output_image = output_image.cpu()\n", |
394 | 371 | "\n", |
395 | 372 | " end = timer()\n", |
396 | 373 | "\n", |
| 374 | + " print(output_image.mean())\n", |
| 375 | + "\n", |
| 376 | + " del data\n", |
| 377 | + " del input_image\n", |
| 378 | + " del output_image\n", |
| 379 | + " torch.cuda.empty_cache()\n", |
| 380 | + " gc.collect()\n", |
| 381 | + "\n", |
397 | 382 | " sample_name = sample.split(\"/\")[-1]\n", |
398 | 383 | " if idx > 0:\n", |
399 | 384 | " total_time_dict[sample_name] = end - start\n", |
400 | | - "\n", |
| 385 | + " print(end - start)\n", |
401 | 386 | " return total_time_dict" |
402 | 387 | ] |
403 | 388 | }, |
404 | 389 | { |
405 | 390 | "cell_type": "markdown", |
406 | 391 | "metadata": {}, |
407 | 392 | "source": [ |
408 | | - "## Benchmark the end-to-end bundle inference" |
| 393 | + "### Running the Benchmark\n", |
| 394 | + "\n", |
| 395 | + "The cell below will execute the benchmark based on the `benchmark_type` variable.\n", |
| 396 | + "\n", |
| 397 | + "#### Optional: Using the Python Script\n", |
| 398 | + "\n", |
| 399 | + "For convenience, a Python script, [`run_benchmark.py`](./run_benchmark.py), is available to run the benchmark. You can open a terminal and execute the following command to run the benchmark for all benchmark types:\n", |
| 400 | + "\n", |
| 401 | + "\n", |
| 402 | + "```bash\n", |
| 403 | + "for benchmark_type in \"original\" \"trt\" \"trt_gpu_transforms\" \"trt_gds_gpu_transforms\"; do\n", |
| 404 | + " python run_benchmark.py --benchmark_type \"$benchmark_type\"\n", |
| 405 | + "done\n", |
| 406 | + "```" |
409 | 407 | ] |
410 | 408 | }, |
411 | 409 | { |
|
426 | 424 | " gpu_loading_flag = True\n", |
427 | 425 | "\n", |
428 | 426 | "infer_transforms = get_transforms(device, gpu_loading_flag, gpu_transforms_flag)\n", |
429 | | - "post_transforms = get_post_transforms(infer_transforms)\n", |
430 | 427 | "model = get_model(device, weights_path, trt_model_path, trt_flag)\n", |
431 | 428 | "\n", |
432 | | - "total_time_dict = run_inference(train_files, infer_transforms, post_transforms, model, device, benchmark_type)" |
| 429 | + "total_time_dict = run_inference(train_files, infer_transforms, model, device, benchmark_type)\n", |
| 430 | + "\n", |
| 431 | + "df = pd.DataFrame(list(total_time_dict.items()), columns=[\"file_name\", \"time\"])\n", |
| 432 | + "df.to_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"), index=False)" |
| 433 | + ] |
| 434 | + }, |
| 435 | + { |
| 436 | + "cell_type": "markdown", |
| 437 | + "metadata": {}, |
| 438 | + "source": [ |
| 439 | + "## Analyze and Visualize the Results\n", |
| 440 | + "\n", |
| 441 | + "In this section, we will analyze and visualize the results.\n", |
| 442 | + "All cell outputs presented in this section were obtained by a NVIDIA RTX A6000 GPU." |
| 443 | + ] |
| 444 | + }, |
| 445 | + { |
| 446 | + "cell_type": "markdown", |
| 447 | + "metadata": {}, |
| 448 | + "source": [ |
| 449 | + "### Collect Benchmark Results" |
433 | 450 | ] |
434 | 451 | }, |
435 | 452 | { |
436 | 453 | "cell_type": "code", |
437 | | - "execution_count": 8, |
| 454 | + "execution_count": 18, |
438 | 455 | "metadata": {}, |
439 | 456 | "outputs": [], |
440 | 457 | "source": [ |
441 | | - "df = pd.DataFrame(list(total_time_dict.items()), columns=[\"file_name\", \"time\"])\n", |
442 | | - "df.to_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"), index=False)" |
| 458 | + "# collect benchmark results\n", |
| 459 | + "all_df = pd.read_csv(os.path.join(root_dir, f\"time_original.csv\"))\n", |
| 460 | + "all_df.columns = [\"file_name\", \"original_time\"]\n", |
| 461 | + "\n", |
| 462 | + "for benchmark_type in [\"trt\", \"trt_gpu_transforms\", \"trt_gds_gpu_transforms\"]:\n", |
| 463 | + " df = pd.read_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"))\n", |
| 464 | + " df.columns = [\"file_name\", f\"{benchmark_type}_time\"]\n", |
| 465 | + " all_df = pd.merge(all_df, df, on=\"file_name\", how=\"left\")\n", |
| 466 | + "\n", |
| 467 | + "# for each file, add it's size\n", |
| 468 | + "all_df[\"file_size\"] = all_df[\"file_name\"].apply(lambda x: os.path.getsize(os.path.join(root_dir, \"Task03_Liver\", \"imagesTs_nii\", x)))\n", |
| 469 | + "# sort by file size\n", |
| 470 | + "all_df = all_df.sort_values(by=\"file_size\", ascending=True)\n", |
| 471 | + "# convert file size to MB\n", |
| 472 | + "all_df[\"file_size\"] = all_df[\"file_size\"] / 1024 / 1024\n", |
| 473 | + "# get the average time for each benchmark type\n", |
| 474 | + "average_time = all_df.mean(numeric_only=True)\n", |
| 475 | + "del average_time[\"file_size\"]" |
| 476 | + ] |
| 477 | + }, |
| 478 | + { |
| 479 | + "cell_type": "markdown", |
| 480 | + "metadata": {}, |
| 481 | + "source": [ |
| 482 | + "### Visualize Average Inference Time for Each Benchmark Type" |
| 483 | + ] |
| 484 | + }, |
| 485 | + { |
| 486 | + "cell_type": "code", |
| 487 | + "execution_count": null, |
| 488 | + "metadata": {}, |
| 489 | + "outputs": [], |
| 490 | + "source": [ |
| 491 | + "plt.figure(figsize=(10, 6))\n", |
| 492 | + "average_time.plot(kind='bar', color=['skyblue', 'orange', 'green', 'red'])\n", |
| 493 | + "plt.title('Average Inference Time for Each Benchmark Type')\n", |
| 494 | + "plt.xlabel('Benchmark Type')\n", |
| 495 | + "plt.ylabel('Average Time (seconds)')\n", |
| 496 | + "plt.xticks(rotation=45)\n", |
| 497 | + "plt.tight_layout()\n", |
| 498 | + "plt.show()" |
443 | 499 | ] |
| 500 | + }, |
| 501 | + { |
| 502 | + "cell_type": "code", |
| 503 | + "execution_count": null, |
| 504 | + "metadata": {}, |
| 505 | + "outputs": [], |
| 506 | + "source": [] |
444 | 507 | } |
445 | 508 | ], |
446 | 509 | "metadata": { |
|
0 commit comments