diff --git a/aiu_fms_testing_utils/scripts/inference.py b/aiu_fms_testing_utils/scripts/inference.py index 3ec33f0e..b66b3b4c 100644 --- a/aiu_fms_testing_utils/scripts/inference.py +++ b/aiu_fms_testing_utils/scripts/inference.py @@ -257,6 +257,12 @@ default=0, help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group", ) +parser.add_argument( + "--numa", + action="store_true", + help="NUMA aware task distribution (requires distributed option)", +) + args = parser.parse_args() attention_map = { @@ -327,6 +333,21 @@ dist.init_process_group() # Fix until PT 2.3 torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + if args.numa: + try: + from numa import info + numa_num_nodes = info.get_num_configured_nodes() + numa_world_size = dist.get_world_size() + numa_size_per_node = numa_world_size // numa_num_nodes + from numa import schedule + numa_rank = dist.get_rank() + numa_node = dist.get_rank() // numa_size_per_node + schedule.run_on_nodes(numa_node) + from numa import memory + memory.set_local_alloc() + dprint(f"NUMA: process {numa_rank} set to node {numa_node}") + except: + dprint(f"NUMA not available in this machine, please install libnuma libraries") aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size()) if args.device_type == "cuda":