Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TemplateArguments:
system: Optional[str] = None # Override the default_system in the template.
max_length: Optional[int] = None

truncation_strategy: Literal['delete', 'left', 'right', None] = None
truncation_strategy: Literal['delete', 'left', 'right', 'split', None] = None
max_pixels: Optional[int] = None
agent_template: Optional[str] = None
norm_bbox: Literal['norm1000', 'none', None] = None
Expand Down
34 changes: 33 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
default_system: Optional[str] = None,
max_length: Optional[int] = None,
*,
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
max_pixels: Optional[int] = None,
agent_template: Optional[str] = None,
norm_bbox: Literal['norm1000', 'none', None] = None,
Expand Down Expand Up @@ -524,6 +524,17 @@ def encode(self,
else:
raise ValueError(f'task_type: {self.task_type} is not supported.')

if isinstance(encoded, list):
processed_list = []
for sub_encoded in encoded:
for key in list(sub_encoded.keys()):
if sub_encoded[key] is None:
sub_encoded.pop(key)
if not return_length:
sub_encoded.pop('length', None)
processed_list.append(sub_encoded)
return processed_list

if chosen.channel is not None:
encoded['channel'] = chosen.channel

Expand Down Expand Up @@ -1218,6 +1229,27 @@ def _encode_truncated(self, inputs: StdTemplateInputs):
input_ids, labels, loss_scale = self._truncate(
input_ids, labels, loss_scale, truncation_strategy=self.truncation_strategy)
length = self._get_length(input_ids, labels)
elif self.truncation_strategy == 'split':
encoded_chunks = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The 'split' truncation strategy appears to be designed for text-only, decoder-only models, which is great for pre-training. However, it doesn't handle multimodal data or encoder-decoder architectures. Using encoded.copy() performs a shallow copy, which means that for multimodal inputs, each chunk would incorrectly reference the same full list of images/videos/audios. Similarly, for encoder-decoder models, other fields like prompt_input_ids would be copied without being chunked, leading to inconsistencies. To prevent incorrect usage and potential crashes, it's crucial to add checks to ensure this strategy is only used in supported scenarios.

                if self.is_encoder_decoder:
                    raise ValueError("The 'split' truncation strategy is not supported for encoder-decoder models.")
                if inputs.is_multimodal:
                    raise ValueError("The 'split' truncation strategy is not supported for multimodal inputs.")
                encoded_chunks = []

block_size = self.max_length
for i in range(0, length, block_size):
new_encoded = encoded.copy()
chunk_input_ids = input_ids[i:i + block_size]
chunk_labels = labels[i:i + block_size] if labels is not None else None
chunk_loss_scale = loss_scale[i:i + block_size] if loss_scale is not None else None

if chunk_labels is not None and i > 0 and len(chunk_labels) > 0:
chunk_labels = list(chunk_labels)
chunk_labels[0] = -100
Comment on lines +1242 to +1243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When setting chunk_labels[0] to -100 to prevent loss calculation on the first token of a new chunk, it's good practice to also update the corresponding loss_scale for consistency. If loss_scale is being used, its first element should be set to 0.0 to align with the masked label. This ensures that the loss scaling is correctly handled, especially in scenarios where is_loss_scale_binary is false.

                        chunk_labels = list(chunk_labels)
                        chunk_labels[0] = -100
                        if chunk_loss_scale is not None and len(chunk_loss_scale) > 0:
                            chunk_loss_scale = list(chunk_loss_scale)
                            chunk_loss_scale[0] = 0.0


new_encoded['input_ids'] = chunk_input_ids
new_encoded['labels'] = chunk_labels
new_encoded['loss_scale'] = chunk_loss_scale
new_encoded['length'] = self._get_length(chunk_input_ids, chunk_labels)

encoded_chunks.append(new_encoded)

return encoded_chunks
elif self.truncation_strategy == 'raise':
raise MaxLengthError(f'Current length of row({length}) is larger'
f' than the max_length({self.max_length}).')
Expand Down
Loading