diff --git a/init2winit/model_lib/dlrm.py b/init2winit/model_lib/dlrm.py index fc59f33d..1e72db2c 100644 --- a/init2winit/model_lib/dlrm.py +++ b/init2winit/model_lib/dlrm.py @@ -269,8 +269,24 @@ class DLRMResNet(nn.Module): activation_function: str = 'relu' embedding_init_multiplier: Optional[float] = None + def get_embeddings(self, embedding_table_block, indices_block): + """Get embeddings for a block of indices.""" + embedding_table_block = jax.lax.all_gather( + embedding_table_block, 'devices', axis=1, tiled=True) + embeddings = jnp.take(embedding_table_block, indices_block, axis=0) + return embeddings + @nn.compact def __call__(self, x, train): + shmapped_get_embeddings = jax.experimental.shard_map.shard_map( + self.get_embeddings, + mesh=model_utils.get_default_mesh(), + in_specs=( + P(None, 'devices'), + P('devices',), + ), + out_specs=P('devices',), + ) bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -295,26 +311,31 @@ def __call__(self, x, train): )(bot_mlp_input) bot_mlp_input += activation_fn(x) + bot_mlp_output = bot_mlp_input + batch_size = bot_mlp_output.shape[0] + feature_stack = jnp.reshape(bot_mlp_output, + [batch_size, -1, self.embed_dim]) base_init_fn = jnn.initializers.uniform(scale=1.0) if self.embedding_init_multiplier is None: embedding_init_multiplier = 1 / self.vocab_size**0.5 else: embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. - idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + idx_lookup = cat_features % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) * embedding_init_multiplier embedding_table = self.param( - 'embedding_table', - scaled_init, - [self.vocab_size, self.embed_dim]) + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) - embed_features = embedding_table[idx_lookup] - batch_size = bot_mlp_input.shape[0] - embed_features = jnp.reshape( - embed_features, (batch_size, 26 * self.embed_dim)) - top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) + embed_features = shmapped_get_embeddings(embedding_table, idx_lookup) + feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) + dot_interact_output = dot_interact( + concat_features=feature_stack, keep_diags=self.keep_diags) + top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], + axis=-1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) @@ -354,6 +375,15 @@ def scaled_init(key, shape, dtype=jnp.float_): class DLRMResNetModel(base_model.BaseModel): """DLRMResNetModel init2winit class.""" + def get_sharding_overrides(self, mesh: Any) -> Any: + type_to_sharding = super().get_sharding_overrides(mesh) + overrides = { + ParameterType.EMBEDDING: NamedSharding(mesh, P(None, 'devices')), + } + + type_to_sharding.update(overrides) + return type_to_sharding + def build_flax_module(self): """DLRMResNet for ad click probability prediction.""" return DLRMResNet( @@ -371,7 +401,16 @@ def get_fake_inputs(self, hps): """Helper method solely for purpose of initalizing the model.""" # NOTE(dsuo): hps.input_shape for `criteo_terabyte_input_pipeline` is # (39,) + fake_batch_size = hps.batch_size + + # If fake batch size is not divisible by the number of devices, we use the + # smallest batch size that is divisible by the number of devices. + # This is necessary for shard_mapped get_embeddings() to work. + # It also makes our fake batch closer to the real batch size. + if fake_batch_size % jax.device_count() != 0: + fake_batch_size = jax.device_count() + dummy_inputs = [ - jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) + jnp.zeros((fake_batch_size, *hps.input_shape), dtype=hps.model_dtype) ] return dummy_inputs