diff --git a/beginner_source/ddp_series_multigpu.rst b/beginner_source/ddp_series_multigpu.rst index ef6549d4de..f0f67e5930 100644 --- a/beginner_source/ddp_series_multigpu.rst +++ b/beginner_source/ddp_series_multigpu.rst @@ -202,6 +202,7 @@ Running the distributed training job Here's what the code looks like: .. code-block:: python + def main(rank, world_size, total_epochs, save_every): ddp_setup(rank, world_size) dataset, model, optimizer = load_train_objs() @@ -218,7 +219,6 @@ Here's what the code looks like: mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size) - Further Reading ---------------