Skip to content

Commit 75fe897

Browse files
committed
fixed merge conflicts
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
2 parents 462122b + 1a77f63 commit 75fe897

22 files changed

+2237
-328
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ aiu-fms-testing-utils.egg-info
77
*/**/*.pyc
88
.vscode
99
aiu-fms-testing-utils.egg-info
10+
export_deeprt
11+
export_dtcompiler
12+
*.egg-info

aiu_fms_testing_utils/testing/validation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import List, Tuple, Callable, MutableMapping, Any, Optional
33

44
import torch
5-
from fms.utils.generation import generate
65
from aiu_fms_testing_utils.utils import ids_for_prompt
76
from aiu_fms_testing_utils.utils.aiu_setup import dprint
87
import os
@@ -188,11 +187,19 @@ def load_validation_information(validation_path, validation_files_type, batch_si
188187

189188
return ValidationInfo(validation_info)
190189

191-
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **padding_kwargs):
190+
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **extra_kwargs):
192191
max_seq_len = model.config.max_expected_seq_len
192+
attention_specific_kwargs = {}
193+
if "paged" in extra_kwargs["attn_name"]:
194+
from aiu_fms_testing_utils.utils.paged import generate
195+
else:
196+
# TODO: Add a unified generation dependent on attn_type
197+
from fms.utils.generation import generate
198+
attention_specific_kwargs["contiguous_cache"] = True
199+
attention_specific_kwargs["max_seq_len"] = max_seq_len
193200

194201
# Add only_last_token optimization
195-
extra_generation_kwargs = {**padding_kwargs}
202+
extra_generation_kwargs = {**extra_kwargs}
196203
if only_last_token:
197204
extra_generation_kwargs["only_last_token"] = only_last_token
198205
if attn_algorithm is not None:
@@ -204,12 +211,11 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
204211
max_new_tokens=max_new_tokens,
205212
use_cache=True,
206213
do_sample=False,
207-
max_seq_len=max_seq_len,
208214
post_iteration_hook=post_iteration_hook,
209215
eos_token_id=eos_token_id,
210216
timing=timing,
211-
contiguous_cache=True,
212217
extra_kwargs=extra_generation_kwargs,
218+
**attention_specific_kwargs
213219
)
214220

215221
if timing != "":

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,64 @@
1-
import torch
2-
import torch.nn as nn
3-
import time
4-
from fms.utils.tokenizers import BaseTokenizer
5-
from fms.utils.generation import generate
6-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1+
# Standard
72
from typing import Optional, List, Tuple
8-
import os
9-
import requests
103
import json
4+
import os
115
import random
6+
import requests
7+
import time
8+
9+
# Third Party
10+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
11+
from fms.utils.tokenizers import BaseTokenizer
12+
import torch
13+
import torch.nn as nn
1214

13-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, **padding_kwargs):
15+
16+
def warmup_model(
17+
model: nn.Module,
18+
input_ids: torch.Tensor,
19+
max_new_tokens: int,
20+
compile_dynamic_sendnn: bool = False,
21+
use_cache: bool = True,
22+
**extra_kwargs
23+
):
1424
import torch_sendnn
25+
attention_specific_kwargs = {}
26+
attn_name = extra_kwargs["attn_name"]
27+
if "paged" in attn_name:
28+
from aiu_fms_testing_utils.utils.paged import generate, adjust_inputs_to_batch
29+
else:
30+
# TODO: Add a unified generation dependent on attn_type
31+
from fms.utils.generation import generate
32+
attention_specific_kwargs["contiguous_cache"] = True
33+
1534
dprint("AIU warmup")
1635
pt_compile_model_time = time.time()
17-
extra_kwargs = {**padding_kwargs, "only_last_token": True}
18-
max_new_tokens_warmup = max_new_tokens
36+
37+
# adjust inputs depending on attn_type and dynamic shapes
38+
_warmup_input_ids = input_ids
39+
_extra_kwargs = extra_kwargs
40+
_max_new_tokens = max_new_tokens
1941
if compile_dynamic_sendnn:
20-
max_new_tokens_warmup = 2
42+
_max_new_tokens = 2
43+
# always warmup with batch size 2 when using attn_type=paged
44+
if "paged" in attn_name:
45+
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(
46+
input_ids,
47+
**extra_kwargs,
48+
)
49+
50+
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
51+
2152
with torch_sendnn.warmup_mode():
22-
generate(model, input_ids, max_new_tokens=max_new_tokens_warmup, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
53+
generate(
54+
model,
55+
_warmup_input_ids,
56+
max_new_tokens=_max_new_tokens,
57+
do_sample=False,
58+
use_cache=use_cache,
59+
extra_kwargs=extra_kwargs,
60+
**attention_specific_kwargs,
61+
)
2362
pt_compile_model_time = time.time() - pt_compile_model_time
2463
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
2564

@@ -35,17 +74,17 @@ def __download_file(url, filename):
3574
try:
3675
response = requests.get(url, stream=True)
3776
response.raise_for_status()
38-
77+
3978
with open(filename, 'wb') as file:
4079
for chunk in response.iter_content(chunk_size=8192):
4180
file.write(chunk)
4281
print(f"Successfully downloaded {filename}")
43-
82+
4483
except requests.exceptions.RequestException as e:
4584
print(f"An error occurred: {e}")
4685

4786
def __sample_requests(
48-
prompt_list: List[str],
87+
prompt_list: List[str],
4988
num_requests: int,
5089
tokenizer: BaseTokenizer,
5190
prompt_length_min: int = 32,
@@ -65,16 +104,14 @@ def __sample_requests(
65104
# Tokenize the prompts and completions.
66105
prompt = prompt_list[i]
67106
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
68-
107+
69108
prompt_len = len(prompt_token_ids)
70109
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
71110
# Prune too short or too long sequences.
72111
continue
73112
filtered_dataset.append((prompt, prompt_len))
74-
75-
return filtered_dataset
76-
77113

114+
return filtered_dataset
78115

79116
def sample_sharegpt_requests(
80117
dataset_path: str,
@@ -94,15 +131,22 @@ def sample_sharegpt_requests(
94131
# Filter out the conversations with less than 2 turns.
95132
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
96133
dataset = [data["conversations"][0]["value"] for data in dataset]
97-
98-
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
134+
135+
return __sample_requests(
136+
dataset,
137+
num_requests,
138+
tokenizer,
139+
prompt_length_min,
140+
prompt_length_max,
141+
seed,
142+
)
99143

100144
def sample_squad_v2_qa_requests(
101145
dataset_path: str,
102-
num_requests: int,
103-
tokenizer: BaseTokenizer,
104-
prompt_length_min: int = 32,
105-
prompt_length_max: int = 64,
146+
num_requests: int,
147+
tokenizer: BaseTokenizer,
148+
prompt_length_min: int = 32,
149+
prompt_length_max: int = 64,
106150
seed: Optional[int] = None
107151
) -> List[Tuple[str, int]]:
108152
from datasets import load_dataset
@@ -111,10 +155,14 @@ def sample_squad_v2_qa_requests(
111155
ds = load_dataset(dataset_path)['train']
112156
else:
113157
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
114-
115-
116-
ds = [f"{data['context']}\n{data['question']}" for data in ds]
117158

118-
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
119-
159+
ds = [f"{data['context']}\n{data['question']}" for data in ds]
120160

161+
return __sample_requests(
162+
ds,
163+
num_requests,
164+
tokenizer,
165+
prompt_length_min,
166+
prompt_length_max,
167+
seed,
168+
)

aiu_fms_testing_utils/utils/aiu_setup.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import argparse
12
import os
3+
import torch
24

35
# ==============================================================
46
# Common utilities
@@ -67,3 +69,55 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False
6769
dprint(f"Detected running via torchrun")
6870

6971
aiu_setup(rank, world_size)
72+
73+
74+
# ==============================================================
75+
# Environment variables utilities
76+
# ==============================================================
77+
def set_aiu_env_vars(args: argparse.Namespace) -> None:
78+
"""Set necessary environment variables for AIU"""
79+
80+
if not args.compile_dynamic:
81+
_target_cache_size = max(
82+
int(args.max_new_tokens * 2),
83+
int(args.min_pad_length * 2.5),
84+
int(args.fixed_prompt_length * 2.5),
85+
)
86+
_prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length))
87+
if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"):
88+
if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit:
89+
_prev = torch._dynamo.config.accumulated_cache_size_limit
90+
torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size
91+
dprint(
92+
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit "
93+
f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} "
94+
f"to accomodate prompt size of {_prompt_size} and decode tokens of "
95+
f"{args.max_new_tokens}"
96+
)
97+
98+
if _target_cache_size > torch._dynamo.config.cache_size_limit:
99+
_prev = torch._dynamo.config.cache_size_limit
100+
torch._dynamo.config.cache_size_limit = _target_cache_size
101+
dprint(
102+
f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to "
103+
f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of "
104+
f"{_prompt_size} and decode tokens of {args.max_new_tokens}"
105+
)
106+
107+
torch._dynamo.config.assume_static_by_default = True
108+
torch._dynamo.config.automatic_dynamic_shapes = False
109+
110+
# os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") # CONFIRM IF THIS IS NEEDE
111+
112+
if not args.is_encoder:
113+
os.environ.setdefault("COMPILATION_MODE", "offline_decoder")
114+
115+
if args.device_type == "aiu-senulator":
116+
os.environ["FLEX_COMPUTE"] = "SENULATOR"
117+
os.environ["FLEX_DEVICE"] = "MOCK"
118+
else:
119+
if "AIU_WORLD_RANK_0" not in os.environ:
120+
print("must set AIU_WORLD_RANK_0")
121+
exit()
122+
os.environ.setdefault("FLEX_COMPUTE", "SENTIENT")
123+
os.environ.setdefault("FLEX_DEVICE", "PF") # will use VF eventually

0 commit comments

Comments
 (0)