Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Commit 4fe1cb9

Browse files
fix runtime error with one gpu (#53)
1 parent 6b9b96f commit 4fe1cb9

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

bloom-inference-scripts/bloom-accelerate-inference.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66

77
import torch
8+
import torch.distributed as dist
89

910
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
1011

@@ -57,9 +58,20 @@ def print_rank0(*msg):
5758
dtype = torch.int8
5859

5960
kwargs = dict(
60-
device_map="balanced_low_0",
61+
device_map="auto",
6162
)
6263

64+
def get_world_size() -> int:
65+
if dist.is_initialized():
66+
return dist.get_world_size()
67+
else:
68+
return 1
69+
70+
# balanced_low_0 - because it allows a larger batch size with multiple GPUs
71+
if get_world_size() > 1:
72+
kwargs["device_map"] = "balanced_low_0"
73+
74+
6375
if infer_dtype == "int8":
6476
print_rank0("Using `load_in_8bit=True` to use quanitized model")
6577
kwargs["load_in_8bit"] = True

0 commit comments

Comments
 (0)