diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index ffffbed50..de64d905a 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -179,6 +179,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 4d31324fe..4a6f3b3cb 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -339,15 +339,15 @@ def _preprocess_data( phase=context.phase, iteration=context.iteration, metrics=context.metrics, + extra_kwargs={ + "grad_output": grad_output, + "micro_batch": micro_batch, + "num_micro_batches": batch_config.sequential_micro_batches, + "micro_batch_splits": batch_config.micro_batch_splits, + }, ) for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): - kwargs.update( - grad_output=grad_output, - micro_batch=micro_batch, - micro_batch_split=micro_batch_split, - num_micro_batches=batch_config.sequential_micro_batches, - micro_batch_splits=batch_config.micro_batch_splits, - ) + kwargs.update(micro_batch_split=micro_batch_split) data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index 3e8e31cea..586f833b3 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -62,8 +62,8 @@ def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.T class AuxiliaryLoss(torch.autograd.Function): @staticmethod - def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa - ctx.grad = torch.full_like(aux_loss, grad) + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float | None) -> torch.Tensor: # noqa + ctx.grad = None if grad is None else torch.full_like(aux_loss, grad) return input_ @staticmethod diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index d6eab0eb2..859bafea2 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ @@ -12,7 +12,6 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -113,7 +112,7 @@ def __init__( CompositeTensorDim("value", (head_group_dim, head_size_dim)), ), ) - dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) + self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power) @@ -152,7 +151,7 @@ def __init__( # Output. self.dense = self._config.dense_layer.get_layer( - dense_dim, + self._dense_dim, hidden_dim, default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), default_add_bias=self._config.add_linear_biases, @@ -163,22 +162,13 @@ def __init__( # Debug dims self._query_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), head_size_dim, ) self._kv_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, head_group_dim, head_size_dim, ) - self._context_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, - dense_dim, - ) def _attn_backup( self, @@ -269,7 +259,7 @@ def _attn_flash( ) def _query_key_value_forward( - self, input_: torch.Tensor, sequence_first: bool + self, input_: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: key_value, key_value_context = self.key_value.forward_only(input_) @@ -292,10 +282,7 @@ def _query_key_value_forward( if handle: handle.wait() - if self._sequence_data_parallel_dim.group and not sequence_first: - key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) - - context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} + context = {"query": query_context, "key_value": key_value_context} return query, key_value, context def _query_key_value_backward( @@ -305,7 +292,7 @@ def _query_key_value_backward( key_value_grad, handle = reduce_scatter_op( key_value_grad, group=self._sequence_data_parallel_dim.group, - dim=1 - context["sequence_first"], + dim=0, async_op=True, ) @@ -331,15 +318,19 @@ def _forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[AttentionKwargs.sequence_first] - query, key_value = self._query_key_value(input_, sequence_first) + query, key_value = self._query_key_value(input_) + + # Separate the batch and sequence dimensions + token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim]) + token_shape = tuple(dim.size for dim in token_dims) + query = query.unflatten(0, token_shape) + key_value = key_value.unflatten(0, token_shape) # TODO: Move the rest to function. if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: - assert sequence_first # Clear the lists so tensors can be de-allocated - key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) + key_value = torch.cat((past_key_values.pop(0), key_value), dim=1) if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences @@ -348,26 +339,15 @@ def _forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._sequence_data_parallel_dim.group: - key_value = ( - key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] - if sequence_first - else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] - ) - - if sequence_first: - # TODO: Optimize (is contiguous avoidable?) - query = query.transpose(0, 1).contiguous() - key_value = key_value.transpose(0, 1).contiguous() - + key_value = key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) - self._debug(query, "query_rotary_input", self._query_dims, kwargs) - self._debug(key, "key_rotary_input", self._kv_dims, kwargs) + self._debug(query, "query_rotary_input", token_dims + self._query_dims, kwargs) + self._debug(key, "key_rotary_input", token_dims + self._kv_dims, kwargs) query, key = self._rotary(query, key, kwargs) with set_generator(self._distributed.tp_generator): @@ -379,22 +359,17 @@ def _forward( else: raise NotImplementedError(self._implementation) - self._debug(query, "query", self._query_dims, kwargs) - self._debug(key, "key", self._kv_dims, kwargs) - self._debug(value, "value", self._kv_dims, kwargs) - self._debug(input_, "context", self._context_dims, kwargs) + self._debug(query, "query", token_dims + self._query_dims, kwargs) + self._debug(key, "key", token_dims + self._kv_dims, kwargs) + self._debug(value, "value", token_dims + self._kv_dims, kwargs) + self._debug(input_, "context", token_dims + (self._dense_dim,), kwargs) - if sequence_first: - # TODO: Optimize (is contiguous avoidable? Transpose dense output?) - input_ = input_.transpose(0, 1).contiguous() - out, bias = self.dense(input_) - self._debug(out, None, kwargs.get(AttentionKwargs.hidden_dims), kwargs) + out, bias = self.dense(input_.flatten(0, 1)) + self._debug(out, None, token_dims + (self._hidden_dim,), kwargs) return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - batch_dim: TensorDim = kwargs[AttentionKwargs.hidden_dims][1 if kwargs[AttentionKwargs.sequence_first] else 0] - - # Using this one since `hidden_dims` may be sequence-tensor-parallel, and attention is not. + batch_dim: TensorDim = kwargs[AttentionKwargs.batch_dim] sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim] sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim] @@ -435,7 +410,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c partly_out_of_window = max(sequence_k - fully_out_of_window - self._config.window_size, 0) attention_compute -= (partly_out_of_window * (partly_out_of_window + 1) * attn_compute_base) // 2 - dense_input = TensorMeta.from_dims((batch_dim, sequence_q_dim, self._context_dims[-1])) + dense_input = TensorMeta.from_dims((*input_.dims[:-1], self._dense_dim)) # TODO: Add marginal compute? (ex. softmax) return sum( diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index a1942cab1..dc7334b45 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -47,7 +47,7 @@ def __call__( if bias is not None: assert tensor is not None tensor = tensor + bias - meta = self._get_meta(tensor, name, dims, kwargs) + meta = self._get_meta(tensor, name, dims) if output_hidden_state: kwargs[BlockKwargs.hidden_states][name] = (meta, tensor) @@ -68,7 +68,7 @@ def __call__( "", tensor, level=level, - meta=self._get_meta(tensor, name + f"{name}.grad", dims, kwargs), + meta=self._get_meta(tensor, name + f"{name}.grad", dims), **logging_kwargs, ) @@ -76,26 +76,13 @@ def _get_meta( self, tensor: torch.Tensor | None, name: str, - dims: tuple[TensorDim | str, ...] | None, - kwargs: dict[str, typing.Any], - ) -> TensorMeta | None: - if tensor is None: - return None + dims: tuple[TensorDim | str | None, ...] | None, + ) -> TensorMeta: if dims is None: dims = tuple(f"dim_{i}" for i in range(tensor.ndim)) - hidden_dims = {} - if BlockKwargs.hidden_dims in kwargs: - for dim in kwargs[BlockKwargs.hidden_dims]: - hidden_dims[dim.name] = dim - if BlockKwargs.sequence_q_dim in kwargs: - hidden_dims[kwargs[BlockKwargs.sequence_q_dim].name] = kwargs[BlockKwargs.sequence_q_dim] return TensorMeta.from_dims( tuple( - ( - dim - if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) - ) + (dim if isinstance(dim, TensorDim) else TensorDim(f"dim_{i}" if dim is None else dim, tensor.size(i))) for i, dim in enumerate(dims) ), tensor_name=name, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd76d36cb..a1b600445 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -31,10 +31,11 @@ class BlockDimNames: class BlockKwargs: - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" + batch_dim = "batch_dim" sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" + token_dim = "token_dim" + hidden_token_dim = "hidden_token_dim" # TODO: These are confusing sequence_length = "sequence_length" sequence_lengths = "sequence_lengths" diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8f6e360fd..dd19c1086 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -126,143 +126,76 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> torch.Tensor: if isinstance(input_, TensorMeta): - dims = kwargs[BlockKwargs.hidden_dims] + dims = input_.dims if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator - self._debug(None, "begin", kwargs.get(BlockKwargs.hidden_dims), kwargs) + hidden_dims = (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim) + self._debug(None, "begin", hidden_dims, kwargs) fw_input = input_ hidden_states = self.norm_1(input_) - self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, "norm_1", hidden_dims, kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs, metrics=metrics) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses, metrics) + self._debug(hidden_states.detach(), "mixer_output", hidden_dims, kwargs, bias=bias) + if self._config.distillation_model is not None and self.training: + if bias is not None: + hidden_states = hidden_states + bias + bias = None + hidden_states = self._activation_distillation_loss(hidden_states, kwargs, losses, metrics) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) - self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(input_, "mixer_residual", hidden_dims, kwargs) hidden_states = self.norm_2(input_) - self._debug(hidden_states, "norm_2", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, "norm_2", hidden_dims, kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - self._debug(hidden_states, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, None, hidden_dims, kwargs) if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metrics): - """ - Maybe apply activation distillation loss and setup backward hooks. - """ - mixer_output = hidden_states if bias is None else hidden_states + bias - - # Teacher: output mixer activations via _debug interface - self._debug(mixer_output.detach(), "mixer_output", kwargs.get(BlockKwargs.hidden_dims), kwargs) - - # Student gets teacher activations and computes the activation-level loss. - activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) - key = f"{self.module_name}.mixer_output" - if ( - activation_targets is not None - and self.training - and (teacher_output := activation_targets.pop(key, None)) is not None - ): - # Compare student mixer output with the teacher's stored activation and accumulate the loss. - teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) - Assert.eq(teacher_tensor.shape, mixer_output.shape) - # TODO: un-scaled loss for reporting? Average loss over layers? - # L2 loss - activation_loss_factor = self._config.distillation_loss_weight - # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. - - # Handle possible padding by using pre-computed activation mask - sequence_first = kwargs.get(BlockKwargs.sequence_first, False) - activation_mask = kwargs.get(BlockKwargs.activation_mask) - - if activation_mask is not None: - # Use pre-computed activation mask (bool tensor where True = valid token) - mask = activation_mask.to(dtype=mixer_output.dtype) - if sequence_first: - # (batch, sequence) -> (sequence, batch) - mask = mask.T - - # Compute masked L2 loss: norm over hidden dim, then apply mask - per_token_loss = torch.norm( - mixer_output - teacher_tensor, p=2, dim=-1 - ) # (batch, sequence) or (sequence, batch) - - # Slice mask to match per_token_loss shape (for sequence parallelism) - # When sequence_tensor_parallel is enabled, per_token_loss only has local sequence length - if mask.shape != per_token_loss.shape: - # Calculate the sequence offset for this rank using the hidden_dims parallel rank - hidden_dims = kwargs.get(BlockKwargs.hidden_dims) - seq_dim_idx = 0 if sequence_first else 1 - hidden_seq_dim = hidden_dims[seq_dim_idx] if hidden_dims else None - - if hidden_seq_dim and hidden_seq_dim.parallel_dim: - # Use the rank from the actual parallel dimension used by hidden states - local_seq_length = per_token_loss.shape[0] if sequence_first else per_token_loss.shape[1] - seq_offset = hidden_seq_dim.parallel_dim.rank * local_seq_length - else: - seq_offset = 0 - - if sequence_first: - # mask: (sequence, batch), per_token_loss: (local_sequence, batch) - mask = mask[seq_offset : seq_offset + per_token_loss.shape[0], :] - else: - # mask: (batch, sequence), per_token_loss: (batch, local_sequence) - mask = mask[:, seq_offset : seq_offset + per_token_loss.shape[1]] - - masked_loss = per_token_loss * mask - local_loss_sum = torch.sum(masked_loss) - total_count = int(mask.sum().item()) - else: - # No activation_mask available, compute loss on all tokens - per_token_loss = torch.norm( - mixer_output - teacher_tensor, p=2, dim=-1 - ) # (batch, sequence) or (sequence, batch) - local_loss_sum = torch.sum(per_token_loss) - # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) - # In either case, dims 0 and 1 are batch and sequence - total_count = mixer_output.shape[0] * mixer_output.shape[1] - - # All-reduce across tensor-parallel group if sequence-parallel is enabled - if self._sequence_parallel and self._distributed.tensor_group is not None: - all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) - if activation_mask is not None: - # Different ranks may have different amounts of padding - total_count_tensor = torch.tensor(total_count, device=mixer_output.device, dtype=torch.int64) - all_reduce(total_count_tensor, group=self._distributed.tensor_group, op=ReduceOp.SUM) - total_count = int(total_count_tensor.item()) - else: - # All ranks contribute the same count - total_count *= self._distributed.tensor_group.size() - - activation_loss = local_loss_sum / total_count - scaled_activation_loss = activation_loss_factor * activation_loss - - # Backward hooks - hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0) - bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None - # Logging - if losses is not None and self._distillation_loss_name in losses: - losses[self._distillation_loss_name].append(activation_loss.detach()) - # Per-layer metrics - if metrics is not None: - metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach() - - # If using stochastic mixer, also log per-mixer-type activation distillation loss - from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer - - if isinstance(self.mixer, StochasticMixer): - selected_mixer = self.mixer._last_selected_mixer - metrics[f"{self.module_name}/activation_distillation_loss/{selected_mixer}"] = ( - activation_loss.detach() - ) - return hidden_states, bias + def _activation_distillation_loss(self, hidden_states, kwargs, losses, metrics): + Assert.incl( + mixer_output_name := f"{self.module_name}.mixer_output", + reference_hidden_states := kwargs[f"reference_{self._config.distillation_model}_hidden_states"], + ) + teacher_hidden_states = reference_hidden_states.pop(mixer_output_name) + + # L2 loss + per_token_loss = torch.norm(hidden_states - teacher_hidden_states, dim=-1, dtype=torch.float32) + if (activation_mask := kwargs.get(BlockKwargs.activation_mask)) is not None: + per_token_loss = per_token_loss * activation_mask + loss = torch.mean(per_token_loss) + + # All-reduce across tensor-parallel group if sequence-parallel is enabled + if self._sequence_parallel and self._distributed.tensor_group is not None: + all_reduce(loss, group=self._distributed.tensor_group, op=ReduceOp.AVG) + + scaled_activation_loss = self._config.distillation_loss_weight * loss + + # Backward hook + hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, kwargs.get(BlockKwargs.grad_output)) + + # Logging + if losses is not None and self._distillation_loss_name in losses: + losses[self._distillation_loss_name].append(loss.detach()) + + if metrics is not None: + metrics[f"{self.module_name}/activation_distillation_loss"] = loss.detach() + + # If using stochastic mixer, also log per-mixer-type activation distillation loss + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + if isinstance(self.mixer, StochasticMixer): + metrics[f"{self.module_name}/activation_distillation_loss/{self.mixer._last_selected_mixer}"] = ( + loss.detach() + ) + return hidden_states def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 413a88ed6..13ba79a7a 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -12,7 +12,6 @@ from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType @@ -94,11 +93,8 @@ def _forward( return TensorMeta.from_dims(input_.dims[:-1] + (self._output_dim,), "MLP output"), None hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - logit_dims = ( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,) - if BlockKwargs.hidden_dims in kwargs - else None - ) + hidden_token_dim = kwargs[BlockKwargs.hidden_token_dim] + logit_dims = (hidden_token_dim, self._top_expert_dim) self._debug(logits, "Router logits", logit_dims, kwargs) # Apply z_loss if applicable @@ -130,7 +126,7 @@ def _forward( self._debug(top_experts, "router_top_experts", logit_dims, kwargs) out = self._mlp_forward(hidden_states, scores, top_experts).view_as(input_) # noqa - self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(out, None, (hidden_token_dim, self._hidden_dim), kwargs) return out, None def _forward_dropless( @@ -241,24 +237,14 @@ def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - if kwargs[AttentionKwargs.sequence_first]: - sequence_dim, batch_dim, hidden_dim = input_.dims - else: - batch_dim, sequence_dim, hidden_dim = input_.dims - - # Applying the tokens per expert on the batch dim so the super() call works as intended. - moe_batch_dim = TensorDim( - f"moe_{batch_dim.name}", batch_dim.global_size * self._config.experts_per_token, batch_dim.parallel_dim + token_dim, hidden_dim = input_.dims + # Applying the tokens per expert on the token dim so the super() call works as intended. + moe_token_dim = TensorDim( + f"moe_{token_dim.name}", token_dim.global_size * self._config.experts_per_token, token_dim.parallel_dim + ) + moe_input = TensorMeta.from_dims( + (moe_token_dim, hidden_dim), tensor_name=f"moe_{input_.tensor_name}", dtype=input_.dtype ) - - if kwargs[AttentionKwargs.sequence_first]: - dims = sequence_dim, moe_batch_dim, hidden_dim - else: - dims = moe_batch_dim, sequence_dim, hidden_dim - - # Also adjust the dtype in case of full-precision residual - moe_input = TensorMeta.from_dims(dims, tensor_name=f"moe_{input_.tensor_name}", dtype=input_.dtype) - return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 88c86c8aa..1048f7c2a 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -9,7 +9,6 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -85,16 +84,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c if config.hardware and self._config.recompute_level.recompute_layer_1 else config ) - - # Get the layer 2 input dims, accounting for ordering and possible sequence-parallelism. - # TODO: Don't rely on kwargs dimensions. - if kwargs[AttentionKwargs.sequence_first]: - dims = (kwargs[AttentionKwargs.sequence_q_dim], input_.dims[1], self._intermediate_2_dim) - else: - dims = (input_.dims[0], kwargs[AttentionKwargs.sequence_q_dim], self._intermediate_2_dim) # Also adjust the dtype in case of full-precision residual layer_2_input = TensorMeta.from_dims( - dims, tensor_name="intermediate_1", dtype=self._distributed_config.compute_dtype.torch + (input_.dims[0], self._intermediate_2_dim), + tensor_name="intermediate_1", + dtype=self._distributed_config.compute_dtype.torch, ) # TODO: Add marginal compute? (ex. activation, gate + up) @@ -141,6 +135,5 @@ def _forward( bias = self.layer_2.bias if self._parallel_dim.group else None # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) # to let _debug infer dims from actual tensor shape - dims = None if self._output_dim != self._hidden_dim else kwargs.get(BlockKwargs.hidden_dims) - self._debug(out, None, dims, kwargs, bias=bias) + self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs, bias=bias) return out, bias diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9ba1f3433..e3446bba6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -146,8 +146,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - # TODO: Option to chose whether to split in batch or sequence dimension? - # (Currently split merged batch and sequence, depends on `sequence_first`) cross_entropy_splits: int = Field( default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", @@ -274,14 +272,6 @@ class LanguageModelConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - sequence_first: bool | None = Field( - default=None, - desc="Override the default dimension ordering", - doc="By default, the hidden states are stored with dimensions (batch, sequence, ...), as it makes attention more efficient." - " However, some settings such as sequence-tensor/data/pipelineo-parallel instead require the ordering (sequence, batch, ...)." - " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", - hint=FieldHint.testing, - ) @property def layer_class(self) -> "type[LanguageModel]": diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24c..c6df8f62b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -8,6 +8,7 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs @@ -28,11 +29,6 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType - # Preprocessing - _rotary_embedding_frequencies: torch.Tensor - _position_ids: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - def __init__( self, config: ConfigType, @@ -84,7 +80,7 @@ def _forward( token_ids: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool, - embedding_map: tuple[torch.Tensor, torch.Tensor] | None, + embedding_map: torch.Tensor, ) -> torch.Tensor: Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group @@ -102,7 +98,7 @@ def _forward( if self._sequence_parallel: input_ = gather(input_, group=group, dim=0) # Out-of-place equivalent of `embeddings[embedding_map] += input_` - embeddings = embeddings.index_put(embedding_map, input_[: embedding_map[0].size(0)], accumulate=True) + embeddings = embeddings.index_put((embedding_map,), input_[: embedding_map.size(0)], accumulate=True) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) @@ -122,7 +118,7 @@ def _forward( if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: - embeddings = embeddings * token_mask.unsqueeze(2) + embeddings = embeddings * token_mask.unsqueeze(-1) if input_ is not None: # TODO: Accumulate redundant with masking? @@ -131,12 +127,12 @@ def _forward( input_ = gather(input_, group=group, dim=0) embeddings_ = embeddings.new_zeros(embeddings.shape[0] * group.size(), *embeddings.shape[1:]) embeddings_ = embeddings_.index_put( - embedding_map, input_[: embedding_map[0].size(0)], accumulate=True + (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) embeddings = embeddings + split(embeddings_, group=group, dim=0) else: embeddings = embeddings.index_put( - embedding_map, input_[: embedding_map[0].size(0)], accumulate=True + (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) with set_generator( @@ -154,7 +150,7 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[LanguageModelKwargs.hidden_dims], + (kwargs[LanguageModelKwargs.hidden_token_dim], self._hidden_dim), tensor_name=f"{self.module_name} output", dtype=self._residual_dtype, ) @@ -167,8 +163,6 @@ def forward( # TODO: Support multiple encoders. # TODO: Support pipeline-parallel. token_ids = kwargs.get(LanguageModelKwargs.token_ids) - # Drop the placeholder batch dimension, remove patch padding. - input_ = input_.squeeze(int(kwargs[LanguageModelKwargs.sequence_first])) out = self._forward( input_, @@ -178,7 +172,7 @@ def forward( kwargs.get(LanguageModelKwargs.mask_inputs), embedding_map, ) - self._debug(out, None, kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) + self._debug(out, None, (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: @@ -188,29 +182,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if not self._config.position_embeddings.enabled: return - self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], self._distributed.device) - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if not self._config.cross_document_position_embeddings: - position_ids = torch.stack( - [ - torch.cat([torch.arange(x) for x in sample_lens]) - for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] - ] - ).to(self._distributed.device, dtype=torch.int64) - position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[LanguageModelKwargs.sequence_first]: - position_ids = position_ids.transpose(0, 1) - kwargs[LanguageModelKwargs.position_ids] = position_ids + # TODO: Move to data preprocessing. + if self._config.cross_document_position_embeddings: + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + kwargs[LanguageModelKwargs.position_ids] = torch.arange( + sequence_k - sequence_q, sequence_k, device=self._distributed.device, dtype=torch.int64 + ).repeat(kwargs[LanguageModelKwargs.batch_dim].size) else: - kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ - sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) - - def _create_position_embeddings(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - Assert.leq(sequence_length, self._config.num_position_embeddings) - self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) + preprocess_for_varlen(kwargs, self._distributed.device, return_position_ids=True) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 144074ca5..c5bf9ff9b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -7,7 +7,6 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ @@ -15,8 +14,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.block import Block, BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LM_HEAD_LOSS_NAME, @@ -34,13 +32,15 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](Block[ConfigType]): +class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](BlockBase[ConfigType]): + heads: "list[LanguageModelHead]" + @abc.abstractmethod def get_output_weights(self) -> list[torch.Tensor]: pass -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType]): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType], Block): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -99,19 +99,21 @@ def __init__( loss_configs = ( self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} ) - self._losses = [ - loss_config.get_layer( - distributed_config, - self._get_full_loss_name(name), - self._prediction_distance, - self._prediction_heads, - self._vocab_parallel, - self._config.cross_entropy_splits, - self._config.logits_scale_factor, - self._loss_coefficient, - ) - for name, loss_config in loss_configs.items() - ] + self.losses = torch.nn.ModuleList( + [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in loss_configs.items() + ] + ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss) @@ -168,7 +170,12 @@ def _forward_backward( ln_output = self.final_norm(input_) # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. - self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) + self._debug( + ln_output, + "final_norm", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), + kwargs, + ) loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach(), kwargs, losses) if ln_output_grad is None: return loss, None @@ -185,18 +192,7 @@ def _logits_loss_forward_backward( if not self.training: logits, _ = self._logits_loss_forward_backward_partial(input_, kwargs, return_logits=True) - # TODO: Make a proper way of returning the model output. - logits = logits.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - logits = gather_op(logits, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - logits = gather_op( - logits, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = ( - logits.detach() - ) + self._debug(logits, "logits", (kwargs[LanguageModelKwargs.hidden_token_dim], self._vocab_dim), kwargs) return None, None input_ = input_.flatten(0, -2) @@ -230,7 +226,7 @@ def _logits_loss_forward_backward( total_loss = sum( (loss_.weight / self._config.cross_entropy_splits) * loss_dict[loss_.name] - for loss_ in self._losses + for loss_ in self.losses if loss_.weight != 0.0 and loss_.name in loss_dict ) @@ -240,7 +236,7 @@ def _logits_loss_forward_backward( if all_losses_dict is not None: all_losses_dict[self._total_loss_name].append(total_loss) - if len(self._losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._losses): + if len(self.losses) > 1 or any(loss_.weight != 1.0 for loss_ in self.losses): for name, loss_value in loss_dict.items(): if self._config.cross_entropy_splits != 1: loss_value /= self._config.cross_entropy_splits @@ -265,24 +261,19 @@ def _logits_loss_forward_backward_partial( group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q - if LanguageModelKwargs.hidden_dims in kwargs: - batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] - dims = ( - (sequence_dim, batch_dim, self._vocab_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, self._vocab_dim) - ) - else: - dims = None - self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) + self._debug( + logits, + f"logits{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) if return_logits: return logits, None losses, grad = {}, None - for loss in self._losses: + for loss in self.losses: # losses are returned unscaled but the grads are already scaled loss_value, grad = loss.forward_backward( logits, @@ -304,7 +295,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, dtype=DataType.float32, ) - for loss in self._losses + for loss in self.losses ), ] @@ -316,6 +307,6 @@ def _total_loss_name(self) -> str: return self._get_full_loss_name(LM_HEAD_LOSS_NAME) @property - def heads(self): + def heads(self) -> "list[LanguageModelHead]": # For compatibility with MTP. return [self] diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 766a5ed54..8dc88c4a1 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -11,7 +11,7 @@ from fast_llm.utils import Assert -class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType]): +class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, config: ConfigType, @@ -62,19 +62,17 @@ def _prepare_target( split_index: int = 0, *, multi_token_format: bool = False, + sequence_parallel: bool = True, ) -> torch.Tensor | None: # MTP shift if multi_token_format and self._prediction_heads > 1: - sequence_first: bool = kwargs[LanguageModelLossKwargs.sequence_first] - sequence_q_length = target.size(1 - sequence_first) + 1 - self._prediction_heads - target_slice = slice(self._prediction_distance, self._prediction_distance + sequence_q_length) - target = target[target_slice] if sequence_first else target[:, target_slice] - - # Flatten the batch and sequence dimensions. - target = target.flatten(0, 1) + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + target = target.unflatten( + 0, (kwargs[LanguageModelKwargs.batch_dim].size, sequence_q + self._prediction_heads - 1) + )[:, self._prediction_distance : self._prediction_distance + sequence_q].flatten(0, 1) # Get the local chunk. - if self._sequence_parallel: + if sequence_parallel and self._sequence_parallel: target = split_op(target, self._parallel_dim.group, 0) # Get the chunk for the current split. @@ -104,7 +102,13 @@ def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): - return self._prepare_target(kwargs[f"{reference_model}_logits"], kwargs, split_index) + assert self._prediction_distance == 0 + Assert.incl( + logits_name := self.module_name.rsplit(".", 2)[0] + f".logits", + reference_hidden_states := kwargs[f"reference_{reference_model}_hidden_states"], + ) + # The logits are already sequence-parallel if needed, we don't want to split again. + return self._prepare_target(reference_hidden_states[logits_name], kwargs, split_index, sequence_parallel=False) def loss_forward_backward( diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index ad3395a0f..5efe2d836 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -7,12 +7,12 @@ from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.head import LanguageModelHeadBase -class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](BlockBase[ConfigType]): +class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](LanguageModelHeadBase[ConfigType]): _config: ConfigType def __init__( diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index c7a2c1c59..5e721d424 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -12,7 +12,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import GatedDeltaNetConfig @@ -301,16 +300,12 @@ def _forward( - """ - sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) - if sequence_first: - projected_states_qkvz = projected_states_qkvz.transpose(0, 1) - projected_states_ba = projected_states_ba.transpose(0, 1) batch_size, sequence_length = projected_states_qkvz.shape[:2] @@ -371,8 +366,6 @@ def _forward( core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) - if sequence_first: - core_attn_out = core_attn_out.transpose(0, 1) output = self.out_proj(core_attn_out) return output diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 94cde7d5f..07ca3a997 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -11,7 +11,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig @@ -229,7 +228,6 @@ def _forward( """ Same as in gdn, the idea is to always do forward pass in a packed way, whcih is required for varlen support. """ - sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ # TODO: can be made more efficeint by rearranging hidden states directly and only once @@ -239,11 +237,6 @@ def _forward( k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - if sequence_first: - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - batch_size, sequence_length, _ = q.size() q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) @@ -257,8 +250,6 @@ def _forward( v = self._apply_conv(v, self.v_conv, seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - if sequence_first: - g_kernel = g_kernel.transpose(0, 1) g_kernel = self._reshape_heads(g_kernel) g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) @@ -268,8 +259,6 @@ def _forward( q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) - if sequence_first: - beta = beta.transpose(0, 1) beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md @@ -290,14 +279,10 @@ def _forward( g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim g_out = self._reshape_heads(g_out) - if sequence_first: - g_out = g_out.transpose(0, 1) attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) attn_out = self.norm(attn_out, g_out) attn_out = rearrange(attn_out, "b s h d -> b s (h d)") - if sequence_first: - attn_out = attn_out.transpose(0, 1) attn_out = self.o_proj(attn_out) return attn_out diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 81b82d08e..fd6255e6c 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -166,16 +166,11 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, local_inner_projection) - inner_projection = self.in_proj(input_) - dt = self.dt_proj(self.dt_in_proj(input_)) - # Standardize to (batch, sequence, local_inner_projection) - if kwargs[BlockKwargs.sequence_first]: - inner_projection = inner_projection.transpose(0, 1) - dt = dt.transpose(0, 1) - - sequence_length = inner_projection.size(1) + sequence_length = kwargs[BlockKwargs.sequence_q_dim].size + token_shape = (kwargs[BlockKwargs.batch_dim].size, kwargs[BlockKwargs.sequence_q_dim].size) + # inner_projection : (local_tokens, hidden) -> (batch, sequence, local_inner_projection) + inner_projection = self.in_proj(input_).unflatten(0, token_shape) + dt = self.dt_proj(self.dt_in_proj(input_)).unflatten(0, token_shape) z, x, b, c = torch.split( inner_projection, @@ -245,13 +240,10 @@ def _forward( # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] - if kwargs[BlockKwargs.sequence_first]: - # TODO: Is contiguous needed? - y = y.transpose(0, 1).contiguous() - # (batch/sequence, sequence/batch, local_heads * state) - # -> (batch/local_sequence, local_sequence/batch, hidden) - out, bias = self.out_proj(y) - self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + # (batch, sequence, local_heads * state) + # -> (local_tokens, hidden) + out, bias = self.out_proj(y.flatten(0, 1)) + self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/vision/embeddings.py b/fast_llm/layers/vision/embeddings.py index 2076f72e5..0b0434f56 100644 --- a/fast_llm/layers/vision/embeddings.py +++ b/fast_llm/layers/vision/embeddings.py @@ -6,7 +6,6 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionKwargs @@ -60,17 +59,13 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[VisionKwargs.hidden_dims], + (kwargs[VisionKwargs.hidden_token_dim], self._hidden_dim), tensor_name="Patch convolution output", dtype=self._residual_dtype, ) if self._sequence_parallel: input_ = split(input_, group=self._parallel_dim.group, dim=0) - out = ( - self.normalization(self.patch_embeddings(input_.flatten(1))) - .unsqueeze(int(kwargs[AttentionKwargs.sequence_first])) - .to(self._residual_dtype) - ) - self._debug(out, None, kwargs.get(VisionKwargs.hidden_dims), kwargs) + out = self.normalization(self.patch_embeddings(input_.flatten(1))).to(self._residual_dtype) + self._debug(out, None, (kwargs.get(VisionKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index a418c3fb5..387610a46 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -135,20 +135,20 @@ def _inner_forward( # The transformers will save the present keys and values to this list. kwargs[AttentionKwargs.presents] = [] - kwargs["global_logits"] = True - self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - if kwargs[AttentionKwargs.sequence_first]: - logits = kwargs["logits"].transpose(0, 1) - else: - logits = kwargs["logits"] + # TODO: Handle MTP. + logits_meta, logits = kwargs[AttentionKwargs.hidden_states]["head.logits"] + logits, _ = logits_meta.local_to_global(logits) + logits = logits.unflatten( + 0, (kwargs[AttentionKwargs.batch_dim].global_size, kwargs[AttentionKwargs.sequence_q_dim].global_size) + ) if output_hidden_states: hidden_states = { key: tensor if meta is None else meta.local_to_global(tensor)[0] - for key, (meta, tensor) in kwargs["hidden_states"].items() + for key, (meta, tensor) in kwargs[AttentionKwargs.hidden_states].items() } else: hidden_states = None diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd2932984..cabcdc489 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,3 +1,4 @@ +import functools import logging import re import typing @@ -72,35 +73,29 @@ def preprocess_meta( micro_sequence_length, self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) - hidden_sequence_q_dim = ( - TensorDim( - BlockDimNames.sequence_q_tp, - micro_sequence_length, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), + token_dim = TensorDim( + "token", + batch_dim.global_size * sequence_q_dim.global_size, + self._distributed_config.get_distributed_dim(DistributedDimNames.data), + ) + # The token dimension as appears in hidden states, i.e. with possible sequence-tensor-parallel split. + hidden_token_dim = ( + ( + "token_tp", + token_dim.global_size, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), ) if self._distributed_config.sequence_tensor_parallel - else sequence_q_dim - ) - - need_sequence_first = hidden_sequence_q_dim.size != sequence_length - if self._config.sequence_first is None: - sequence_first = need_sequence_first - else: - sequence_first = self._config.sequence_first - assert not (need_sequence_first and not sequence_first) - - hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, self._hidden_dim) - if sequence_first - else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) + else token_dim ) common_kwargs = { LanguageModelKwargs.phase: phase, - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.hidden_dims: hidden_dims, AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.batch_dim: batch_dim, AttentionKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.token_dim: token_dim, + AttentionKwargs.hidden_token_dim: hidden_token_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -122,7 +117,7 @@ def preprocess_meta( sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( - hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 + (token_dim,), tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 ) kwargs = { @@ -131,16 +126,18 @@ def preprocess_meta( } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( - hidden_dims[:2], tensor_name="labels", dtype=torch.int64 + (token_dim,), tensor_name="labels", dtype=torch.int64 ) reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - AttentionKwargs.sequence_first, AttentionKwargs.sequence_length, + AttentionKwargs.batch_dim, AttentionKwargs.sequence_q_dim, AttentionKwargs.sequence_k_dim, + AttentionKwargs.token_dim, + AttentionKwargs.hidden_token_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -158,6 +155,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -167,44 +165,18 @@ def preprocess_batch( if preprocessed_meta is None: preprocessed_meta = self.preprocess_meta(batch, phase) - distillation_models = self._config.decoder.get_reference_models() - # TODO: Support multiple distillation models? - assert len(distillation_models) <= 1 - reference_logits = [{} for _ in preprocessed_meta] + reference_preprocessed_batches = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] - - # Set output_hidden_states in reference metadata before preprocessing if needed for distillation - if name in distillation_models: - reference_output_hidden_states = [r"decoder\.\d+\.mixer_output$"] - for _, ref_kwargs_meta in reference_preprocessed_meta: - ref_kwargs_meta[BlockKwargs.output_hidden_states] = [ - re.compile(pattern) for pattern in reference_output_hidden_states - ] - - reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( + reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration, ) - # TODO: Do things work with >1? - Assert.eq(len(reference_batch), len(preprocessed_meta), 1) - for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): - reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - if BlockKwargs.hidden_states in reference_kwargs and reference_kwargs[BlockKwargs.hidden_states]: - # Extract activations from hidden_states dict (stored by _debug method) - # Format: {layer_name: (meta, tensor), ...} - activations = { - layer_name: tensor - for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() - } - reference_logits[i][f"{name}_activations"] = activations - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): @@ -217,50 +189,66 @@ def preprocess_batch( pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - # Create activation mask for activation distillation - # This mask should: - # - Be 0 on padding tokens (added at the end when documents aren't truncated) - # - Be 1 on image placeholder tokens (token value -100 but not padding) - # - Be 1 on all other valid tokens (ignores loss-masking-spans) - # - # Note: Padding is added as a separate document with all tokens = -100 - # We detect padding by checking if all tokens in a document segment are -100 - activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) - - for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): - # Iterate through documents in this sample - pos = 0 - for doc_length in sample_lengths: - # Check if this document is padding (all tokens are -100) - doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] - is_padding_doc = torch.all(doc_tokens == -100).item() - - if is_padding_doc: - # This is a padding document, mask it out - activation_mask[sample_index, pos : pos + doc_length] = False - - pos += doc_length - kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, BlockKwargs.iteration: iteration, AttentionKwargs.sequence_lengths: cropped_tokens.lengths, - BlockKwargs.activation_mask: activation_mask, AttentionKwargs.device: self._distributed.device, + BlockKwargs.output_hidden_states: [], BlockKwargs.hidden_states: {}, - **reference_logits[i], } + if extra_kwargs is not None: + Assert.empty(kwargs.keys() & extra_kwargs.keys()) + kwargs.update(extra_kwargs) + + # TODO: Simplify, check more carefully if needed. + if self._decoder_reference_models: + # Create activation mask for activation distillation + # This mask should: + # - Be 0 on padding tokens (added at the end when documents aren't truncated) + # - Be 1 on image placeholder tokens (token value -100 but not padding) + # - Be 1 on all other valid tokens (ignores loss-masking-spans) + # + # Note: Padding is added as a separate document with all tokens = -100 + # We detect padding by checking if all tokens in a document segment are -100 + activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) + + for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): + # Iterate through documents in this sample + pos = 0 + for doc_length in sample_lengths: + # Check if this document is padding (all tokens are -100) + doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] + is_padding_doc = torch.all(doc_tokens == -100).item() + + if is_padding_doc: + # This is a padding document, mask it out + activation_mask[sample_index, pos : pos + doc_length] = False + + pos += doc_length + + kwargs[BlockKwargs.activation_mask] = activation_mask.flatten() + + for name, reference_model in self._reference_models.items(): + reference_tokens, reference_kwargs = reference_preprocessed_batches[name][i] + if name in self._decoder_reference_models: + # TODO: Get the actual names + reference_kwargs[BlockKwargs.output_hidden_states].append( + re.compile(r"decoder\.\d+\.mixer_output$") + ) - # Add activation-distillation targets - assert len(distillation_models) <= 1 - for distillation_model in distillation_models: - teacher_key = f"{distillation_model}_activations" - if teacher_key in reference_logits[i]: - kwargs[BlockKwargs.activation_distillation_targets] = reference_logits[i].pop(teacher_key) + reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - if phase != PhaseType.inference: + kwargs[f"reference_{name}_hidden_states"] = { + layer_name: tensor + for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() + } + + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) + else: labels_begin = tokens_begin + 1 labels_end = tokens_end + self._config.head.max_prediction_distance labels = batch.tokens.crop(labels_begin, labels_end).tokens @@ -273,18 +261,13 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False labels = torch.where(loss_mask, labels, -100) + labels = labels.flatten(0, 1) + kwargs[LanguageModelKwargs.labels] = labels + if self._config.head.get_reference_models(): # loss masks only used for distillation currently # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 - kwargs[LanguageModelKwargs.labels] = ( - labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels - ).contiguous() - if LanguageModelKwargs.loss_mask in kwargs and kwargs[AttentionKwargs.sequence_first]: - kwargs[LanguageModelKwargs.loss_mask] = ( - kwargs[LanguageModelKwargs.loss_mask].transpose(0, 1).contiguous() - ) - if batch.chosen_spans is not None: kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges @@ -293,11 +276,7 @@ def preprocess_batch( labels_begin, labels_end ).ranges - tokens = ( - cropped_tokens.tokens.transpose(0, 1) - if kwargs[AttentionKwargs.sequence_first] - else cropped_tokens.tokens - ).contiguous() + tokens = cropped_tokens.tokens.flatten(0, 1) self.preprocess(kwargs) preprocessed.append((tokens, kwargs)) @@ -310,6 +289,19 @@ def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]] output_weights.insert(0, self.embeddings.word_embeddings_weight) return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} + @functools.cached_property + def _decoder_reference_models(self) -> set[str]: + out = self._config.decoder.get_reference_models() + Assert.leq(out, self._reference_models.keys()) + Assert.leq(len(out), 1) + return out + + @functools.cached_property + def _head_reference_models(self) -> set[str]: + out = self._config.head.get_reference_models() + Assert.leq(out, self._reference_models.keys()) + return out + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): # TODO: Can we drop class? diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760e..e90bd4d89 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -9,7 +9,6 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel @@ -97,54 +96,48 @@ def preprocess_meta( # TODO: What about sequence data? batch_data_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - micro_sequence_length = tokens.global_shape.numel() - - batch_and_sequence_q_dim = PatchSequenceTensorDim( - BlockDimNames.sequence_q, - micro_sequence_length, + token_dim = PatchSequenceTensorDim( + "token", + kwargs[VisionKwargs.token_dim].global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.data), batch_data_dim, ) - hidden_batch_and_sequence_q_dim = ( + hidden_token_dim = ( PatchSequenceTensorDim( - BlockDimNames.sequence_q_tp, - micro_sequence_length, + "token_tp", + token_dim.global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), batch_data_dim, ) if self._distributed_config.sequence_tensor_parallel - else batch_and_sequence_q_dim + else token_dim ) # These are used by the model (preprocessing) and shouldn't see the batch-parallel dim. sequence_q_dim = TensorDim( - BlockDimNames.sequence_q, - micro_sequence_length, + "sequence_q", + token_dim.global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) - sequence_k_dim = TensorDim(BlockDimNames.sequence_k, micro_sequence_length) + TensorDim("sequence_k", token_dim.global_size) image_patches = TensorMeta.from_dims( ( # We combine the batch and sequence dims to allow for variable sequence lengths. # Gives the same result, assuming we disable cross-image attention (TODO: Enforce) - batch_and_sequence_q_dim, + token_dim, # TODO: Relate to tensor dims in patch convolution. TensorDim("input_channels", self._config.vision_encoder.embeddings.input_channels), TensorDim("patch_height", self._config.vision_encoder.embeddings.patch_height), TensorDim("patch_width", self._config.vision_encoder.embeddings.patch_width), ) ) - # Use vision encoder's internal hidden dim (for embeddings/encoder), not the output dim (for adapter) - hidden_dims = ( - (hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._vision_hidden_dim) - if (sequence_first := kwargs[LanguageModelKwargs.sequence_first]) - else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._vision_hidden_dim) - ) kwargs[self._vision_encoder_namespace] = { - VisionKwargs.sequence_first: sequence_first, - VisionKwargs.sequence_k_dim: sequence_k_dim, - VisionKwargs.sequence_q_dim: sequence_q_dim, - VisionKwargs.hidden_dims: hidden_dims, + VisionKwargs.sequence_length: kwargs[VisionKwargs.sequence_length], + VisionKwargs.batch_dim: scalar_dim, + VisionKwargs.sequence_q_dim: token_dim, + VisionKwargs.sequence_k_dim: token_dim, + VisionKwargs.token_dim: token_dim, + VisionKwargs.hidden_token_dim: hidden_token_dim, } preprocessed_meta.append((image_patches, kwargs)) @@ -159,9 +152,10 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: preprocessed = super().preprocess_batch( - batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics + batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics, extra_kwargs=extra_kwargs ) # TODO: Support micro-sequences. assert len(preprocessed) == 1, "Micro-sequences not supported for MultiModalModel." @@ -194,22 +188,17 @@ def preprocess_batch( VisionKwargs.sequence_lengths: [cropped_image_patches.lengths + [pad_size]], VisionKwargs.sequence_length: sequence_length, VisionKwargs.device: self._distributed.device, - BlockKwargs.output_hidden_states: kwargs.get(BlockKwargs.output_hidden_states, []), - BlockKwargs.hidden_states: kwargs[BlockKwargs.hidden_states], + VisionKwargs.output_hidden_states: kwargs.get(VisionKwargs.output_hidden_states, []), + VisionKwargs.hidden_states: kwargs[VisionKwargs.hidden_states], } # We need to modify `local_unpadded_size` directly in `preprocessed_meta` since it's the one used by the engine. # Unsafe, but only needed for testing. # TODO: Doesn't work with gradient accumulation (only sees the last value). - hidden_batch_and_sequence_q_dim = kwargs[self._vision_encoder_namespace][VisionKwargs.hidden_dims][ - 0 if kwargs[self._vision_encoder_namespace][VisionKwargs.sequence_first] else 1 - ] - assert isinstance(hidden_batch_and_sequence_q_dim, PatchSequenceTensorDim) PatchSequenceTensorDim.local_unpadded_size = cropped_image_patches.patches.size(0) kwargs[LanguageModelKwargs.embedding_map] = ( - (cropped_image_patches.token_map, cropped_image_patches.sample_map) - if kwargs[LanguageModelKwargs.sequence_first] - else (cropped_image_patches.sample_map, cropped_image_patches.token_map) + cropped_image_patches.sample_map * kwargs[VisionKwargs.sequence_q_dim].size + + cropped_image_patches.token_map ) super().preprocess(kwargs) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ee3e0e2e1..b1a922099 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,6 +6,7 @@ import torch from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -28,7 +29,6 @@ class LMHeadTestConfig: logits_scale_factor: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False - sequence_first: bool = False loss_masking: bool = False prediction_heads: int = 1 tied_embedding_weight: bool = False @@ -88,22 +88,15 @@ def get_config(self) -> GPTModelConfig: def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device = "cuda" if torch.cuda.is_available() else "cpu" input_ = torch.randn( - ( - (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) - if self.sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE) - ), + (BATCH_SIZE * SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), device=device, requires_grad=True, ) - label_shape = ( - (SEQUENCE_LENGTH + self.prediction_heads - 1, BATCH_SIZE) - if self.sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + self.prediction_heads - 1) - ) + label_shape = (BATCH_SIZE * (SEQUENCE_LENGTH + self.prediction_heads - 1),) kwargs: dict[str, typing.Any] = { - AttentionKwargs.sequence_first: self.sequence_first, + AttentionKwargs.batch_dim: TensorDim("batch", BATCH_SIZE), + AttentionKwargs.sequence_q_dim: TensorDim("sequence_q", SEQUENCE_LENGTH), AttentionKwargs.grad_output: 1.0, } if self.loss_masking: @@ -122,11 +115,13 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: if self.distillation_loss is not False: assert self.prediction_heads == 1 - kwargs[f"distillation_logits"] = torch.randn( - input_.shape[:-1] + (VOCAB_SIZE,), - dtype=input_.dtype, - device=device, - ) + kwargs[f"reference_distillation_hidden_states"] = { + "head.logits": torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=device, + ) + } return input_, kwargs def get_reference_outputs( @@ -153,28 +148,25 @@ def get_reference_outputs( losses = {} if self.actual_label_loss is not False: - if self.sequence_first: - labels = kwargs[LanguageModelKwargs.labels][ - head._prediction_distance : head._prediction_distance + logits.size(0) - ] - else: - labels = kwargs[LanguageModelKwargs.labels][ - :, head._prediction_distance : head._prediction_distance + logits.size(1) + labels = ( + kwargs[LanguageModelKwargs.labels] + .view(BATCH_SIZE, (SEQUENCE_LENGTH + self.prediction_heads - 1))[ + :, head._prediction_distance : head._prediction_distance + SEQUENCE_LENGTH ] - label_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), labels.flatten(), reduction="none" - ).mean() + .flatten() + ) + label_loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none").mean() losses["label"] = label_loss.detach() total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: distillation_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), - torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1), + logits, + torch.softmax(kwargs[f"reference_distillation_hidden_states"]["head.logits"], -1), reduction="none", ) if LanguageModelKwargs.loss_mask in kwargs: - distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() + distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask] distillation_loss = distillation_loss.mean() losses["distillation"] = distillation_loss.detach() total_loss = total_loss + float(self.distillation_loss) * distillation_loss @@ -220,7 +212,6 @@ def _add_configs(base_name: str, **kwargs): _add_configs("default") _add_configs("bfloat16", compute_dtype=DataType.bfloat16) _add_configs("full_precision_residual", full_precision_residual=True) -_add_configs("sequence_first", sequence_first=True) _add_configs("logit_scaling", logits_scale_factor=5.0) _add_configs("tied_embedding_weight", tied_embedding_weight=True) _add_configs("multi_token_prediction", prediction_heads=2) @@ -240,7 +231,7 @@ def _add_configs(base_name: str, **kwargs): for _lm_head_test_config in _lm_head_test_configs ], ) -def test_lm_head(test_config): +def test_lm_head(test_config: LMHeadTestConfig): model_config = test_config.get_config() model, distributed = get_base_model(model_config) input_, kwargs = test_config.get_inputs() diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index c12fe52e9..d096b4af3 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -69,9 +69,7 @@ def _compare_mixers( sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { BlockKwargs.device: distributed.device, - BlockKwargs.sequence_first: False, BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), BlockKwargs.sequence_q_dim: TensorDim("", SEQ_LEN), BlockKwargs.sequence_k_dim: TensorDim("", SEQ_LEN), } diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index bc538f9a0..d31cffa50 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -45,8 +45,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): """ Check that Gated Delta Net forward/backward match with and without packing. """ - hidden_size = 32 - hidden_dim = TensorDim("hidden", hidden_size) + hidden_dim = TensorDim("hidden", hidden_size := 32) distributed = Distributed( distributed_config := DistributedConfig(compute_dtype=DataType.float16, use_cuda=torch.cuda.is_available()) ) @@ -68,20 +67,19 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): kwargs = { BlockKwargs.device: distributed.device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), } kwargs_packed = { **kwargs, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_length: seq_len, + BlockKwargs.batch_dim: TensorDim("", batch_size), BlockKwargs.sequence_q_dim: TensorDim("", seq_len), BlockKwargs.sequence_k_dim: TensorDim("", seq_len), } mixer.preprocess(kwargs_packed) - out_packed, context = stage.forward(hidden_states, kwargs_packed) + out_packed, context = stage.forward(hidden_states.flatten(0, 1), kwargs_packed) stage.backward(torch.ones_like(out_packed), context) names, parameters = zip(*list(mixer.named_parameters())) @@ -97,14 +95,15 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): **kwargs, BlockKwargs.sequence_lengths: [[seq_len_]], BlockKwargs.sequence_length: seq_len_, + BlockKwargs.batch_dim: TensorDim("", 1), BlockKwargs.sequence_q_dim: TensorDim("", seq_len_), BlockKwargs.sequence_k_dim: TensorDim("", seq_len_), } mixer.preprocess(kwargs_seq) - out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) + out, context = stage.forward(seq, kwargs_seq) stage.backward(torch.ones_like(out), context) out_refs.append(out) - out_ref = torch.cat(out_refs, dim=1).view_as(out_packed) + out_ref = torch.cat(out_refs, dim=0).view_as(out_packed) Assert.rms_close_relative(out_packed, out_ref, 1e-3, 1e-4) diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index 8c131dfa7..cdf2295e0 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -220,13 +220,7 @@ def test_all_padding_sample(self): labels = kwargs[LanguageModelKwargs.labels] # Get labels for sample 1 (all should be -100) - # Handle sequence_first dimension ordering - if labels.shape[0] > labels.shape[1]: - # sequence_first=True: shape is (seq, batch) - sample1_labels = labels[:, 1] - else: - # sequence_first=False: shape is (batch, seq) - sample1_labels = labels[1, :] + sample1_labels = labels[8:] assert torch.all(sample1_labels == -100), f"All labels in padding sample should be -100, got {sample1_labels}" diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f08e9a488..bd5a92720 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -123,14 +123,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=1, compare_config=_fp16_compare, ), - # Sequence-first baseline - DistributedTestingConfig( - name="sf", - compare="simple", - config_args=["model.base_model.sequence_first=True"], - num_gpus=1, - compare_config=_compare_layer_mismatch, - ), # Cross-entropy splits. DistributedTestingConfig( name="ce4", @@ -171,14 +163,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=1, compare_config=_compare_layer_match, ), - # Sequence-first gradient accumulation baseline. - DistributedTestingConfig( - name="df4_sf", - compare="simple", - config_args=["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"], - num_gpus=1, - compare_config=_compare_layer_mismatch, - ), ] SINGLE_GPU_TESTING_CONFIGS = {config.name: config for config in _SINGLE_GPU_TESTING_CONFIGS} @@ -221,7 +205,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Sequence-data-parallel DistributedTestingConfig( name="sdp2", - compare="sf", + compare="simple", config_args=["model.distributed.sequence_data_parallel=2"], num_gpus=2, compare_config=_compare_layer_match, @@ -238,7 +222,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple sequence-tensor-parallel DistributedTestingConfig( name="stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -260,7 +244,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Cross-entropy splits DistributedTestingConfig( name="stp2_ce4", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -274,7 +258,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple DistributedTestingConfig( name="dp2_stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -285,7 +269,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Breadth-first micro-batches DistributedTestingConfig( name="sdp2_stp2_bf4", - compare="df4_sf", + compare="df4", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -298,7 +282,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Sequence-data-parallel DistributedTestingConfig( name="sdp2_stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -358,10 +342,10 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon compare_config=_compare_layer_match, ), # ===== 2d configs (Tensor + Pipeline) - # Simple [sf, mb] + # Simple [mb] DistributedTestingConfig( name="stp2_pp2s1_bf4", - compare="df4_sf", + compare="df4", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index a4f28d14c..7b41c1f50 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -314,6 +314,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=1.5, ) update_and_add_testing_config( @@ -333,6 +334,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=1.5, ) update_and_add_testing_config( @@ -360,6 +362,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=1.0, ) del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] @@ -394,6 +397,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=1.0, ) update_and_add_testing_config(