-
Notifications
You must be signed in to change notification settings - Fork 30
Refactor get valid prompts - for memory optimization #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -523,129 +523,129 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: | |
| for v in program_map.values(): | ||
| random.Random(42).shuffle(v) | ||
|
|
||
|
|
||
| # select prompts that fit the batch size criteria | ||
| valid_prompts = [] | ||
| if custom_shape: | ||
| for program_criteria_seq, valid_prompt_shapes in program_map.items(): | ||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| if valid_prompt_shape == custom_shape: | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| pad_multiple=pad_multiple, | ||
| ) | ||
| valid_prompts = [ | ||
| ( | ||
| def get_program_prompt_list(): | ||
| if custom_shape: | ||
| prompt_found = 0 | ||
| for program_criteria_seq, valid_prompt_shapes in program_map.items(): | ||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| if valid_prompt_shape == custom_shape: | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| pad_multiple=pad_multiple, | ||
| ) | ||
| prompt_found = 1 | ||
| yield ( | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change2: yield instead of list, flag set before yield |
||
| program_criteria_seq[0].program_id, | ||
| custom_shape, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) | ||
| ] | ||
| break | ||
| if prompt_found: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change3: see flag instead of length of list |
||
| break | ||
| if len(valid_prompts) > 0: | ||
| break | ||
| else: | ||
| for program_info in programs: | ||
| program_id = program_info.program_id | ||
| batch_size_limit = program_info.batch_size_limit | ||
| batch_size_limit_type = program_info.batch_size_limit_type | ||
| prompt_length_limit = program_info.prompt_length_limit | ||
| prompt_length_limit_type = program_info.prompt_length_limit_type | ||
|
|
||
| filtered_program_map = program_map | ||
| if program_id.isnumeric(): | ||
| filtered_program_map = { | ||
| k: v | ||
| for k, v in program_map.items() | ||
| if k[0] == program_criteria_list[int(program_id)] | ||
| } | ||
| used_keys = set() | ||
| # for each program, we need to check if we have a shape that satisfies the --programs request | ||
| for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): | ||
| # if ? or numeric => we need to check if we have found at least one valid key to stop | ||
| if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: | ||
| break | ||
| # if * => we need to see if we have found the first key to see if we should skip | ||
| elif program_id == "*" and program_seq_key[0] in used_keys: | ||
| continue | ||
|
|
||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| # make sure the criteria for batch limit and prompt limit is satisfied | ||
| # eval is safe here because we have limited what type and limit can be before | ||
| else: | ||
| for program_info in programs: | ||
| program_id = program_info.program_id | ||
| batch_size_limit = program_info.batch_size_limit | ||
| batch_size_limit_type = program_info.batch_size_limit_type | ||
| prompt_length_limit = program_info.prompt_length_limit | ||
| prompt_length_limit_type = program_info.prompt_length_limit_type | ||
|
|
||
| filtered_program_map = program_map | ||
| if program_id.isnumeric(): | ||
| filtered_program_map = { | ||
| k: v | ||
| for k, v in program_map.items() | ||
| if k[0] == program_criteria_list[int(program_id)] | ||
| } | ||
| used_keys = set() | ||
| # for each program, we need to check if we have a shape that satisfies the --programs request | ||
| for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): | ||
| # if ? or numeric => we need to check if we have found at least one valid key to stop | ||
| if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: | ||
| break | ||
| # if * => we need to see if we have found the first key to see if we should skip | ||
| elif program_id == "*" and program_seq_key[0] in used_keys: | ||
| continue | ||
|
|
||
| for valid_prompt_shape in valid_prompt_shapes: | ||
| # make sure the criteria for batch limit and prompt limit is satisfied | ||
| # eval is safe here because we have limited what type and limit can be before | ||
|
|
||
| batch_check = eval( | ||
| f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" | ||
| ) | ||
| prompt_check = eval( | ||
| f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" | ||
| ) | ||
| if batch_check and prompt_check: | ||
| # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length | ||
| # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning | ||
| # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| if ( | ||
| args.enforce_homogeneous_prompt_programs | ||
| or args.prefill_chunk_size > 0 | ||
| ): | ||
| # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length | ||
| tkv_cutoff = ( | ||
| 1 << (valid_prompt_shape[1].bit_length() - 1) | ||
| if args.enforce_homogeneous_prompt_programs | ||
| else pad_multiple | ||
| ) | ||
|
|
||
| batch_check = eval( | ||
| f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" | ||
| ) | ||
| prompt_check = eval( | ||
| f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" | ||
| ) | ||
| if batch_check and prompt_check: | ||
| # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length | ||
| # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning | ||
| # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user | ||
| enforce_sizes = [valid_prompt_shape[1]] | ||
| if ( | ||
| args.enforce_homogeneous_prompt_programs | ||
| or args.prefill_chunk_size > 0 | ||
| ): | ||
| # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length | ||
| tkv_cutoff = ( | ||
| 1 << (valid_prompt_shape[1].bit_length() - 1) | ||
| if args.enforce_homogeneous_prompt_programs | ||
| else pad_multiple | ||
| ) | ||
|
|
||
| possible_seq_lengths = [ | ||
| _ | ||
| for _ in range( | ||
| tkv_cutoff, valid_prompt_shape[1], pad_multiple | ||
| possible_seq_lengths = [ | ||
| _ | ||
| for _ in range( | ||
| tkv_cutoff, valid_prompt_shape[1], pad_multiple | ||
| ) | ||
| ] | ||
| # favor sequences that are close to the valid prompt length | ||
| possible_seq_lengths.reverse() | ||
| # add the valid prompt size to the end since it will already exist in the above enforce_sizes | ||
| possible_seq_lengths = possible_seq_lengths + [ | ||
| valid_prompt_shape[1] | ||
| ] | ||
| enforce_sizes = enforce_sizes + list( | ||
| itertools.islice( | ||
| itertools.cycle(possible_seq_lengths), | ||
| valid_prompt_shape[0] - 1, | ||
| ) | ||
| ) | ||
| ] | ||
| # favor sequences that are close to the valid prompt length | ||
| possible_seq_lengths.reverse() | ||
| # add the valid prompt size to the end since it will already exist in the above enforce_sizes | ||
| possible_seq_lengths = possible_seq_lengths + [ | ||
| valid_prompt_shape[1] | ||
| ] | ||
| enforce_sizes = enforce_sizes + list( | ||
| itertools.islice( | ||
| itertools.cycle(possible_seq_lengths), | ||
| valid_prompt_shape[0] - 1, | ||
| try: | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) | ||
| ) | ||
| ) | ||
| try: | ||
| input_ids, extra_kwargs, sample_key = __prepare_inputs( | ||
| valid_prompt_shape[0], | ||
| valid_prompt_shape[1], | ||
| tokenizer, | ||
| enforce_sizes=enforce_sizes, | ||
| pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) | ||
| ) | ||
| valid_prompts.append( | ||
| ( | ||
| used_keys.add(program_seq_key[0]) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change 4: used_keys.add(program_seq_key[0]) before yield and then yield |
||
| yield ( | ||
| program_seq_key[0], | ||
| valid_prompt_shape, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) | ||
| ) | ||
| used_keys.add(program_seq_key[0]) | ||
| break | ||
| except ValueError: | ||
| dprint( | ||
| f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" | ||
| ) | ||
|
|
||
| if len(used_keys) == 0 and local_rank == 0: | ||
| dprint( | ||
| f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" | ||
| ) | ||
|
|
||
| break | ||
| except ValueError: | ||
| dprint( | ||
| f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" | ||
| ) | ||
|
|
||
| if len(used_keys) == 0 and local_rank == 0: | ||
| dprint( | ||
| f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" | ||
| ) | ||
|
|
||
|
|
||
| # metric calculator based on the cross-entropy and mean diff for each decode step | ||
|
|
@@ -664,7 +664,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): | |
|
|
||
| failed_cases = [] | ||
| # for each program and valid prompt (batch size, sequence length) | ||
| for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: | ||
| for ( | ||
| program_id, | ||
| valid_prompt, | ||
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) in get_program_prompt_list(): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change 5: call function to yield instead of list |
||
| extra_kwargs["attn_name"] = ATTN_NAME | ||
| if ( | ||
| "granite-3.3-8b-instruct" in model_variant | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since its hard to see git diff changes made in this PR:,
change 1 - use prompt_found flag as we are yielding instead of storing in list