diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 1407c008910e..26d04a15eac3 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -1,6 +1,7 @@ """Utilities for distribution strategy with JAX backend.""" import jax +import jax.lax as lax import numpy as np from keras.src.backend.common import global_state @@ -212,6 +213,50 @@ def process_id(): return jax.process_index() +def all_reduce(x, op="sum", axis_name="model"): + """Reduces a tensor across a device mesh axis using a collective. + + Args: + x: The tensor to reduce. + op: The reduction operation. "sum" or "mean". + axis_name: The name of the mesh axis to reduce over. + + Returns: + The reduced tensor. + """ + if op == "sum": + return lax.psum(x, axis_name=axis_name) + elif op == "mean": + return lax.pmean(x, axis_name=axis_name) + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + +def all_gather(x, axis, axis_name="model"): + """Gathers and concatenates tensors from all devices across a mesh axis. + + This function assumes it is called within a `pjit` context. It takes + the local shard `x` from each device along the `axis_name` of the mesh + and concatenates them along the specified tensor `axis` to form a + single, larger tensor that is then replicated on all participating devices. + + Args: + x (jax.Array): The input JAX array (tensor) shard on the local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. + axis_name (str, optional): The name of the mesh axis to gather + from. Defaults to 'model'. + + Returns: + jax.Array: The full, gathered JAX array, which is identical across + all devices participating in the gather. + """ + return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True) + + def _to_backend_device(device_name): if isinstance(device_name, jax.Device): return device_name diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 3ee3a2bc91b7..25fd3e65da7f 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -441,6 +441,50 @@ def test_distribute_data_input(self): for shard in result.addressable_shards: self.assertEqual(shard.data.shape, (3, 4)) + def test_all_reduce(self): + devices = jax.devices() + num_devices = len(devices) + input_data = np.ones((num_devices, 2), dtype="float32") + + def sum_fn(x): + return backend_dlib.all_reduce(x, op="sum", axis_name="batch") + + result_sum = jax.pmap(sum_fn, axis_name="batch")(input_data) + + expected_sum = np.full((num_devices, 2), num_devices, dtype="float32") + self.assertAllClose(result_sum, expected_sum) + + def mean_fn(x): + return backend_dlib.all_reduce(x, op="mean", axis_name="batch") + + result_mean = jax.pmap(mean_fn, axis_name="batch")(input_data) + + self.assertAllClose(result_mean, input_data) + + with self.assertRaisesRegex( + ValueError, "Unsupported reduction operation" + ): + backend_dlib.all_reduce(input_data[0], op="max", axis_name="batch") + + def test_all_gather(self): + devices = jax.devices() + num_devices = len(devices) + + input_data = np.arange(num_devices, dtype="float32").reshape( + num_devices, 1, 1 + ) + + def gather_fn(x): + return backend_dlib.all_gather(x, axis=0, axis_name="batch") + + results = jax.pmap(gather_fn, axis_name="batch")(input_data) + + expected_gathered = np.arange(num_devices, dtype="float32").reshape( + num_devices, 1 + ) + for i in range(num_devices): + self.assertAllClose(results[i], expected_gathered) + class ShardingCaptureLayer(layers.Layer): def __init__(self, **kwargs): diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..7691d01a7b64 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,200 @@ +import functools + +from keras.src import layers +from keras.src.backend import distribution_lib +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import ( + split_tensor_for_parallelism, +) + + +def analyze_dense_layer(layer): + """Classifies a Dense layer based on its input/output dimensions. + + This function uses a heuristic to determine if a Dense layer acts as an + 'up_projection' (expansion), a 'down_projection' (contraction), or a + standard 'dense' layer. This classification is used to determine the + appropriate sharding strategy (e.g., column-parallel vs row-parallel). + + Args: + layer: The Keras Dense layer instance to analyze. + + Returns: + str: One of 'up_projection', 'down_projection', or 'dense'. + """ + input_dim = None + output_dim = None + + kernel = getattr(layer, "kernel", getattr(layer, "_kernel", None)) + if kernel is not None: + if len(kernel.shape) == 2: + input_dim = kernel.shape[0] + output_dim = kernel.shape[1] + + if output_dim is None and hasattr(layer, "units"): + output_dim = layer.units + + if ( + input_dim is None + and hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + + if input_dim is None or output_dim is None: + return "dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "dense" + + +def _reduce_sum(x): + """Performs an all-reduce sum operation across the 'model' mesh axis. + + Args: + x: The input tensor to reduce. + + Returns: + The reduced tensor, summed across all devices in the model axis. + """ + return distribution_lib.all_reduce(x, op="sum", axis_name="model") + + +def _gather(x, axis): + """Performs an all-gather operation across the 'model' mesh axis. + + Args: + x: The input tensor shard to gather. + axis: The axis along which to concatenate the gathered parts. + + Returns: + The gathered tensor, concatenated along the specified axis. + """ + return distribution_lib.all_gather(x, axis=axis, axis_name="model") + + +def _get_layer_path(layer): + """Retrieves the unique hierarchical path of a layer. + + This utilizes `layer.path` (available in Keras 3+) which provides a + globally unique identifier based on the model structure (e.g., + 'model/dense_1'). Falls back to `layer.name` if the path is unavailable. + + Args: + layer: The Keras layer instance. + + Returns: + str: The unique path string for the layer. + """ + return getattr(layer, "path", layer.name) + + +def _apply_layer_sharding_rules(layer, device_count, state_rules, output_rules): + """Applies sharding rules to a single layer based on its type. + + This function populates `state_rules` and `output_rules` with strategies + specific to the layer class (e.g., Dense, EinsumDense, Embedding). It + determines how weights should be partitioned (state rules) and how outputs + should be synchronized (output rules). + + Args: + layer: The Keras layer instance to configure. + device_count: The number of devices available for tensor parallelism. + state_rules: A dictionary mapping variable paths to sharding functions. + Updated in-place. + output_rules: A dictionary mapping layer paths to output communication + functions. Updated in-place. + """ + + def split_rule(dim): + return functools.partial( + split_tensor_for_parallelism, device_count=device_count, dim=dim + ) + + def gather_rule(axis): + return functools.partial(_gather, axis=axis) + + layer_path = _get_layer_path(layer) + + if isinstance(layer, layers.Dense): + mlp_type = analyze_dense_layer(layer) + + if mlp_type == "up_projection": + state_rules[layer.kernel.path] = split_rule(dim=1) + if layer.use_bias: + state_rules[layer.bias.path] = split_rule(dim=0) + output_rules[layer_path] = {0: gather_rule(axis=-1)} + + elif mlp_type == "down_projection": + state_rules[layer.kernel.path] = split_rule(dim=0) + output_rules[layer_path] = {0: _reduce_sum} + + else: + state_rules[layer.kernel.path] = split_rule(dim=1) + if layer.use_bias: + state_rules[layer.bias.path] = split_rule(dim=0) + output_rules[layer_path] = {0: gather_rule(axis=-1)} + + elif isinstance(layer, layers.EinsumDense): + if "attention_output" in layer.name: # Use name check as heuristic + state_rules[layer.kernel.path] = split_rule(dim=0) + output_rules[layer_path] = {0: _reduce_sum} + else: + state_rules[layer.kernel.path] = split_rule(dim=1) + if hasattr(layer, "bias") and layer.bias is not None: + state_rules[layer.bias.path] = split_rule(dim=0) + output_rules[layer_path] = {0: gather_rule(axis=-1)} + + elif ( + isinstance(layer, (layers.Embedding,)) + or "Embedding" in layer.__class__.__name__ + ): + if hasattr(layer, "weights"): + found_embedding = False + for weight in layer.weights: + if "embedding" in weight.name or "weight" in weight.name: + state_rules[weight.path] = split_rule(dim=1) + found_embedding = True + + if found_embedding: + output_rules[layer_path] = {0: lambda x: x} + + +def get_default_config(model, device_ids): + """Generates a default tensor parallelism configuration for a model. + + This function traverses the model's layer hierarchy and + automatically generates a `LayoutMap`. This map contains: + 1. `state_rules`: How to shard the weights of supported layers + across the specified devices. + 2. `output_rules`: How to synchronize or gather the outputs of + these layers during the forward pass. + + Args: + model: The Keras model to configure. + device_ids: A list of device identifiers to be used + for distribution. + + Returns: + LayoutMap: A configuration object containing `state_rules` and + `output_rules` for tensor parallelism. + """ + device_count = len(device_ids) + state_rules = {} + output_rules = {} + + for layer in model._flatten_layers(recursive=True, include_self=True): + _apply_layer_sharding_rules( + layer, device_count, state_rules, output_rules + ) + + return LayoutMap(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..360d65ee16c4 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,157 @@ +import functools + +import keras +from keras.src import layers +from keras.src import testing +from keras.src.distribution.tensor_parallel.autoconfig import _gather +from keras.src.distribution.tensor_parallel.autoconfig import _reduce_sum +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer, +) +from keras.src.distribution.tensor_parallel.autoconfig import get_default_config +from keras.src.distribution.tensor_parallel.tensor_layout import ( + split_tensor_for_parallelism, +) + + +class AutoConfigTest(testing.TestCase): + def check_rule(self, rule, expected_device_count, expected_dim): + """ + Helper to verify a rule. + The rules are now functools.partial objects, so we verify their + configuration directly. + """ + self.assertIsInstance(rule, functools.partial) + self.assertEqual(rule.func, split_tensor_for_parallelism) + self.assertEqual(rule.keywords["device_count"], expected_device_count) + self.assertEqual(rule.keywords["dim"], expected_dim) + + def test_analyze_dense_layer_directly(self): + """Tests the heuristic for classifying Dense layers.""" + + up_proj_layer = layers.Dense(64, name="up") + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual(analyze_dense_layer(up_proj_layer), "up_projection") + down_proj_layer = layers.Dense(16, name="down") + down_proj_layer.build(input_shape=(None, 64)) + self.assertEqual( + analyze_dense_layer(down_proj_layer), + "down_projection", + ) + generic_layer = layers.Dense(32, name="generic") + generic_layer.build(input_shape=(None, 28)) + self.assertEqual(analyze_dense_layer(generic_layer), "dense") + non_dense_layer = layers.LayerNormalization() + self.assertEqual(analyze_dense_layer(non_dense_layer), "dense") + + def test_simple_mlp_model(self): + """Tests rule generation for a standard MLP block.""" + device_count = 2 + devices = [f"gpu:{i}" for i in range(device_count)] + + model = keras.Sequential( + [ + keras.Input(shape=(32,)), + layers.Dense(128, name="mlp_up"), + layers.Dense(32, name="mlp_down"), + ], + name="mlp_block", + ) + + layout_map = get_default_config(model, devices) + state_rules = layout_map.state_rules + output_rules = layout_map.output_rules + + up_kernel_key = "mlp_block/mlp_up/kernel" + self.assertIn(up_kernel_key, state_rules) + up_kernel_rule = state_rules[up_kernel_key] + self.check_rule(up_kernel_rule, device_count, 1) + + down_kernel_key = "mlp_block/mlp_down/kernel" + self.assertIn(down_kernel_key, state_rules) + down_kernel_rule = state_rules[down_kernel_key] + self.check_rule(down_kernel_rule, device_count, 0) + + self.assertIn("mlp_block/mlp_up", output_rules) + up_output_rule = output_rules["mlp_block/mlp_up"][0] + self.assertIsInstance(up_output_rule, functools.partial) + self.assertEqual(up_output_rule.func, _gather) + self.assertEqual(up_output_rule.keywords["axis"], -1) + + self.assertIn("mlp_block/mlp_down", output_rules) + down_output_rule = output_rules["mlp_block/mlp_down"][0] + self.assertEqual(down_output_rule, _reduce_sum) + + def test_model_with_embedding_and_einsumdense(self): + """Tests rule generation for Embedding and EinsumDense layers.""" + device_count = 4 + devices = [f"gpu:{i}" for i in range(device_count)] + + class SimpleTransformer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.embedding = layers.Embedding( + input_dim=1000, output_dim=64, name="embedding" + ) + self.qkv_proj = layers.EinsumDense( + "abc,cde->abde", + output_shape=(None, 3, 128), + bias_axes="de", + name="qkv_proj", + ) + self.attention_output = layers.EinsumDense( + "abde,cde->abc", + output_shape=(None, 64), + bias_axes="c", + name="attention_output", + ) + + def call(self, inputs): + x = self.embedding(inputs) + x = self.qkv_proj(x) + x = self.attention_output(x) + return x + + model = SimpleTransformer(name="transformer") + model(keras.ops.zeros((1, 10))) + + layout_map = get_default_config(model, devices) + state_rules = layout_map.state_rules + + expected_key = "transformer/embedding/embeddings" + self.assertIn(expected_key, state_rules) + emb_rule = state_rules[expected_key] + self.check_rule(emb_rule, device_count, 1) + + qkv_key = "transformer/qkv_proj/kernel" + self.assertIn(qkv_key, state_rules) + qkv_rule = state_rules[qkv_key] + self.check_rule(qkv_rule, device_count, 1) + + attn_out_key = "transformer/attention_output/kernel" + self.assertIn(attn_out_key, state_rules) + attn_out_rule = state_rules[attn_out_key] + self.check_rule(attn_out_rule, device_count, 0) + + def test_nested_model(self): + """Tests that the recursive traversal finds layers in nested models.""" + device_count = 2 + devices = [f"gpu:{i}" for i in range(device_count)] + inner_model = keras.Sequential( + [layers.Dense(64, name="inner_dense")], name="inner_block" + ) + outer_model = keras.Sequential( + [ + keras.Input(shape=(32,)), + layers.Dense(32, name="outer_dense_1"), + inner_model, + ], + name="outer_block", + ) + layout_map = get_default_config(outer_model, devices) + state_rules = layout_map.state_rules + + expected_key = "outer_block/inner_block/inner_dense/kernel" + self.assertIn(expected_key, state_rules) + inner_rule = state_rules[expected_key] + self.check_rule(inner_rule, device_count, 1) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..a83710793250 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,556 @@ +import numpy as np + +from keras.src import ops +from keras.src import optimizers +from keras.src import saving +from keras.src.backend import distribution_lib + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. + + Args: + base_optimizer: The Keras optimizer instance. + device_count: The total number of devices/processes in the distributed + setup. + shard_optimizer_states: If `True`, the optimizer's state variables + will be partitioned across `device_count` devices. Defaults to + `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer, + device_count, + shard_optimizer_states=True, + tensor_parallel_config=None, + ): + self.base_optimizer = base_optimizer + self.device_count = device_count + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variables = None + self._variable_to_slot_name = {} + + def _initialize_sharded_states(self): + """Partitions the optimizer's state variables across shards. + + This method inspects the variables created by the base optimizer and + maps them to model parameters. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variable_to_slot_name = {} + + model_vars_by_path = {v.path: v for v in self._variables} + + sorted_model_paths = sorted( + model_vars_by_path.keys(), key=len, reverse=True + ) + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + found_param = None + slot_name = None + + for model_path in sorted_model_paths: + model_var = model_vars_by_path[model_path] + + if model_path in state_var.path: + suffix = state_var.path.split(model_path)[-1] + if suffix.startswith("/"): + slot_name = suffix.strip("/") + found_param = model_var + break + + sanitized_path = model_path.replace("/", "_") + if sanitized_path in state_var.path: + suffix = state_var.path.split(sanitized_path)[-1] + clean_suffix = suffix.lstrip("/_") + if clean_suffix: + slot_name = clean_suffix + found_param = model_var + break + + if found_param is not None and slot_name is not None: + self._state_variable_to_parameter[state_var.path] = found_param + self._variable_to_slot_name[state_var.path] = slot_name + + sharding_dim = 0 + if self.tensor_parallel_config: + rule = self.tensor_parallel_config.state_rules.get( + found_param.path + ) + if rule: + if hasattr(rule, "keywords") and "dim" in rule.keywords: + sharding_dim = rule.keywords["dim"] + elif hasattr(rule, "dim"): + sharding_dim = rule.dim + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + + def _partition_state(self, state_variable, dim): + """Splits a single state variable numpy array into chunks. + + Args: + state_variable: The state variable to split. + dim: The dimension along which to split the variable. + + Returns: + list: A list of numpy arrays representing the split state. + """ + state_array = ops.convert_to_numpy(state_variable) + if ( + state_array.ndim > dim + and state_array.shape[dim] >= self.device_count + ): + return np.array_split(state_array, self.device_count, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.device_count)] + + def apply_gradients(self, gradients_and_vars, shard_models): + """Coordinates gradient synchronization and application. + + Args: + gradients_and_vars: A list containing lists of (gradient, variable) + tuples for each device. + shard_models: A list of model shards corresponding to the devices. + + Raises: + ValueError: If the number of gradient sets does not match the + device count. + """ + if len(gradients_and_vars) != self.device_count: + raise ValueError( + f"Expected {self.device_count} sets of gradients, " + f"but received {len(gradients_and_vars)}." + ) + + synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + + if self.shard_optimizer_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients, shard_models + ): + """Averages gradients across all shards and applies them once. + + This is used when `shard_optimizer_states` is False. + + Args: + synchronized_gradients: The list of synchronized gradients. + shard_models: The list of model shards. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients, shard_models + ): + """Applies gradients to each shard using its local optimizer state. + + Args: + synchronized_gradients: The list of synchronized gradients. + shard_models: The list of model shards. + """ + for shard_idx in range(self.device_count): + local_states = self._get_local_optimizer_states(shard_idx) + shard_optimizer = shard_models[shard_idx].optimizer.base_optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) + + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) + + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx): + """Constructs the state dictionary for a single shard. + + Args: + shard_idx: The index of the current shard. + + Returns: + dict: A dictionary mapping state names to their local values. + """ + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + + def _update_optimizer_internal_state(self, optimizer, local_states): + """Assigns local sharded state values to the optimizer's variables. + + Args: + optimizer: The local optimizer instance for the shard. + local_states: The local state dictionary. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + if "iterations" in local_states: + var.assign(local_states["iterations"]) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + var.assign(local_param_state) + + def _update_global_sharded_states(self, optimizer, shard_idx): + """Updates the main sharded_states dictionary after a gradient step. + + Args: + optimizer: The local optimizer instance. + shard_idx: The index of the current shard. + """ + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + + def _synchronize_gradients(self, gradients_and_vars): + """Synchronizes gradients across shards using tensor parallel rules. + + Args: + gradients_and_vars: A list of (gradient, variable) tuples. + + Returns: + list: The synchronized list of gradients and variables. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + + if variable.path not in self.tensor_parallel_config.state_rules: + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.device_count): + if gradients_and_vars[shard_idx][i][0] is not None: + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients): + """Performs a mean all-reduce operation on a list of gradients. + + This method uses the on-device communication primitive from the backend + (e.g., JAX's lax.pmean) when multiple devices are detected. + + Args: + gradients: A list of gradient tensors to reduce. + + Returns: + list: A list containing the reduced gradient repeated for each + device. + """ + if not gradients: + return [] + + if distribution_lib.get_device_count() > 1: + local_grad = gradients[0] + synced_tensor = distribution_lib.all_reduce( + local_grad, op="mean", axis_name="model" + ) + + return [synced_tensor for _ in range(self.device_count)] + + if len(gradients) == 1: + mean_grad = ops.convert_to_tensor(gradients[0]) + else: + stacked_grads = ops.stack( + [ops.convert_to_tensor(g) for g in gradients], axis=0 + ) + mean_grad = ops.mean(stacked_grads, axis=0) + + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self): + """Returns the weights of the base optimizer.""" + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] + + def set_weights(self, weights): + """Sets the weights of the base optimizer.""" + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables): + """Enables and initializes optimizer state sharding. + + Args: + variables: A list of model variables to track. + """ + self.shard_optimizer_states = True + self._variables = variables + self._initialize_sharded_states() + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This class serves as the public Keras-compliant interface (inherits + `optimizers.Optimizer`). It delegates the complex tasks of state + management, gradient synchronization, and sharding to the internal + `CoordinatedOptimizer` instance. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier. + device_count: The total number of devices/processes in the distributed + setup. + tensor_parallel_config: An optional configuration object. Defaults to + `None`. + name: The name of the optimizer. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + base_optimizer, + device_count, + tensor_parallel_config=None, + name=None, + **kwargs, + ): + if isinstance(base_optimizer, str): + base_optimizer_instance = optimizers.get(base_optimizer) + else: + base_optimizer_instance = base_optimizer + + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) + else: + lr_value = float(ops.convert_to_numpy(learning_rate)) + + if name is None: + name = f"TensorParallel_{base_optimizer_instance.name}" + + kwargs.pop("learning_rate", None) + + super().__init__( + learning_rate=lr_value, + name=name, + **kwargs, + ) + + self.base_optimizer = base_optimizer_instance + self.device_count = device_count + self.tensor_parallel_config = tensor_parallel_config + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + device_count, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars, **kwargs): + """Applies gradients to the model variables. + Args: + grads_and_vars: List of (gradient, variable) pairs. + **kwargs: Keyword arguments. Must contain `shard_models` if + `grads_and_vars` is a list of lists (sharded gradients). + """ + is_sharded_grads = ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ) + if is_sharded_grads: + if "shard_models" not in kwargs: + raise ValueError( + "The `shard_models` keyword argument is required when " + "applying sharded gradients (a list of lists)." + ) + shard_models = kwargs.get("shard_models") + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars, **kwargs) + + def update_step(self, gradient, variable, *args, **kwargs): + """Delegates the update step to the base optimizer. + + Args: + gradient: The gradient tensor. + variable: The variable to update. + *args: Additional arguments for the update. + **kwargs: Additional keyword arguments for the update. + """ + if hasattr(self.base_optimizer, "update_step"): + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + + return super().update_step(gradient, variable, *args, **kwargs) + + def build(self, variables): + """Builds the optimizer and initializes sharded states. + + Args: + variables: The list of variables to optimize. + """ + if self.built: + return + + self.base_optimizer.build(variables) + if variables: + iterations = self.base_optimizer.iterations + original_iterations_val = None + if iterations is not None: + original_iterations_val = ops.convert_to_numpy(iterations.value) + + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + if iterations is not None and original_iterations_val is not None: + iterations.assign(original_iterations_val) + + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self): + """Returns the weights of the base optimizer.""" + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights): + """Sets the weights of the base optimizer.""" + self.coordinated_optimizer.set_weights(weights) + + def get_config(self): + config = super().get_config() + base_optimizer_config = saving.serialize_keras_object( + self.base_optimizer + ) + config.update( + { + "base_optimizer": base_optimizer_config, + "device_count": self.device_count, + "tensor_parallel_config": self.tensor_parallel_config, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object( + base_optimizer_config, custom_objects=custom_objects + ) + return cls(base_optimizer=base_optimizer, **config) + + @property + def variables(self): + """Returns the list of variables from the base optimizer.""" + return self.base_optimizer.variables + + @property + def learning_rate(self): + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, value): + self.base_optimizer.learning_rate = value + + @property + def iterations(self): + """Returns the training iteration count from the base optimizer.""" + return self.base_optimizer.iterations diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..f174fbe4fc39 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,183 @@ +import numpy as np +import pytest + +import keras +from keras import ops +from keras.src import backend +from keras.src import optimizers +from keras.src import testing +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, device_count): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(device_count): + multiplier = float(i + 1) + gradients = [ + ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer(base_optimizer, device_count=4) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + super().apply_gradients(grads_and_vars, *args, **kwargs) + + device_count = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, device_count) + + coord = CoordinatedOptimizer( + optimizer, + device_count, + shard_optimizer_states=False, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) + self.assertAllClose( + grad_numpy, + np.ones_like(grad_numpy) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer("adam", device_count=4) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + device_count = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer(base_opt, device_count) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, device_count) + + coord_apply_tracker = {"called": False} + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + + base_apply_tracker = {"called": False} + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer(optimizers.Adam(), device_count=4) + model = self._get_simple_model() + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) + + def test_serialization(self): + """Tests manual reconstruction via from_config.""" + device_count = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + + optimizer = TensorParallelOptimizer(base_opt, device_count) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.device_count, device_count) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), device_count=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + + found_key = None + for key, param in state_to_param.items(): + if param is dense_output_kernel: + found_key = key + break + + self.assertIsNotNone(found_key) + self.assertIs(state_to_param[found_key], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout.py b/keras/src/distribution/tensor_parallel/tensor_layout.py new file mode 100644 index 000000000000..fa5b88e304d7 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout.py @@ -0,0 +1,34 @@ +import collections + +from keras.src import ops + + +def split_tensor_for_parallelism(tensor, index, device_count, dim): + """Calculates a slice of a tensor along a specified dimension for a + given index. + + This utility is used in tensor parallelism API to distribute a + tensor across multiple devices. + + Args: + tensor: The full tensor to be sharded. + index: The index of the device/shard to return (e.g., 0, 1, 2...). + device_count: The total number of parallel devices or splits. + dim: The dimension along which to split the tensor. Supports negative + indexing. + + Returns: + A tensor slice corresponding to the given `index`. + """ + if dim < 0: + split_dim = ops.ndim(tensor) + dim + else: + split_dim = dim + + splits = ops.array_split( + tensor, indices_or_sections=device_count, axis=split_dim + ) + return splits[index] + + +LayoutMap = collections.namedtuple("LayoutMap", ["state_rules", "output_rules"]) diff --git a/keras/src/distribution/tensor_parallel/tensor_layout_test.py b/keras/src/distribution/tensor_parallel/tensor_layout_test.py new file mode 100644 index 000000000000..72b21b4912aa --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_layout_test.py @@ -0,0 +1,163 @@ +from keras.src import ops +from keras.src import testing +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import ( + split_tensor_for_parallelism, +) + + +class LayoutTest(testing.TestCase): + """Test suite for tensor layout actions and mappings.""" + + def test_split_with_even_division(self): + """Tests splitting a tensor that divides evenly among workers.""" + device_count = 4 + dim = 0 + tensor = ops.reshape(ops.arange(16, dtype="float32"), (8, 2)) + + expected_shard_0 = ops.array([[0.0, 1.0], [2.0, 3.0]]) + expected_shard_2 = ops.array([[8.0, 9.0], [10.0, 11.0]]) + + shard_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=dim + ) + shard_2 = split_tensor_for_parallelism( + tensor, index=2, device_count=device_count, dim=dim + ) + + self.assertAllClose(shard_0, expected_shard_0) + self.assertAllClose(shard_2, expected_shard_2) + self.assertEqual(shard_0.shape, (2, 2)) + + def test_split_with_uneven_division(self): + """Tests splitting tensor where remainder is distributed correctly.""" + device_count = 3 + dim = 0 + tensor = ops.reshape(ops.arange(10, dtype="float32"), (10, 1)) + + shard_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=dim + ) + self.assertEqual(shard_0.shape, (4, 1)) + self.assertAllClose(shard_0, ops.array([[0.0], [1.0], [2.0], [3.0]])) + + shard_1 = split_tensor_for_parallelism( + tensor, index=1, device_count=device_count, dim=dim + ) + self.assertEqual(shard_1.shape, (3, 1)) + self.assertAllClose(shard_1, ops.array([[4.0], [5.0], [6.0]])) + + shard_2 = split_tensor_for_parallelism( + tensor, index=2, device_count=device_count, dim=dim + ) + self.assertEqual(shard_2.shape, (3, 1)) + self.assertAllClose(shard_2, ops.array([[7.0], [8.0], [9.0]])) + + def test_split_and_undo_cycle_even_removed(self): + """ + Confirms that the original tensor can be reconstructed. + """ + device_count = 2 + dim = 0 + original_tensor = ops.reshape(ops.arange(12, dtype="float32"), (6, 2)) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + reconstructed_tensor = ops.concatenate(shards, axis=dim) + + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_and_undo_cycle_uneven_removed(self): + """ + Confirms that original tensor can be reconstructed with uneven split. + """ + device_count = 4 + dim = 0 + original_tensor = ops.reshape(ops.arange(22, dtype="float32"), (11, 2)) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + self.assertEqual(shards[0].shape, (3, 2)) + self.assertEqual(shards[1].shape, (3, 2)) + self.assertEqual(shards[2].shape, (3, 2)) + self.assertEqual(shards[3].shape, (2, 2)) + + reconstructed_tensor = ops.concatenate(shards, axis=dim) + self.assertAllClose(original_tensor, reconstructed_tensor) + + def test_split_last_dimension(self): + """Tests splitting on the last dimension.""" + device_count = 3 + dim = 2 + original_tensor = ops.reshape( + ops.arange(30, dtype="float32"), (2, 5, 3) + ) + + shards = [ + split_tensor_for_parallelism( + original_tensor, index=i, device_count=device_count, dim=dim + ) + for i in range(device_count) + ] + + self.assertEqual(shards[0].shape, (2, 5, 1)) + self.assertEqual(shards[1].shape, (2, 5, 1)) + self.assertEqual(shards[2].shape, (2, 5, 1)) + + def test_split_with_sharding_type_hint(self): + """Tests using 'row' and 'column' sharding hints for 2D tensors.""" + device_count = 2 + tensor = ops.reshape(ops.arange(16, dtype="float32"), (4, 4)) + + row_dim = 0 + shard_row_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=row_dim + ) + self.assertAllClose(shard_row_0, tensor[:2, :]) + + col_dim = 1 + shard_col_0 = split_tensor_for_parallelism( + tensor, index=0, device_count=device_count, dim=col_dim + ) + self.assertAllClose(shard_col_0, tensor[:, :2]) + + def test_layout_map_namedtuple_behavior(self): + """Tests basic behavior of the LayoutMap namedtuple.""" + + def rule_kernel(tensor, index): + return split_tensor_for_parallelism( + tensor, index=index, device_count=2, dim=0 + ) + + def rule_output(tensor, index): + return split_tensor_for_parallelism( + tensor, index=index, device_count=2, dim=-1 + ) + + state_rules = {"kernel": rule_kernel} + output_rules = {"output": rule_output} + + layout_map = LayoutMap( + state_rules=state_rules, output_rules=output_rules + ) + + self.assertIs(layout_map.state_rules, state_rules) + self.assertIs(layout_map.output_rules, output_rules) + + self.assertIs(layout_map[0], state_rules) + self.assertIs(layout_map[1], output_rules) + + with self.assertRaises(AttributeError): + layout_map.state_rules = {} + + self.assertTrue(callable(layout_map.state_rules["kernel"]))