|
109 | 109 | " Invertd,\n", |
110 | 110 | " Activationsd,\n", |
111 | 111 | " AsDiscreted,\n", |
112 | | - " Compose\n", |
| 112 | + " Compose,\n", |
113 | 113 | ")\n", |
114 | 114 | "from monai.inferers import sliding_window_inference\n", |
115 | 115 | "from monai.networks.nets import SegResNet\n", |
|
316 | 316 | "\n", |
317 | 317 | " return infer_transforms\n", |
318 | 318 | "\n", |
| 319 | + "\n", |
319 | 320 | "def get_post_transforms(infer_transforms):\n", |
320 | 321 | " post_transforms = Compose(\n", |
321 | 322 | " [\n", |
|
332 | 333 | " )\n", |
333 | 334 | " return post_transforms\n", |
334 | 335 | "\n", |
| 336 | + "\n", |
335 | 337 | "def get_model(device, weights_path, trt_model_path, trt_flag=False):\n", |
336 | 338 | " if not trt_flag:\n", |
337 | 339 | " model = SegResNet(\n", |
|
376 | 378 | " data = infer_transforms({\"image\": sample})\n", |
377 | 379 | "\n", |
378 | 380 | " with torch.no_grad():\n", |
379 | | - " input_image = data[\"image\"].unsqueeze(0).to(device) if benchmark_type in [\"trt\", \"original\"] else data[\"image\"].unsqueeze(0)\n", |
| 381 | + " input_image = (\n", |
| 382 | + " data[\"image\"].unsqueeze(0).to(device)\n", |
| 383 | + " if benchmark_type in [\"trt\", \"original\"]\n", |
| 384 | + " else data[\"image\"].unsqueeze(0)\n", |
| 385 | + " )\n", |
380 | 386 | " if benchmark_type == \"original\":\n", |
381 | 387 | " with torch.autocast(device_type=\"cuda\"):\n", |
382 | 388 | " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
383 | 389 | " else:\n", |
384 | 390 | " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
385 | | - " \n", |
| 391 | + "\n", |
386 | 392 | " data[\"pred\"] = output_image.squeeze(0)\n", |
387 | 393 | " # data = post_transforms(data)\n", |
388 | | - " \n", |
| 394 | + "\n", |
389 | 395 | " end = timer()\n", |
390 | 396 | "\n", |
391 | 397 | " sample_name = sample.split(\"/\")[-1]\n", |
|
0 commit comments