-
Notifications
You must be signed in to change notification settings - Fork 30
Refactor get valid prompts - for memory optimization #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
|
I tested with 32*32 RAG factoid dataset on granite model and got comparable logs as the Kept refactor to the minimal for the memory fix, as there is another branch with a whole refactor |
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
| ( | ||
| def get_program_prompt_list(): | ||
| if custom_shape: | ||
| prompt_found = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since its hard to see git diff changes made in this PR:,
change 1 - use prompt_found flag as we are yielding instead of storing in list
| pad_multiple=pad_multiple, | ||
| ) | ||
| prompt_found = 1 | ||
| yield ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change2: yield instead of list, flag set before yield
| ) | ||
| ] | ||
| break | ||
| if prompt_found: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change3: see flag instead of length of list
| ) | ||
| valid_prompts.append( | ||
| ( | ||
| used_keys.add(program_seq_key[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change 4: used_keys.add(program_seq_key[0]) before yield and then yield
| input_ids, | ||
| extra_kwargs, | ||
| sample_key, | ||
| ) in get_program_prompt_list(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change 5: call function to yield instead of list
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
|
rebased on latest main and tested |
|
bot:test |
Currently on main, we generate all valid prompts for all programs in a large list called
valid_prompts. After which, we start validating programs using the list of prompts. This can cause large memory consumption if size of prompts is large. We were running into out of memory with 128k size prompts.This PR makes the generation of prompts a iterator instead. Prompts will be generated for a program and the program will be validated, before going to the next program. If prompts could not be generated for a program, the program validation is skipped just like today.
This helps save a lot of memory. Users will see change that prompt extraction for all programs will not happen upfront, and will happen as needed for a program