Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for distribution strategy with JAX backend."""

import jax
import jax.lax as lax
import numpy as np

from keras.src.backend.common import global_state
Expand Down Expand Up @@ -212,6 +213,50 @@ def process_id():
return jax.process_index()


def all_reduce(x, op="sum", axis_name="model"):
"""Reduces a tensor across a device mesh axis using a collective.

Args:
x: The tensor to reduce.
op: The reduction operation. "sum" or "mean".
axis_name: The name of the mesh axis to reduce over.

Returns:
The reduced tensor.
"""
if op == "sum":
return lax.psum(x, axis_name=axis_name)
elif op == "mean":
return lax.pmean(x, axis_name=axis_name)
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)


def all_gather(x, axis, axis_name="model"):
"""Gathers and concatenates tensors from all devices across a mesh axis.

This function assumes it is called within a `pjit` context. It takes
the local shard `x` from each device along the `axis_name` of the mesh
and concatenates them along the specified tensor `axis` to form a
single, larger tensor that is then replicated on all participating devices.

Args:
x (jax.Array): The input JAX array (tensor) shard on the local device.
axis (int): The tensor axis along which to concatenate the gathered
shards.
axis_name (str, optional): The name of the mesh axis to gather
from. Defaults to 'model'.

Returns:
jax.Array: The full, gathered JAX array, which is identical across
all devices participating in the gather.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)


def _to_backend_device(device_name):
if isinstance(device_name, jax.Device):
return device_name
Expand Down
44 changes: 44 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,50 @@ def test_distribute_data_input(self):
for shard in result.addressable_shards:
self.assertEqual(shard.data.shape, (3, 4))

def test_all_reduce(self):
devices = jax.devices()
num_devices = len(devices)
input_data = np.ones((num_devices, 2), dtype="float32")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_data needs to to sharded across the devices to make this test valid.


def sum_fn(x):
return backend_dlib.all_reduce(x, op="sum", axis_name="batch")

result_sum = jax.pmap(sum_fn, axis_name="batch")(input_data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the pmap be part of the all_reduce implementation?

Same for "mean".


expected_sum = np.full((num_devices, 2), num_devices, dtype="float32")
self.assertAllClose(result_sum, expected_sum)

def mean_fn(x):
return backend_dlib.all_reduce(x, op="mean", axis_name="batch")

result_mean = jax.pmap(mean_fn, axis_name="batch")(input_data)

self.assertAllClose(result_mean, input_data)

with self.assertRaisesRegex(
ValueError, "Unsupported reduction operation"
):
backend_dlib.all_reduce(input_data[0], op="max", axis_name="batch")

def test_all_gather(self):
devices = jax.devices()
num_devices = len(devices)

input_data = np.arange(num_devices, dtype="float32").reshape(
num_devices, 1, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_data needs to to sharded across the devices to make this test valid.

)

def gather_fn(x):
return backend_dlib.all_gather(x, axis=0, axis_name="batch")

results = jax.pmap(gather_fn, axis_name="batch")(input_data)

expected_gathered = np.arange(num_devices, dtype="float32").reshape(
num_devices, 1
)
for i in range(num_devices):
self.assertAllClose(results[i], expected_gathered)


class ShardingCaptureLayer(layers.Layer):
def __init__(self, **kwargs):
Expand Down
200 changes: 200 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import functools

from keras.src import layers
from keras.src.backend import distribution_lib
from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap
from keras.src.distribution.tensor_parallel.tensor_layout import (
split_tensor_for_parallelism,
)


def analyze_dense_layer(layer):
"""Classifies a Dense layer based on its input/output dimensions.

This function uses a heuristic to determine if a Dense layer acts as an
'up_projection' (expansion), a 'down_projection' (contraction), or a
standard 'dense' layer. This classification is used to determine the
appropriate sharding strategy (e.g., column-parallel vs row-parallel).

Args:
layer: The Keras Dense layer instance to analyze.

Returns:
str: One of 'up_projection', 'down_projection', or 'dense'.
"""
input_dim = None
output_dim = None

kernel = getattr(layer, "kernel", getattr(layer, "_kernel", None))
if kernel is not None:
if len(kernel.shape) == 2:
input_dim = kernel.shape[0]
output_dim = kernel.shape[1]

if output_dim is None and hasattr(layer, "units"):
output_dim = layer.units

if (
input_dim is None
and hasattr(layer, "input_shape")
and layer.input_shape
and len(layer.input_shape) > 1
):
input_dim = layer.input_shape[-1]

if input_dim is None or output_dim is None:
return "dense"

expansion_threshold = 1.5
is_expansion = output_dim > input_dim * expansion_threshold
is_contraction = input_dim > output_dim * expansion_threshold

if is_expansion:
return "up_projection"
elif is_contraction:
return "down_projection"
else:
return "dense"


def _reduce_sum(x):
"""Performs an all-reduce sum operation across the 'model' mesh axis.

Args:
x: The input tensor to reduce.

Returns:
The reduced tensor, summed across all devices in the model axis.
"""
return distribution_lib.all_reduce(x, op="sum", axis_name="model")


def _gather(x, axis):
"""Performs an all-gather operation across the 'model' mesh axis.

Args:
x: The input tensor shard to gather.
axis: The axis along which to concatenate the gathered parts.

Returns:
The gathered tensor, concatenated along the specified axis.
"""
return distribution_lib.all_gather(x, axis=axis, axis_name="model")


def _get_layer_path(layer):
"""Retrieves the unique hierarchical path of a layer.

This utilizes `layer.path` (available in Keras 3+) which provides a
globally unique identifier based on the model structure (e.g.,
'model/dense_1'). Falls back to `layer.name` if the path is unavailable.

Args:
layer: The Keras layer instance.

Returns:
str: The unique path string for the layer.
"""
return getattr(layer, "path", layer.name)


def _apply_layer_sharding_rules(layer, device_count, state_rules, output_rules):
"""Applies sharding rules to a single layer based on its type.

This function populates `state_rules` and `output_rules` with strategies
specific to the layer class (e.g., Dense, EinsumDense, Embedding). It
determines how weights should be partitioned (state rules) and how outputs
should be synchronized (output rules).

Args:
layer: The Keras layer instance to configure.
device_count: The number of devices available for tensor parallelism.
state_rules: A dictionary mapping variable paths to sharding functions.
Updated in-place.
output_rules: A dictionary mapping layer paths to output communication
functions. Updated in-place.
"""

def split_rule(dim):
return functools.partial(
split_tensor_for_parallelism, device_count=device_count, dim=dim
)

def gather_rule(axis):
return functools.partial(_gather, axis=axis)

layer_path = _get_layer_path(layer)

if isinstance(layer, layers.Dense):
mlp_type = analyze_dense_layer(layer)

if mlp_type == "up_projection":
state_rules[layer.kernel.path] = split_rule(dim=1)
if layer.use_bias:
state_rules[layer.bias.path] = split_rule(dim=0)
output_rules[layer_path] = {0: gather_rule(axis=-1)}

elif mlp_type == "down_projection":
state_rules[layer.kernel.path] = split_rule(dim=0)
output_rules[layer_path] = {0: _reduce_sum}

else:
state_rules[layer.kernel.path] = split_rule(dim=1)
if layer.use_bias:
state_rules[layer.bias.path] = split_rule(dim=0)
output_rules[layer_path] = {0: gather_rule(axis=-1)}

elif isinstance(layer, layers.EinsumDense):
if "attention_output" in layer.name: # Use name check as heuristic
state_rules[layer.kernel.path] = split_rule(dim=0)
output_rules[layer_path] = {0: _reduce_sum}
else:
state_rules[layer.kernel.path] = split_rule(dim=1)
if hasattr(layer, "bias") and layer.bias is not None:
state_rules[layer.bias.path] = split_rule(dim=0)
output_rules[layer_path] = {0: gather_rule(axis=-1)}

elif (
isinstance(layer, (layers.Embedding,))
or "Embedding" in layer.__class__.__name__
):
if hasattr(layer, "weights"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights is a property of Layer, it cannot be missing, remove this if.

found_embedding = False
for weight in layer.weights:
if "embedding" in weight.name or "weight" in weight.name:
state_rules[weight.path] = split_rule(dim=1)
found_embedding = True

if found_embedding:
output_rules[layer_path] = {0: lambda x: x}


def get_default_config(model, device_ids):
"""Generates a default tensor parallelism configuration for a model.

This function traverses the model's layer hierarchy and
automatically generates a `LayoutMap`. This map contains:
1. `state_rules`: How to shard the weights of supported layers
across the specified devices.
2. `output_rules`: How to synchronize or gather the outputs of
these layers during the forward pass.

Args:
model: The Keras model to configure.
device_ids: A list of device identifiers to be used
for distribution.

Returns:
LayoutMap: A configuration object containing `state_rules` and
`output_rules` for tensor parallelism.
"""
device_count = len(device_ids)
state_rules = {}
output_rules = {}

for layer in model._flatten_layers(recursive=True, include_self=True):
_apply_layer_sharding_rules(
layer, device_count, state_rules, output_rules
)

return LayoutMap(state_rules=state_rules, output_rules=output_rules)
Loading