diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 6a2637efa..8acc36698 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -769,7 +769,7 @@ def forward( self, features: KeyedJaggedTensor, ) -> Dict[str, torch.Tensor]: - return self.regroup([self.ebc(features)]) + return self.regroup(keyed_tensors=[self.ebc(features)]) class myModel(nn.Module): def __init__(self, ebc, regroup): @@ -813,6 +813,96 @@ def forward( for key in eager_out.keys(): torch.testing.assert_close(deserialized_out[key], eager_out[key]) + def test_key_order_with_ebc_and_regroup_input_kwargs(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + ebc1 = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + ebc2 = EmbeddingBagCollection( + tables=[tb1_config, tb3_config, tb2_config], + is_weighted=False, + ) + ebc2.load_state_dict(ebc1.state_dict()) + regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) + + class mySparse(nn.Module): + def __init__(self, ebc): + super().__init__() + self.ebc = ebc + + def forward( + self, + features: KeyedJaggedTensor, + ) -> KeyedTensor: + return self.ebc(features) + + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.regroup = regroup + self.sparse = mySparse(ebc) + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + sparse_out = self.sparse(features) + return self.regroup(keyed_tensors=[sparse_out]) + + model = myModel(ebc1, regroup) + eager_out = model(id_list_features) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten_ep, + JsonSerializer, + short_circuit_pytree_ebc_regroup=True, + finalize_interpreter_modules=True, + ) + + # we export the model with ebc1 and unflatten the model, + # and then swap with ebc2 (you can think this as the the sharding process + # resulting a shardedEBC), so that we can mimic the key-order change + # pyre-fixme[16]: `Module` has no attribute `ebc`. + # pyre-fixme[16]: `Tensor` has no attribute `ebc`. + deserialized_model.sparse.ebc = ebc2 + + deserialized_out = deserialized_model(id_list_features) + for key in eager_out.keys(): + torch.testing.assert_close(deserialized_out[key], eager_out[key]) + def test_cast_in_regroup(self) -> None: class Model(nn.Module): def __init__(self, ebc, fpebc, regroup): diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index b9390a5d5..05bbfe9fb 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -12,7 +12,7 @@ import logging import operator from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import cast, Dict, List, Optional, Tuple, Type, Union import torch @@ -370,25 +370,33 @@ def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]: # remove tree_unflatten from the in_fqns (in-coming nodes) for fqn in in_fqns: submodule, node = _get_graph_node(module, fqn) - assert len(node.args) == 1 - getitem_getitem: Node = node.args[0] # pyre-ignore[9] + # kt_regroup node will have either one arg or one kwarg + assert len(node.args) == 1 or len(node.kwargs) == 1 + use_args = len(node.args) == 1 + + getitem_getitem = cast( + Node, node.args[0] if use_args else list(node.kwargs.values())[0] + ) assert ( getitem_getitem.op == "call_function" and getitem_getitem.target == operator.getitem ) - tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16] + tree_unflatten_getitem = cast(Node, getitem_getitem.args[0]) assert ( tree_unflatten_getitem.op == "call_function" and tree_unflatten_getitem.target == operator.getitem ) - tree_unflatten = tree_unflatten_getitem.args[0] + tree_unflatten = cast(Node, tree_unflatten_getitem.args[0]) assert ( tree_unflatten.op == "call_function" and tree_unflatten.target == torch.utils._pytree.tree_unflatten ) logger.info(f"Removing tree_unflatten from {fqn}") input_nodes = tree_unflatten.args[0] - node.args = (input_nodes,) + if use_args: + node.args = (input_nodes,) + else: + node.kwargs = {list(node.kwargs.keys())[0]: input_nodes} # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `eliminate_dead_code`. submodule.graph.eliminate_dead_code()