diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 5e1aeff330..2ff2e2ca6b 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -32,7 +32,7 @@ class TemplateArguments: system: Optional[str] = None # Override the default_system in the template. max_length: Optional[int] = None - truncation_strategy: Literal['delete', 'left', 'right', None] = None + truncation_strategy: Literal['delete', 'left', 'right', 'split', None] = None max_pixels: Optional[int] = None agent_template: Optional[str] = None norm_bbox: Literal['norm1000', 'none', None] = None diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index fce20eb7d2..d51e352df0 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -63,7 +63,7 @@ def __init__( default_system: Optional[str] = None, max_length: Optional[int] = None, *, - truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', + truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', max_pixels: Optional[int] = None, agent_template: Optional[str] = None, norm_bbox: Literal['norm1000', 'none', None] = None, @@ -524,6 +524,17 @@ def encode(self, else: raise ValueError(f'task_type: {self.task_type} is not supported.') + if isinstance(encoded, list): + processed_list = [] + for sub_encoded in encoded: + for key in list(sub_encoded.keys()): + if sub_encoded[key] is None: + sub_encoded.pop(key) + if not return_length: + sub_encoded.pop('length', None) + processed_list.append(sub_encoded) + return processed_list + if chosen.channel is not None: encoded['channel'] = chosen.channel @@ -1218,6 +1229,27 @@ def _encode_truncated(self, inputs: StdTemplateInputs): input_ids, labels, loss_scale = self._truncate( input_ids, labels, loss_scale, truncation_strategy=self.truncation_strategy) length = self._get_length(input_ids, labels) + elif self.truncation_strategy == 'split': + encoded_chunks = [] + block_size = self.max_length + for i in range(0, length, block_size): + new_encoded = encoded.copy() + chunk_input_ids = input_ids[i:i + block_size] + chunk_labels = labels[i:i + block_size] if labels is not None else None + chunk_loss_scale = loss_scale[i:i + block_size] if loss_scale is not None else None + + if chunk_labels is not None and i > 0 and len(chunk_labels) > 0: + chunk_labels = list(chunk_labels) + chunk_labels[0] = -100 + + new_encoded['input_ids'] = chunk_input_ids + new_encoded['labels'] = chunk_labels + new_encoded['loss_scale'] = chunk_loss_scale + new_encoded['length'] = self._get_length(chunk_input_ids, chunk_labels) + + encoded_chunks.append(new_encoded) + + return encoded_chunks elif self.truncation_strategy == 'raise': raise MaxLengthError(f'Current length of row({length}) is larger' f' than the max_length({self.max_length}).')