Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cookbook/transformers/sp_fsdp_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
device_type=Platform.get_platform().device_prefix(),
)]

# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing)
# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2.
# In Transformers route, ulysses_size is the total sequence-parallel degree.
device_mesh = DeviceMesh(
device_type='cuda',
device_type=Platform.get_platform().device_prefix(),
mesh=np.arange(4).reshape(2, 2),
mesh_dim_names=('dp', 'fsdp'),
ulysses_size=2,
Expand Down
3 changes: 2 additions & 1 deletion cookbook/transformers/sp_fsdp_dense.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash
# To enabele sequence parallelism, please set ulysses_size > 1
# To enable Transformers sequence parallelism, please set ulysses_size > 1.
# ulysses_size is interpreted as the total sequence-parallel degree.
# device_mesh = DeviceMesh(
# device_type="cuda",
# mesh=np.arange(4).reshape(2, 2),
Expand Down
6 changes: 6 additions & 0 deletions src/twinkle/metric/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
return
loss = outputs['loss']
loss_reduction = kwargs.get('loss_reduction', 'mean')
ulysses_size = getattr(self.device_mesh, 'ulysses_size', None) or 1
if loss_reduction == 'sum':
if not isinstance(inputs, list):
inputs = [inputs]
for input in inputs:
# `Transformers` models may use reduction=sum, to average grads before step
labels = input['labels']
self.num_tokens += (labels >= 0).sum().item()
# Sequence-parallel gathered loss is replicated on each ulysses rank, while
# local labels still count only the shard-local tokens. Normalize the loss
# contribution here so metric-side averaging matches the non-SP path.
if ulysses_size > 1:
loss = loss / float(ulysses_size)
grad_norm = kwargs.get('grad_norm')
if grad_norm is not None:
self.grad_norm = grad_norm
Expand Down
Loading
Loading