Skip to content

Conversation

@Ssukriti
Copy link

Currently on main, we generate all valid prompts for all programs in a large list called valid_prompts. After which, we start validating programs using the list of prompts. This can cause large memory consumption if size of prompts is large. We were running into out of memory with 128k size prompts.

This PR makes the generation of prompts a iterator instead. Prompts will be generated for a program and the program will be validated, before going to the next program. If prompts could not be generated for a program, the program validation is skipped just like today.

This helps save a lot of memory. Users will see change that prompt extraction for all programs will not happen upfront, and will happen as needed for a program

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti Ssukriti marked this pull request as ready for review November 17, 2025 20:52
@Ssukriti
Copy link
Author

I tested with 32*32 RAG factoid dataset on granite model and got comparable logs as the main branch

Kept refactor to the minimal for the memory fix, as there is another branch with a whole refactor

@Ssukriti Ssukriti requested a review from JRosenkranz November 17, 2025 20:54
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
(
def get_program_prompt_list():
if custom_shape:
prompt_found = 0
Copy link
Author

@Ssukriti Ssukriti Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since its hard to see git diff changes made in this PR:,
change 1 - use prompt_found flag as we are yielding instead of storing in list

pad_multiple=pad_multiple,
)
prompt_found = 1
yield (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change2: yield instead of list, flag set before yield

)
]
break
if prompt_found:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change3: see flag instead of length of list

)
valid_prompts.append(
(
used_keys.add(program_seq_key[0])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 4: used_keys.add(program_seq_key[0]) before yield and then yield

input_ids,
extra_kwargs,
sample_key,
) in get_program_prompt_list():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change 5: call function to yield instead of list

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti
Copy link
Author

rebased on latest main and tested

@JRosenkranz
Copy link
Contributor

bot:test
TEST_FILE=test_scripts.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants