Skip to content

Commit b18373e

Browse files
committed
🐛 Fix default for sharegpt dataset path
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent adda790 commit b18373e

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def __download_file(url, filename):
3737
try:
3838
response = requests.get(url, stream=True)
3939
response.raise_for_status()
40-
40+
4141
with open(filename, 'wb') as file:
4242
for chunk in response.iter_content(chunk_size=8192):
4343
file.write(chunk)
4444
print(f"Successfully downloaded {filename}")
45-
45+
4646
except requests.exceptions.RequestException as e:
4747
print(f"An error occurred: {e}")
4848

4949
def __sample_requests(
50-
prompt_list: List[str],
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,
@@ -67,15 +67,15 @@ def __sample_requests(
6767
# Tokenize the prompts and completions.
6868
prompt = prompt_list[i]
6969
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
70-
70+
7171
prompt_len = len(prompt_token_ids)
7272
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
7373
# Prune too short or too long sequences.
7474
continue
7575
filtered_dataset.append((prompt, prompt_len))
76-
76+
7777
return filtered_dataset
78-
78+
7979

8080

8181
def sample_sharegpt_requests(
@@ -86,7 +86,7 @@ def sample_sharegpt_requests(
8686
prompt_length_max: int = 64,
8787
seed: Optional[int] = None
8888
) -> List[Tuple[str, int]]:
89-
if not os.path.exists(dataset_path):
89+
if not dataset_path or not os.path.exists(dataset_path):
9090
print("downloading share-gpt dataset as it does not exist")
9191
__download_file("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", dataset_path)
9292

@@ -96,15 +96,15 @@ def sample_sharegpt_requests(
9696
# Filter out the conversations with less than 2 turns.
9797
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
9898
dataset = [data["conversations"][0]["value"] for data in dataset]
99-
99+
100100
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
101101

102102
def sample_squad_v2_qa_requests(
103103
dataset_path: str,
104-
num_requests: int,
105-
tokenizer: BaseTokenizer,
106-
prompt_length_min: int = 32,
107-
prompt_length_max: int = 64,
104+
num_requests: int,
105+
tokenizer: BaseTokenizer,
106+
prompt_length_min: int = 32,
107+
prompt_length_max: int = 64,
108108
seed: Optional[int] = None
109109
) -> List[Tuple[str, int]]:
110110
from datasets import load_dataset
@@ -113,10 +113,10 @@ def sample_squad_v2_qa_requests(
113113
ds = load_dataset(dataset_path)['train']
114114
else:
115115
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
116-
117-
116+
117+
118118
ds = [f"{data['context']}\n{data['question']}" for data in ds]
119119

120120
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
121-
121+
122122

0 commit comments

Comments
 (0)