Skip to content

Commit adda790

Browse files
committed
🔧 Add argument validations
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent 5c7c921 commit adda790

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

scripts/generate_metrics.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
type=str,
2828
default=None,
2929
help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.",
30+
required=True,
3031
)
3132
parser.add_argument(
3233
"--model_path",
@@ -37,65 +38,75 @@
3738
"--model_source",
3839
type=str,
3940
help="Source of the checkpoint. E.g. 'meta', 'hf', None",
41+
required=False,
4042
)
4143
parser.add_argument(
4244
"--tokenizer",
4345
type=str,
44-
required=True,
4546
help="Path to the tokenizer (e.g. ~/tokenizer.model)",
47+
required=True,
4648
)
4749
parser.add_argument(
4850
"--default_dtype",
4951
type=str,
5052
default=None,
5153
choices=["bf16", "fp16", "fp32"],
5254
help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format",
55+
required=False,
5356
)
5457
parser.add_argument(
5558
"--batch_size",
5659
type=int,
5760
default=1,
5861
help="size of input batch",
62+
required=False,
5963
)
6064
parser.add_argument(
6165
"--min_pad_length",
6266
type=int,
6367
help="Pad inputs to a minimum specified length. If any prompt is larger than the specified length, padding will be determined by the largest prompt",
6468
default=0,
69+
required=False,
6570
)
6671
parser.add_argument(
6772
"--max_new_tokens",
6873
type=int,
6974
help="max number of generated tokens",
7075
default=100,
76+
required=False,
7177
)
7278
parser.add_argument(
7379
"--sharegpt_path",
7480
type=str,
7581
help="path to sharegpt data json",
82+
required=False,
7683
)
7784
parser.add_argument(
7885
"--output_dir",
7986
type=str,
8087
help="output directory",
88+
required=True,
8189
)
8290
parser.add_argument(
8391
"--topk_per_token",
8492
type=int,
8593
help="top k values per token to generate loss on",
86-
default=20
94+
default=20,
95+
required=False,
8796
)
8897
parser.add_argument(
8998
"--num_test_tokens_per_sequence",
9099
type=int,
91100
help="number of tokens in test. For instance, if max_new_tokens=128 and num_test_tokens_per_sequence=256, this means we will generate data over 2 sample prompts. If not set, will be set to max_new_tokens",
92-
default=None
101+
default=None,
102+
required=False,
93103
)
94104
parser.add_argument(
95105
"--extra_get_model_kwargs",
96106
nargs='*',
97107
default={},
98-
help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..."
108+
help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,...",
109+
required=False,
99110
)
100111
args = parser.parse_args()
101112

@@ -129,6 +140,12 @@
129140

130141
torch.set_grad_enabled(False)
131142

143+
# As per FMS check https://github.com/foundation-model-stack/foundation-model-stack/blob/ec55d3f4d2a620346a1eb003699db0b0d47e2598/fms/models/__init__.py#L88
144+
# we need to remove variant if model_arg or model_path is provided
145+
if args.model_path and args.variant:
146+
print("Both variant and model path provided. Removing variant")
147+
args.variant = None
148+
132149
# prepare the cuda model
133150
cuda_model = get_model(
134151
architecture=args.architecture,
@@ -211,14 +228,14 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
211228
failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens)
212229

213230
print("extracted cuda validation information level 0")
214-
if len(failed_responses) != 0:
231+
if len(failed_responses) != 0:
215232
print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer)
216233

217234
def write_csv(l, path, metric):
218235
with open(path, 'w') as f:
219236
f.write(f'{metric}\n')
220237
for t in l:
221-
f.write(f"{t[2].item()}\n")
238+
f.write(f"{t[2].item()}\n")
222239
f.close()
223240

224241
num_test_tokens_per_sequence = args.num_test_tokens_per_sequence

0 commit comments

Comments
 (0)