-
Notifications
You must be signed in to change notification settings - Fork 1k
feat: Add split truncation_strategy #6541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,7 +63,7 @@ def __init__( | |
| default_system: Optional[str] = None, | ||
| max_length: Optional[int] = None, | ||
| *, | ||
| truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', | ||
| truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', | ||
| max_pixels: Optional[int] = None, | ||
| agent_template: Optional[str] = None, | ||
| norm_bbox: Literal['norm1000', 'none', None] = None, | ||
|
|
@@ -524,6 +524,17 @@ def encode(self, | |
| else: | ||
| raise ValueError(f'task_type: {self.task_type} is not supported.') | ||
|
|
||
| if isinstance(encoded, list): | ||
| processed_list = [] | ||
| for sub_encoded in encoded: | ||
| for key in list(sub_encoded.keys()): | ||
| if sub_encoded[key] is None: | ||
| sub_encoded.pop(key) | ||
| if not return_length: | ||
| sub_encoded.pop('length', None) | ||
| processed_list.append(sub_encoded) | ||
| return processed_list | ||
|
|
||
| if chosen.channel is not None: | ||
| encoded['channel'] = chosen.channel | ||
|
|
||
|
|
@@ -1218,6 +1229,27 @@ def _encode_truncated(self, inputs: StdTemplateInputs): | |
| input_ids, labels, loss_scale = self._truncate( | ||
| input_ids, labels, loss_scale, truncation_strategy=self.truncation_strategy) | ||
| length = self._get_length(input_ids, labels) | ||
| elif self.truncation_strategy == 'split': | ||
| encoded_chunks = [] | ||
| block_size = self.max_length | ||
| for i in range(0, length, block_size): | ||
| new_encoded = encoded.copy() | ||
| chunk_input_ids = input_ids[i:i + block_size] | ||
| chunk_labels = labels[i:i + block_size] if labels is not None else None | ||
| chunk_loss_scale = loss_scale[i:i + block_size] if loss_scale is not None else None | ||
|
|
||
| if chunk_labels is not None and i > 0 and len(chunk_labels) > 0: | ||
| chunk_labels = list(chunk_labels) | ||
| chunk_labels[0] = -100 | ||
|
Comment on lines
+1242
to
+1243
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When setting chunk_labels = list(chunk_labels)
chunk_labels[0] = -100
if chunk_loss_scale is not None and len(chunk_loss_scale) > 0:
chunk_loss_scale = list(chunk_loss_scale)
chunk_loss_scale[0] = 0.0 |
||
|
|
||
| new_encoded['input_ids'] = chunk_input_ids | ||
| new_encoded['labels'] = chunk_labels | ||
| new_encoded['loss_scale'] = chunk_loss_scale | ||
| new_encoded['length'] = self._get_length(chunk_input_ids, chunk_labels) | ||
|
|
||
| encoded_chunks.append(new_encoded) | ||
|
|
||
| return encoded_chunks | ||
| elif self.truncation_strategy == 'raise': | ||
| raise MaxLengthError(f'Current length of row({length}) is larger' | ||
| f' than the max_length({self.max_length}).') | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'split' truncation strategy appears to be designed for text-only, decoder-only models, which is great for pre-training. However, it doesn't handle multimodal data or encoder-decoder architectures. Using
encoded.copy()performs a shallow copy, which means that for multimodal inputs, each chunk would incorrectly reference the same full list of images/videos/audios. Similarly, for encoder-decoder models, other fields likeprompt_input_idswould be copied without being chunked, leading to inconsistencies. To prevent incorrect usage and potential crashes, it's crucial to add checks to ensure this strategy is only used in supported scenarios.