Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions init2winit/model_lib/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,24 @@ class DLRMResNet(nn.Module):
activation_function: str = 'relu'
embedding_init_multiplier: Optional[float] = None

def get_embeddings(self, embedding_table_block, indices_block):
"""Get embeddings for a block of indices."""
embedding_table_block = jax.lax.all_gather(
embedding_table_block, 'devices', axis=1, tiled=True)
embeddings = jnp.take(embedding_table_block, indices_block, axis=0)
return embeddings

@nn.compact
def __call__(self, x, train):
shmapped_get_embeddings = jax.experimental.shard_map.shard_map(
self.get_embeddings,
mesh=model_utils.get_default_mesh(),
in_specs=(
P(None, 'devices'),
P('devices',),
),
out_specs=P('devices',),
)
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

Expand All @@ -295,26 +311,31 @@ def __call__(self, x, train):
)(bot_mlp_input)
bot_mlp_input += activation_fn(x)

bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
feature_stack = jnp.reshape(bot_mlp_output,
[batch_size, -1, self.embed_dim])
base_init_fn = jnn.initializers.uniform(scale=1.0)
if self.embedding_init_multiplier is None:
embedding_init_multiplier = 1 / self.vocab_size**0.5
else:
embedding_init_multiplier = self.embedding_init_multiplier
# Embedding table init and lookup for a single unified table.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
idx_lookup = cat_features % self.vocab_size

def scaled_init(key, shape, dtype=jnp.float_):
return base_init_fn(key, shape, dtype) * embedding_init_multiplier

embedding_table = self.param(
'embedding_table',
scaled_init,
[self.vocab_size, self.embed_dim])
'embedding_table', scaled_init, [self.vocab_size, self.embed_dim]
)

embed_features = embedding_table[idx_lookup]
batch_size = bot_mlp_input.shape[0]
embed_features = jnp.reshape(
embed_features, (batch_size, 26 * self.embed_dim))
top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
embed_features = shmapped_get_embeddings(embedding_table, idx_lookup)
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(
concat_features=feature_stack, keep_diags=self.keep_diags)
top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
axis=-1)
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
Expand Down Expand Up @@ -354,6 +375,15 @@ def scaled_init(key, shape, dtype=jnp.float_):
class DLRMResNetModel(base_model.BaseModel):
"""DLRMResNetModel init2winit class."""

def get_sharding_overrides(self, mesh: Any) -> Any:
type_to_sharding = super().get_sharding_overrides(mesh)
overrides = {
ParameterType.EMBEDDING: NamedSharding(mesh, P(None, 'devices')),
}

type_to_sharding.update(overrides)
return type_to_sharding

def build_flax_module(self):
"""DLRMResNet for ad click probability prediction."""
return DLRMResNet(
Expand All @@ -371,7 +401,16 @@ def get_fake_inputs(self, hps):
"""Helper method solely for purpose of initalizing the model."""
# NOTE(dsuo): hps.input_shape for `criteo_terabyte_input_pipeline` is
# (39,)
fake_batch_size = hps.batch_size

# If fake batch size is not divisible by the number of devices, we use the
# smallest batch size that is divisible by the number of devices.
# This is necessary for shard_mapped get_embeddings() to work.
# It also makes our fake batch closer to the real batch size.
if fake_batch_size % jax.device_count() != 0:
fake_batch_size = jax.device_count()

dummy_inputs = [
jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype)
jnp.zeros((fake_batch_size, *hps.input_shape), dtype=hps.model_dtype)
]
return dummy_inputs