From 5c12945b110b1d695d144a7f271a22fafd9338f9 Mon Sep 17 00:00:00 2001 From: Raahul Kalyaan Jakka Date: Mon, 1 Dec 2025 11:10:57 -0800 Subject: [PATCH] Added unit test for Heuristic Storage Reservation (#3511) Summary: **Context:** Heuristic Storage reservation is a common component for all planner that checks if the given module along with the constraints can be sharded across the topology. **In this diff:** We added a UT to validate the error for storage use in the storage reservation process. If the given module is larger than the provided topology. We need to OOM the process asap with appropriate error to notify the PG Reviewed By: kausv, mserturk Differential Revision: D85892579 --- .../tests/test_storage_reservations.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/tests/test_storage_reservations.py b/torchrec/distributed/planner/tests/test_storage_reservations.py index 4bc2b40f3..128f9f904 100644 --- a/torchrec/distributed/planner/tests/test_storage_reservations.py +++ b/torchrec/distributed/planner/tests/test_storage_reservations.py @@ -19,7 +19,7 @@ _get_module_size, HeuristicalStorageReservation, ) -from torchrec.distributed.planner.types import Topology +from torchrec.distributed.planner.types import PlannerError, PlannerErrorType, Topology from torchrec.distributed.test_utils.test_model import TestTowerInteraction from torchrec.distributed.types import ModuleSharder @@ -36,6 +36,36 @@ def __init__(self, shardable_sparse: nn.Module) -> None: class TestHeuristicalStorageReservation(unittest.TestCase): + + def test_validate_storage_reservations_errors(self) -> None: + tables = [ + EmbeddingBagConfig( + num_embeddings=1_000_000, + embedding_dim=1024, + name="table_0", + feature_names=["feature_0"], + ), + ] + + ebc = EmbeddingBagCollection(tables) + model = TestModel(shardable_sparse=ebc) + + # Reserving 100% of HBM to make sure the heuristic storage reservation fails + heuristical_storage_reservation = HeuristicalStorageReservation(percentage=1) + with self.assertRaises(PlannerError) as context: + heuristical_storage_reservation.reserve( + topology=Topology(world_size=1, compute_device="cuda"), + batch_size=1024, + module=model, + sharders=cast( + List[ModuleSharder[nn.Module]], [EmbeddingBagCollectionSharder()] + ), + ) + + self.assertEqual( + context.exception.error_type, PlannerErrorType.INSUFFICIENT_STORAGE + ) + def test_storage_reservations_ebc(self) -> None: tables = [ EmbeddingBagConfig(