1- import torch
2- import torch .nn as nn
3- import time
4- from fms .utils .tokenizers import BaseTokenizer
5- from fms .utils .generation import generate
6- from aiu_fms_testing_utils .utils .aiu_setup import dprint
1+ # Standard
72from typing import Optional , List , Tuple
8- import os
9- import requests
103import json
4+ import os
115import random
6+ import requests
7+ import time
8+
9+ # Third Party
10+ from aiu_fms_testing_utils .utils .aiu_setup import dprint
11+ from fms .utils .tokenizers import BaseTokenizer
12+ import torch
13+ import torch .nn as nn
1214
13- def warmup_model (model : nn .Module , input_ids : torch .Tensor , max_new_tokens : int , compile_dynamic_sendnn = False , ** padding_kwargs ):
15+
16+ def warmup_model (
17+ model : nn .Module ,
18+ input_ids : torch .Tensor ,
19+ max_new_tokens : int ,
20+ compile_dynamic_sendnn : bool = False ,
21+ use_cache : bool = True ,
22+ ** extra_kwargs
23+ ):
1424 import torch_sendnn
25+ attention_specific_kwargs = {}
26+ attn_name = extra_kwargs ["attn_name" ]
27+ if "paged" in attn_name :
28+ from aiu_fms_testing_utils .utils .paged import generate , adjust_inputs_to_batch
29+ else :
30+ # TODO: Add a unified generation dependent on attn_type
31+ from fms .utils .generation import generate
32+ attention_specific_kwargs ["contiguous_cache" ] = True
33+
1534 dprint ("AIU warmup" )
1635 pt_compile_model_time = time .time ()
17- extra_kwargs = {** padding_kwargs , "only_last_token" : True }
18- max_new_tokens_warmup = max_new_tokens
36+
37+ # adjust inputs depending on attn_type and dynamic shapes
38+ _warmup_input_ids = input_ids
39+ _extra_kwargs = extra_kwargs
40+ _max_new_tokens = max_new_tokens
1941 if compile_dynamic_sendnn :
20- max_new_tokens_warmup = 2
42+ _max_new_tokens = 2
43+ # always warmup with batch size 2 when using attn_type=paged
44+ if "paged" in attn_name :
45+ _warmup_input_ids , _extra_kwargs = adjust_inputs_to_batch (
46+ input_ids ,
47+ ** extra_kwargs ,
48+ )
49+
50+ extra_kwargs = {** _extra_kwargs , "only_last_token" : "paged" not in attn_name }
51+
2152 with torch_sendnn .warmup_mode ():
22- generate (model , input_ids , max_new_tokens = max_new_tokens_warmup , max_seq_len = model .config .max_expected_seq_len , use_cache = True , do_sample = False , contiguous_cache = True , extra_kwargs = extra_kwargs )
53+ generate (
54+ model ,
55+ _warmup_input_ids ,
56+ max_new_tokens = _max_new_tokens ,
57+ do_sample = False ,
58+ use_cache = use_cache ,
59+ extra_kwargs = extra_kwargs ,
60+ ** attention_specific_kwargs ,
61+ )
2362 pt_compile_model_time = time .time () - pt_compile_model_time
2463 dprint (f"PT compile complete, took { pt_compile_model_time :.3f} s" )
2564
@@ -35,17 +74,17 @@ def __download_file(url, filename):
3574 try :
3675 response = requests .get (url , stream = True )
3776 response .raise_for_status ()
38-
77+
3978 with open (filename , 'wb' ) as file :
4079 for chunk in response .iter_content (chunk_size = 8192 ):
4180 file .write (chunk )
4281 print (f"Successfully downloaded { filename } " )
43-
82+
4483 except requests .exceptions .RequestException as e :
4584 print (f"An error occurred: { e } " )
4685
4786def __sample_requests (
48- prompt_list : List [str ],
87+ prompt_list : List [str ],
4988 num_requests : int ,
5089 tokenizer : BaseTokenizer ,
5190 prompt_length_min : int = 32 ,
@@ -65,16 +104,14 @@ def __sample_requests(
65104 # Tokenize the prompts and completions.
66105 prompt = prompt_list [i ]
67106 prompt_token_ids = ids_for_prompt (prompt , tokenizer )
68-
107+
69108 prompt_len = len (prompt_token_ids )
70109 if prompt_len < prompt_length_min or prompt_len > prompt_length_max :
71110 # Prune too short or too long sequences.
72111 continue
73112 filtered_dataset .append ((prompt , prompt_len ))
74-
75- return filtered_dataset
76-
77113
114+ return filtered_dataset
78115
79116def sample_sharegpt_requests (
80117 dataset_path : str ,
@@ -94,15 +131,22 @@ def sample_sharegpt_requests(
94131 # Filter out the conversations with less than 2 turns.
95132 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
96133 dataset = [data ["conversations" ][0 ]["value" ] for data in dataset ]
97-
98- return __sample_requests (dataset , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
134+
135+ return __sample_requests (
136+ dataset ,
137+ num_requests ,
138+ tokenizer ,
139+ prompt_length_min ,
140+ prompt_length_max ,
141+ seed ,
142+ )
99143
100144def sample_squad_v2_qa_requests (
101145 dataset_path : str ,
102- num_requests : int ,
103- tokenizer : BaseTokenizer ,
104- prompt_length_min : int = 32 ,
105- prompt_length_max : int = 64 ,
146+ num_requests : int ,
147+ tokenizer : BaseTokenizer ,
148+ prompt_length_min : int = 32 ,
149+ prompt_length_max : int = 64 ,
106150 seed : Optional [int ] = None
107151) -> List [Tuple [str , int ]]:
108152 from datasets import load_dataset
@@ -111,10 +155,14 @@ def sample_squad_v2_qa_requests(
111155 ds = load_dataset (dataset_path )['train' ]
112156 else :
113157 ds = load_dataset ("rajpurkar/squad_v2" , cache_dir = dataset_path )['train' ]
114-
115-
116- ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
117158
118- return __sample_requests (ds , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
119-
159+ ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
120160
161+ return __sample_requests (
162+ ds ,
163+ num_requests ,
164+ tokenizer ,
165+ prompt_length_min ,
166+ prompt_length_max ,
167+ seed ,
168+ )
0 commit comments