-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Adding Tensor_layout for Tensor parallelism for Autosharding #21792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
06bb3bb
41f8025
e74eab2
2cddf39
fee036e
9bed6e4
5365f14
bc4d094
4d32e49
119ac15
7851615
4707c2b
45aa44c
8bb39f6
ab444b1
d5612eb
a777178
7b144d9
12b038a
d9eabc8
74437c9
6eeb589
207a4bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -441,6 +441,50 @@ def test_distribute_data_input(self): | |
| for shard in result.addressable_shards: | ||
| self.assertEqual(shard.data.shape, (3, 4)) | ||
|
|
||
| def test_all_reduce(self): | ||
| devices = jax.devices() | ||
| num_devices = len(devices) | ||
| input_data = np.ones((num_devices, 2), dtype="float32") | ||
|
|
||
| def sum_fn(x): | ||
| return backend_dlib.all_reduce(x, op="sum", axis_name="batch") | ||
|
|
||
| result_sum = jax.pmap(sum_fn, axis_name="batch")(input_data) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the Same for "mean". |
||
|
|
||
| expected_sum = np.full((num_devices, 2), num_devices, dtype="float32") | ||
| self.assertAllClose(result_sum, expected_sum) | ||
|
|
||
| def mean_fn(x): | ||
| return backend_dlib.all_reduce(x, op="mean", axis_name="batch") | ||
|
|
||
| result_mean = jax.pmap(mean_fn, axis_name="batch")(input_data) | ||
|
|
||
| self.assertAllClose(result_mean, input_data) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| ValueError, "Unsupported reduction operation" | ||
| ): | ||
| backend_dlib.all_reduce(input_data[0], op="max", axis_name="batch") | ||
|
|
||
| def test_all_gather(self): | ||
| devices = jax.devices() | ||
| num_devices = len(devices) | ||
|
|
||
| input_data = np.arange(num_devices, dtype="float32").reshape( | ||
| num_devices, 1, 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) | ||
|
|
||
| def gather_fn(x): | ||
| return backend_dlib.all_gather(x, axis=0, axis_name="batch") | ||
|
|
||
| results = jax.pmap(gather_fn, axis_name="batch")(input_data) | ||
buildwithsuhana marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| expected_gathered = np.arange(num_devices, dtype="float32").reshape( | ||
| num_devices, 1 | ||
| ) | ||
| for i in range(num_devices): | ||
| self.assertAllClose(results[i], expected_gathered) | ||
|
|
||
|
|
||
| class ShardingCaptureLayer(layers.Layer): | ||
| def __init__(self, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| import functools | ||
|
|
||
| from keras.src import layers | ||
| from keras.src.backend import distribution_lib | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import ( | ||
| split_tensor_for_parallelism, | ||
| ) | ||
|
|
||
|
|
||
| def analyze_dense_layer(layer): | ||
| """Classifies a Dense layer based on its input/output dimensions. | ||
|
|
||
| This function uses a heuristic to determine if a Dense layer acts as an | ||
| 'up_projection' (expansion), a 'down_projection' (contraction), or a | ||
| standard 'dense' layer. This classification is used to determine the | ||
| appropriate sharding strategy (e.g., column-parallel vs row-parallel). | ||
|
|
||
| Args: | ||
| layer: The Keras Dense layer instance to analyze. | ||
|
|
||
| Returns: | ||
| str: One of 'up_projection', 'down_projection', or 'dense'. | ||
| """ | ||
| input_dim = None | ||
| output_dim = None | ||
|
|
||
| kernel = getattr(layer, "kernel", getattr(layer, "_kernel", None)) | ||
| if kernel is not None: | ||
| if len(kernel.shape) == 2: | ||
| input_dim = kernel.shape[0] | ||
| output_dim = kernel.shape[1] | ||
|
|
||
| if output_dim is None and hasattr(layer, "units"): | ||
| output_dim = layer.units | ||
|
|
||
| if ( | ||
| input_dim is None | ||
| and hasattr(layer, "input_shape") | ||
| and layer.input_shape | ||
| and len(layer.input_shape) > 1 | ||
| ): | ||
| input_dim = layer.input_shape[-1] | ||
|
|
||
| if input_dim is None or output_dim is None: | ||
| return "dense" | ||
|
|
||
| expansion_threshold = 1.5 | ||
| is_expansion = output_dim > input_dim * expansion_threshold | ||
| is_contraction = input_dim > output_dim * expansion_threshold | ||
|
|
||
| if is_expansion: | ||
| return "up_projection" | ||
| elif is_contraction: | ||
| return "down_projection" | ||
| else: | ||
| return "dense" | ||
|
|
||
|
|
||
| def _reduce_sum(x): | ||
| """Performs an all-reduce sum operation across the 'model' mesh axis. | ||
|
|
||
| Args: | ||
| x: The input tensor to reduce. | ||
|
|
||
| Returns: | ||
| The reduced tensor, summed across all devices in the model axis. | ||
| """ | ||
| return distribution_lib.all_reduce(x, op="sum", axis_name="model") | ||
|
|
||
|
|
||
| def _gather(x, axis): | ||
| """Performs an all-gather operation across the 'model' mesh axis. | ||
|
|
||
| Args: | ||
| x: The input tensor shard to gather. | ||
| axis: The axis along which to concatenate the gathered parts. | ||
|
|
||
| Returns: | ||
| The gathered tensor, concatenated along the specified axis. | ||
| """ | ||
| return distribution_lib.all_gather(x, axis=axis, axis_name="model") | ||
|
|
||
|
|
||
| def _get_layer_path(layer): | ||
| """Retrieves the unique hierarchical path of a layer. | ||
|
|
||
| This utilizes `layer.path` (available in Keras 3+) which provides a | ||
| globally unique identifier based on the model structure (e.g., | ||
| 'model/dense_1'). Falls back to `layer.name` if the path is unavailable. | ||
|
|
||
| Args: | ||
| layer: The Keras layer instance. | ||
|
|
||
| Returns: | ||
| str: The unique path string for the layer. | ||
| """ | ||
| return getattr(layer, "path", layer.name) | ||
|
|
||
|
|
||
| def _apply_layer_sharding_rules(layer, device_count, state_rules, output_rules): | ||
| """Applies sharding rules to a single layer based on its type. | ||
|
|
||
| This function populates `state_rules` and `output_rules` with strategies | ||
| specific to the layer class (e.g., Dense, EinsumDense, Embedding). It | ||
| determines how weights should be partitioned (state rules) and how outputs | ||
| should be synchronized (output rules). | ||
|
|
||
| Args: | ||
| layer: The Keras layer instance to configure. | ||
| device_count: The number of devices available for tensor parallelism. | ||
| state_rules: A dictionary mapping variable paths to sharding functions. | ||
| Updated in-place. | ||
| output_rules: A dictionary mapping layer paths to output communication | ||
| functions. Updated in-place. | ||
| """ | ||
|
|
||
| def split_rule(dim): | ||
| return functools.partial( | ||
| split_tensor_for_parallelism, device_count=device_count, dim=dim | ||
| ) | ||
|
|
||
| def gather_rule(axis): | ||
| return functools.partial(_gather, axis=axis) | ||
|
|
||
| layer_path = _get_layer_path(layer) | ||
|
|
||
| if isinstance(layer, layers.Dense): | ||
| mlp_type = analyze_dense_layer(layer) | ||
|
|
||
| if mlp_type == "up_projection": | ||
| state_rules[layer.kernel.path] = split_rule(dim=1) | ||
| if layer.use_bias: | ||
| state_rules[layer.bias.path] = split_rule(dim=0) | ||
| output_rules[layer_path] = {0: gather_rule(axis=-1)} | ||
|
|
||
| elif mlp_type == "down_projection": | ||
| state_rules[layer.kernel.path] = split_rule(dim=0) | ||
| output_rules[layer_path] = {0: _reduce_sum} | ||
|
|
||
| else: | ||
| state_rules[layer.kernel.path] = split_rule(dim=1) | ||
| if layer.use_bias: | ||
| state_rules[layer.bias.path] = split_rule(dim=0) | ||
| output_rules[layer_path] = {0: gather_rule(axis=-1)} | ||
|
|
||
| elif isinstance(layer, layers.EinsumDense): | ||
| if "attention_output" in layer.name: # Use name check as heuristic | ||
| state_rules[layer.kernel.path] = split_rule(dim=0) | ||
| output_rules[layer_path] = {0: _reduce_sum} | ||
| else: | ||
| state_rules[layer.kernel.path] = split_rule(dim=1) | ||
| if hasattr(layer, "bias") and layer.bias is not None: | ||
| state_rules[layer.bias.path] = split_rule(dim=0) | ||
| output_rules[layer_path] = {0: gather_rule(axis=-1)} | ||
|
|
||
| elif ( | ||
| isinstance(layer, (layers.Embedding,)) | ||
| or "Embedding" in layer.__class__.__name__ | ||
| ): | ||
| if hasattr(layer, "weights"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| found_embedding = False | ||
| for weight in layer.weights: | ||
| if "embedding" in weight.name or "weight" in weight.name: | ||
| state_rules[weight.path] = split_rule(dim=1) | ||
| found_embedding = True | ||
|
|
||
| if found_embedding: | ||
| output_rules[layer_path] = {0: lambda x: x} | ||
|
|
||
|
|
||
| def get_default_config(model, device_ids): | ||
| """Generates a default tensor parallelism configuration for a model. | ||
|
|
||
| This function traverses the model's layer hierarchy and | ||
| automatically generates a `LayoutMap`. This map contains: | ||
| 1. `state_rules`: How to shard the weights of supported layers | ||
| across the specified devices. | ||
| 2. `output_rules`: How to synchronize or gather the outputs of | ||
| these layers during the forward pass. | ||
|
|
||
| Args: | ||
| model: The Keras model to configure. | ||
| device_ids: A list of device identifiers to be used | ||
| for distribution. | ||
|
|
||
| Returns: | ||
| LayoutMap: A configuration object containing `state_rules` and | ||
| `output_rules` for tensor parallelism. | ||
| """ | ||
| device_count = len(device_ids) | ||
| state_rules = {} | ||
| output_rules = {} | ||
|
|
||
| for layer in model._flatten_layers(recursive=True, include_self=True): | ||
| _apply_layer_sharding_rules( | ||
| layer, device_count, state_rules, output_rules | ||
| ) | ||
|
|
||
| return LayoutMap(state_rules=state_rules, output_rules=output_rules) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_dataneeds to to sharded across the devices to make this test valid.