Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def preprocess_batch(
phase: PhaseType,
iteration: int,
metrics: dict | None = None,
extra_kwargs: dict[str, typing.Any] | None = None,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass
Expand Down
14 changes: 7 additions & 7 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,15 @@ def _preprocess_data(
phase=context.phase,
iteration=context.iteration,
metrics=context.metrics,
extra_kwargs={
"grad_output": grad_output,
"micro_batch": micro_batch,
"num_micro_batches": batch_config.sequential_micro_batches,
"micro_batch_splits": batch_config.micro_batch_splits,
},
)
for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data):
kwargs.update(
grad_output=grad_output,
micro_batch=micro_batch,
micro_batch_split=micro_batch_split,
num_micro_batches=batch_config.sequential_micro_batches,
micro_batch_splits=batch_config.micro_batch_splits,
)
kwargs.update(micro_batch_split=micro_batch_split)
data_index = context.schedule.get_data_index(micro_batch, micro_batch_split)
if self._stages_owned[0]:
context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/functional/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.T

class AuxiliaryLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa
ctx.grad = torch.full_like(aux_loss, grad)
def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float | None) -> torch.Tensor: # noqa
ctx.grad = None if grad is None else torch.full_like(aux_loss, grad)
return input_

@staticmethod
Expand Down
75 changes: 25 additions & 50 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from fast_llm.core.distributed import set_generator
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op
from fast_llm.engine.base_model.config import ResourceUsageConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.initialization import init_normal_
Expand All @@ -12,7 +12,6 @@
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs
from fast_llm.layers.attention.preprocessing import preprocess_for_varlen
from fast_llm.layers.block.config import BlockDimNames
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.tensor import TensorMeta
Expand Down Expand Up @@ -113,7 +112,7 @@ def __init__(
CompositeTensorDim("value", (head_group_dim, head_size_dim)),
),
)
dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim))
self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim))

self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power)

Expand Down Expand Up @@ -152,7 +151,7 @@ def __init__(

# Output.
self.dense = self._config.dense_layer.get_layer(
dense_dim,
self._dense_dim,
hidden_dim,
default_weight_initialization=init_normal_(std=self._hidden_size**-0.5),
default_add_bias=self._config.add_linear_biases,
Expand All @@ -163,22 +162,13 @@ def __init__(

# Debug dims
self._query_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
CompositeTensorDim("heads", (head_group_dim, group_heads_dim)),
head_size_dim,
)
self._kv_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
head_group_dim,
head_size_dim,
)
self._context_dims = (
BlockDimNames.batch,
BlockDimNames.sequence_q,
dense_dim,
)

def _attn_backup(
self,
Expand Down Expand Up @@ -269,7 +259,7 @@ def _attn_flash(
)

def _query_key_value_forward(
self, input_: torch.Tensor, sequence_first: bool
self, input_: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]:
key_value, key_value_context = self.key_value.forward_only(input_)

Expand All @@ -292,10 +282,7 @@ def _query_key_value_forward(
if handle:
handle.wait()

if self._sequence_data_parallel_dim.group and not sequence_first:
key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1)

context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first}
context = {"query": query_context, "key_value": key_value_context}
return query, key_value, context

def _query_key_value_backward(
Expand All @@ -305,7 +292,7 @@ def _query_key_value_backward(
key_value_grad, handle = reduce_scatter_op(
key_value_grad,
group=self._sequence_data_parallel_dim.group,
dim=1 - context["sequence_first"],
dim=0,
async_op=True,
)

Expand All @@ -331,15 +318,19 @@ def _forward(
losses: dict[str, typing.Any] | None = None,
metrics: dict[str, typing.Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
sequence_first = kwargs[AttentionKwargs.sequence_first]
query, key_value = self._query_key_value(input_, sequence_first)
query, key_value = self._query_key_value(input_)

# Separate the batch and sequence dimensions
token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim])
token_shape = tuple(dim.size for dim in token_dims)
query = query.unflatten(0, token_shape)
key_value = key_value.unflatten(0, token_shape)

# TODO: Move the rest to function.

if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None:
assert sequence_first
# Clear the lists so tensors can be de-allocated
key_value = torch.cat((past_key_values.pop(0), key_value), dim=0)
key_value = torch.cat((past_key_values.pop(0), key_value), dim=1)

if (presents := kwargs.get(AttentionKwargs.presents)) is not None:
# Return the presents as a leaf tensors so the gradients from later micro-sequences
Expand All @@ -348,26 +339,15 @@ def _forward(
# Manually add the gradients from later micro-sequences.
key_value = AttachGrad.apply(key_value, present)

if self._sequence_data_parallel_dim.group:
key_value = (
key_value[: kwargs[AttentionKwargs.sequence_k_dim].size]
if sequence_first
else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size]
)

if sequence_first:
# TODO: Optimize (is contiguous avoidable?)
query = query.transpose(0, 1).contiguous()
key_value = key_value.transpose(0, 1).contiguous()

key_value = key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size]
key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1)

query = query.view(*query.shape[:2], self._local_heads, self._config.head_size)
key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size)
value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size)

self._debug(query, "query_rotary_input", self._query_dims, kwargs)
self._debug(key, "key_rotary_input", self._kv_dims, kwargs)
self._debug(query, "query_rotary_input", token_dims + self._query_dims, kwargs)
self._debug(key, "key_rotary_input", token_dims + self._kv_dims, kwargs)
query, key = self._rotary(query, key, kwargs)

with set_generator(self._distributed.tp_generator):
Expand All @@ -379,22 +359,17 @@ def _forward(
else:
raise NotImplementedError(self._implementation)

self._debug(query, "query", self._query_dims, kwargs)
self._debug(key, "key", self._kv_dims, kwargs)
self._debug(value, "value", self._kv_dims, kwargs)
self._debug(input_, "context", self._context_dims, kwargs)
self._debug(query, "query", token_dims + self._query_dims, kwargs)
self._debug(key, "key", token_dims + self._kv_dims, kwargs)
self._debug(value, "value", token_dims + self._kv_dims, kwargs)
self._debug(input_, "context", token_dims + (self._dense_dim,), kwargs)

if sequence_first:
# TODO: Optimize (is contiguous avoidable? Transpose dense output?)
input_ = input_.transpose(0, 1).contiguous()
out, bias = self.dense(input_)
self._debug(out, None, kwargs.get(AttentionKwargs.hidden_dims), kwargs)
out, bias = self.dense(input_.flatten(0, 1))
self._debug(out, None, token_dims + (self._hidden_dim,), kwargs)
return out, bias

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
batch_dim: TensorDim = kwargs[AttentionKwargs.hidden_dims][1 if kwargs[AttentionKwargs.sequence_first] else 0]

# Using this one since `hidden_dims` may be sequence-tensor-parallel, and attention is not.
batch_dim: TensorDim = kwargs[AttentionKwargs.batch_dim]
sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim]
sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim]

Expand Down Expand Up @@ -435,7 +410,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c
partly_out_of_window = max(sequence_k - fully_out_of_window - self._config.window_size, 0)
attention_compute -= (partly_out_of_window * (partly_out_of_window + 1) * attn_compute_base) // 2

dense_input = TensorMeta.from_dims((batch_dim, sequence_q_dim, self._context_dims[-1]))
dense_input = TensorMeta.from_dims((*input_.dims[:-1], self._dense_dim))

# TODO: Add marginal compute? (ex. softmax)
return sum(
Expand Down
23 changes: 5 additions & 18 deletions fast_llm/layers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __call__(
if bias is not None:
assert tensor is not None
tensor = tensor + bias
meta = self._get_meta(tensor, name, dims, kwargs)
meta = self._get_meta(tensor, name, dims)

if output_hidden_state:
kwargs[BlockKwargs.hidden_states][name] = (meta, tensor)
Expand All @@ -68,34 +68,21 @@ def __call__(
"",
tensor,
level=level,
meta=self._get_meta(tensor, name + f"{name}.grad", dims, kwargs),
meta=self._get_meta(tensor, name + f"{name}.grad", dims),
**logging_kwargs,
)

def _get_meta(
self,
tensor: torch.Tensor | None,
name: str,
dims: tuple[TensorDim | str, ...] | None,
kwargs: dict[str, typing.Any],
) -> TensorMeta | None:
if tensor is None:
return None
dims: tuple[TensorDim | str | None, ...] | None,
) -> TensorMeta:
if dims is None:
dims = tuple(f"dim_{i}" for i in range(tensor.ndim))
hidden_dims = {}
if BlockKwargs.hidden_dims in kwargs:
for dim in kwargs[BlockKwargs.hidden_dims]:
hidden_dims[dim.name] = dim
if BlockKwargs.sequence_q_dim in kwargs:
hidden_dims[kwargs[BlockKwargs.sequence_q_dim].name] = kwargs[BlockKwargs.sequence_q_dim]
return TensorMeta.from_dims(
tuple(
(
dim
if isinstance(dim, TensorDim)
else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i))
)
(dim if isinstance(dim, TensorDim) else TensorDim(f"dim_{i}" if dim is None else dim, tensor.size(i)))
for i, dim in enumerate(dims)
),
tensor_name=name,
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ class BlockDimNames:


class BlockKwargs:
sequence_first = "sequence_first"
hidden_dims = "hidden_dims"
batch_dim = "batch_dim"
sequence_q_dim = "sequence_q_dim"
sequence_k_dim = "sequence_k_dim"
token_dim = "token_dim"
hidden_token_dim = "hidden_token_dim"
# TODO: These are confusing
sequence_length = "sequence_length"
sequence_lengths = "sequence_lengths"
Expand Down
Loading