@@ -37,17 +37,17 @@ def __download_file(url, filename):
3737 try :
3838 response = requests .get (url , stream = True )
3939 response .raise_for_status ()
40-
40+
4141 with open (filename , 'wb' ) as file :
4242 for chunk in response .iter_content (chunk_size = 8192 ):
4343 file .write (chunk )
4444 print (f"Successfully downloaded { filename } " )
45-
45+
4646 except requests .exceptions .RequestException as e :
4747 print (f"An error occurred: { e } " )
4848
4949def __sample_requests (
50- prompt_list : List [str ],
50+ prompt_list : List [str ],
5151 num_requests : int ,
5252 tokenizer : BaseTokenizer ,
5353 prompt_length_min : int = 32 ,
@@ -67,15 +67,15 @@ def __sample_requests(
6767 # Tokenize the prompts and completions.
6868 prompt = prompt_list [i ]
6969 prompt_token_ids = ids_for_prompt (prompt , tokenizer )
70-
70+
7171 prompt_len = len (prompt_token_ids )
7272 if prompt_len < prompt_length_min or prompt_len > prompt_length_max :
7373 # Prune too short or too long sequences.
7474 continue
7575 filtered_dataset .append ((prompt , prompt_len ))
76-
76+
7777 return filtered_dataset
78-
78+
7979
8080
8181def sample_sharegpt_requests (
@@ -86,7 +86,7 @@ def sample_sharegpt_requests(
8686 prompt_length_max : int = 64 ,
8787 seed : Optional [int ] = None
8888) -> List [Tuple [str , int ]]:
89- if not os .path .exists (dataset_path ):
89+ if not dataset_path or not os .path .exists (dataset_path ):
9090 print ("downloading share-gpt dataset as it does not exist" )
9191 __download_file ("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" , dataset_path )
9292
@@ -96,15 +96,15 @@ def sample_sharegpt_requests(
9696 # Filter out the conversations with less than 2 turns.
9797 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
9898 dataset = [data ["conversations" ][0 ]["value" ] for data in dataset ]
99-
99+
100100 return __sample_requests (dataset , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
101101
102102def sample_squad_v2_qa_requests (
103103 dataset_path : str ,
104- num_requests : int ,
105- tokenizer : BaseTokenizer ,
106- prompt_length_min : int = 32 ,
107- prompt_length_max : int = 64 ,
104+ num_requests : int ,
105+ tokenizer : BaseTokenizer ,
106+ prompt_length_min : int = 32 ,
107+ prompt_length_max : int = 64 ,
108108 seed : Optional [int ] = None
109109) -> List [Tuple [str , int ]]:
110110 from datasets import load_dataset
@@ -113,10 +113,10 @@ def sample_squad_v2_qa_requests(
113113 ds = load_dataset (dataset_path )['train' ]
114114 else :
115115 ds = load_dataset ("rajpurkar/squad_v2" , cache_dir = dataset_path )['train' ]
116-
117-
116+
117+
118118 ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
119119
120120 return __sample_requests (ds , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
121-
121+
122122
0 commit comments