From bbe345f283538fbf296efb0b1ce486d3dd210fe2 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Wed, 15 Apr 2026 19:35:19 -0700 Subject: [PATCH 1/6] max label per anchor --- gigl/distributed/dataset_factory.py | 6 +++ gigl/distributed/dist_ablp_neighborloader.py | 1 + gigl/distributed/dist_dataset.py | 28 +++++++++- gigl/distributed/graph_store/dist_server.py | 6 ++- gigl/distributed/graph_store/storage_utils.py | 19 ++++++- gigl/utils/data_splitters.py | 54 ++++++++++++++++++- .../dist_ablp_neighborloader_test.py | 31 ++++++++++- .../graph_store/remote_dist_dataset_test.py | 29 ++++++++++ tests/unit/utils/data_splitters_test.py | 32 +++++++++++ 9 files changed, 199 insertions(+), 7 deletions(-) diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 1f78f83f3..b43301674 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -49,6 +49,7 @@ DistNodeSplitter, NodeAnchorLinkSplitter, NodeSplitter, + get_max_labels_per_anchor_node_from_runtime_args, select_ssl_positive_label_edges, ) @@ -502,6 +503,8 @@ def build_dataset_from_task_config_uri( - should_load_tensors_in_parallel (bool): Whether TFRecord loading should happen in parallel across entities Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive `splitter` directly. + - max_labels_per_anchor_node (Optional[int]): Cap for how many labels to + materialize per anchor node for ABLP label fetching. If training there are two additional arguments: - num_val (float): Percentage of edges to use for validation, defaults to 0.1. Must in in range [0, 1]. - num_test (float): Percentage of edges to use for testing, defaults to 0.1. Must be in range [0, 1]. @@ -530,6 +533,7 @@ def build_dataset_from_task_config_uri( ) ssl_positive_label_percentage: Optional[float] = None + max_labels_per_anchor_node: Optional[int] = None splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None if is_inference: args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) @@ -576,6 +580,7 @@ def build_dataset_from_task_config_uri( raise ValueError( f"Unsupported task metadata type: {task_metadata_pb_wrapper.task_metadata_type}" ) + max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args(args) assert sample_edge_direction in ( "in", @@ -628,5 +633,6 @@ def build_dataset_from_task_config_uri( splitter=splitter, _ssl_positive_label_percentage=ssl_positive_label_percentage, ) + dataset.max_labels_per_anchor_node = max_labels_per_anchor_node return dataset diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 215a92a51..c7845f701 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -538,6 +538,7 @@ def _setup_for_colocated( node_ids=curr_process_nodes, positive_label_edge_type=positive_label_edge_type, negative_label_edge_type=negative_label_edge_type, + max_labels_per_anchor_node=dataset.max_labels_per_anchor_node, ) positive_labels_by_label_edge_type[positive_label_edge_type] = ( positive_labels diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index d37bbc925..808907be6 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -26,7 +26,11 @@ GraphPartitionData, PartitionOutput, ) -from gigl.utils.data_splitters import NodeAnchorLinkSplitter, NodeSplitter +from gigl.utils.data_splitters import ( + NodeAnchorLinkSplitter, + NodeSplitter, + validate_max_labels_per_anchor_node, +) from gigl.utils.share_memory import share_memory logger = Logger() @@ -80,6 +84,7 @@ def __init__( degree_tensor: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> None: """ Initializes the fields of the DistDataset class. This function is called upon each serialization of the DistDataset instance. @@ -105,6 +110,8 @@ def __init__( edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + max_labels_per_anchor_node (Optional[int]): Optional cap for how many + labels to materialize per anchor node for ABLP label fetching. """ self._rank: int = rank self._world_size: int = world_size @@ -143,6 +150,9 @@ def __init__( self._degree_tensor: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor + self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node( + max_labels_per_anchor_node + ) # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear # naming (i.e. rank, world_size). @@ -329,6 +339,18 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor + @property + def max_labels_per_anchor_node(self) -> Optional[int]: + return self._max_labels_per_anchor_node + + @max_labels_per_anchor_node.setter + def max_labels_per_anchor_node( + self, new_max_labels_per_anchor_node: Optional[int] + ) -> None: + self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node( + new_max_labels_per_anchor_node + ) + @property def train_node_ids( self, @@ -858,6 +880,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + Optional[int], ]: """ Serializes the member variables of the DistDatasetClass @@ -880,6 +903,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[int]: Optional per-anchor label cap for ABLP label fetching """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -908,6 +932,7 @@ def share_ipc( self._node_feature_info, # Additional field unique to DistDataset class self._edge_feature_info, # Additional field unique to DistDataset class self._degree_tensor, # Additional field unique to DistDataset class + self._max_labels_per_anchor_node, # Additional field unique to DistDataset class ) return ipc_handle @@ -1164,6 +1189,7 @@ def _rebuild_distributed_dataset( Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors + Optional[int], # Optional per-anchor label cap for ABLP label fetching ], ): dataset = DistDataset.from_ipc_handle(ipc_handle) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1110e47c4..533b1dfb3 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -408,7 +408,11 @@ def get_ablp_input( request.supervision_edge_type, self.dataset.get_edge_types() ) positive_labels, negative_labels = get_labels_for_anchor_nodes( - self.dataset, anchors, positive_label_edge_type, negative_label_edge_type + self.dataset, + anchors, + positive_label_edge_type, + negative_label_edge_type, + max_labels_per_anchor_node=self.dataset.max_labels_per_anchor_node, ) return anchors, positive_labels, negative_labels diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index 548e8ee7f..e67da070a 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -32,7 +32,11 @@ ) from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper -from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter, DistNodeSplitter +from gigl.utils.data_splitters import ( + DistNodeAnchorLinkSplitter, + DistNodeSplitter, + get_max_labels_per_anchor_node_from_runtime_args, +) logger = Logger() @@ -45,6 +49,7 @@ def build_storage_dataset( splitter: Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]] = None, should_load_tensors_in_parallel: bool = True, ssl_positive_label_percentage: Optional[float] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> DistDataset: """Build a :class:`DistDataset` for a storage node from a task config. @@ -71,6 +76,10 @@ def build_storage_dataset( self-supervised positive labels. Must be ``None`` when supervised edge labels are already provided. For example, ``0.1`` selects 10 % of edges. + max_labels_per_anchor_node: Optional cap for how many labels to + materialize per anchor node when the storage server serves ABLP + input. If ``None``, this is inferred from the task config's + ``trainer_args``. Returns: A partitioned :class:`DistDataset` ready to be served. @@ -78,12 +87,16 @@ def build_storage_dataset( gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) + if max_labels_per_anchor_node is None: + max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args( + dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + ) serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, tfrecord_uri_pattern=tf_record_uri_pattern, ) - return build_dataset( + dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, sample_edge_direction=sample_edge_direction, should_load_tensors_in_parallel=should_load_tensors_in_parallel, @@ -91,6 +104,8 @@ def build_storage_dataset( splitter=splitter, _ssl_positive_label_percentage=ssl_positive_label_percentage, ) + dataset.max_labels_per_anchor_node = max_labels_per_anchor_node + return dataset def _run_storage_server_session( diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 4aa416f2e..43f28b6e4 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -31,11 +31,44 @@ logger = Logger() PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64) +MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG: Final[str] = "max_labels_per_anchor_node" # We need to make the protocols for the node splitter and node anchor linked spliter runtime checkable so that # we can make isinstance() checks on them at runtime. +def validate_max_labels_per_anchor_node( + max_labels_per_anchor_node: Optional[int], +) -> Optional[int]: + """Validate the optional per-anchor label cap.""" + if max_labels_per_anchor_node is None: + return None + if max_labels_per_anchor_node <= 0: + raise ValueError( + "max_labels_per_anchor_node must be a positive integer when provided." + ) + return max_labels_per_anchor_node + + +def get_max_labels_per_anchor_node_from_runtime_args( + runtime_args: Mapping[str, str], +) -> Optional[int]: + """Parse the optional per-anchor label cap from runtime args.""" + raw_max_labels_per_anchor_node = runtime_args.get( + MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG + ) + if raw_max_labels_per_anchor_node is None: + return None + try: + parsed_max_labels_per_anchor_node = int(raw_max_labels_per_anchor_node) + except ValueError as exc: + raise ValueError( + f"Invalid {MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG} value " + f"{raw_max_labels_per_anchor_node!r}. Expected a positive integer." + ) from exc + return validate_max_labels_per_anchor_node(parsed_max_labels_per_anchor_node) + + @runtime_checkable class NodeAnchorLinkSplitter(Protocol): """Protocol that should be satisfied for anything that is used to split on edges. @@ -562,6 +595,7 @@ def get_labels_for_anchor_nodes( node_ids: torch.Tensor, positive_label_edge_type: PyGEdgeType, negative_label_edge_type: Optional[PyGEdgeType] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Selects labels for the given node ids based on the provided edge types. @@ -592,6 +626,8 @@ def get_labels_for_anchor_nodes( positive_label_edge_type (PyGEdgeType): The edge type to use for the positive labels. negative_label_edge_type (Optional[PyGEdgeType]): The edge type to use for the negative labels. Defaults to None. If not provided no negative labels will be returned. + max_labels_per_anchor_node (Optional[int]): If provided, caps the number of + positive and negative labels materialized per anchor node. Returns: Tuple of (positive labels, negative_labels?) negative labels may be None depending on if negative_label_edge_type is provided. @@ -612,13 +648,19 @@ def get_labels_for_anchor_nodes( # Labels is NxM, where N is the number of nodes, and M is the max number of labels. positive_labels = _get_padded_labels( - node_ids, positive_node_topo, allow_non_existant_node_ids=False + node_ids, + positive_node_topo, + allow_non_existant_node_ids=False, + max_labels_per_anchor_node=max_labels_per_anchor_node, ) if negative_node_topo is not None: # Labels is NxM, where N is the number of nodes, and M is the max number of labels. negative_labels = _get_padded_labels( - node_ids, negative_node_topo, allow_non_existant_node_ids=True + node_ids, + negative_node_topo, + allow_non_existant_node_ids=True, + max_labels_per_anchor_node=max_labels_per_anchor_node, ) else: negative_labels = None @@ -630,6 +672,7 @@ def _get_padded_labels( anchor_node_ids: torch.Tensor, topo: Topology, allow_non_existant_node_ids: bool = False, + max_labels_per_anchor_node: Optional[int] = None, ) -> torch.Tensor: """Returns the padded labels and the max range of labels. @@ -642,9 +685,14 @@ def _get_padded_labels( topo (Topology): The topology to use for the labels. allow_non_existant_node_ids (bool): If True, will allow anchor node ids that do not exist in the topology. This means that the returned tensor will be padded with `PADDING_NODE` for those anchor node ids. + max_labels_per_anchor_node (Optional[int]): If provided, caps the number of + labels materialized per anchor node. Returns: The shape of the returned tensor is [N, max_number_of_labels]. """ + max_labels_per_anchor_node = validate_max_labels_per_anchor_node( + max_labels_per_anchor_node + ) # indptr is the ROW_INDEX of a CSR matrix. # and indices is the COL_INDEX of a CSR matrix. # See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) @@ -660,6 +708,8 @@ def _get_padded_labels( ends = indptr[anchor_node_ids + 1] # [N] max_range = int(torch.max(ends - starts).item()) + if max_labels_per_anchor_node is not None: + max_range = min(max_range, max_labels_per_anchor_node) # Sample all labels based on the CSR start/stop indices. # Creates "indices" for us to us, e.g [[0, 1], [2, 3]] diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 4575b7ad4..315b2b590 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -445,6 +445,7 @@ def tearDown(self): 10: torch.tensor([13, 16]), 15: torch.tensor([17]), }, + max_labels_per_anchor_node=None, ), param( "Positive edges", @@ -457,6 +458,28 @@ def tearDown(self): 15: torch.tensor([16]), }, expected_negative_labels=None, + max_labels_per_anchor_node=None, + ), + param( + "Positive and Negative edges with label cap", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + expected_node=torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]), + expected_srcs=torch.tensor([10, 10, 15, 15, 16, 16, 11, 11]), + expected_dsts=torch.tensor([11, 12, 13, 14, 12, 14, 13, 17]), + expected_positive_labels={ + 10: torch.tensor([15]), + 15: torch.tensor([16]), + }, + expected_negative_labels={ + 10: torch.tensor([13]), + 15: torch.tensor([17]), + }, + max_labels_per_anchor_node=1, ), ] ) @@ -469,6 +492,7 @@ def test_ablp_dataloader( expected_dsts, expected_positive_labels, expected_negative_labels, + max_labels_per_anchor_node, ): # Graph looks like https://is.gd/w2oEVp: # Message passing @@ -511,7 +535,12 @@ def test_ablp_dataloader( partitioned_positive_labels=None, partitioned_node_labels=None, ) - dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset = DistDataset( + rank=0, + world_size=1, + edge_dir="out", + max_labels_per_anchor_node=max_labels_per_anchor_node, + ) dataset.build(partition_output=partition_output) mp.spawn( diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index a0f1a3594..8a38ccceb 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -501,6 +501,35 @@ def test_fetch_ablp_input(self, mock_async_request): torch.tensor([[1]]), ) + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_fetch_ablp_input_respects_max_labels_per_anchor_node( + self, mock_async_request + ): + _create_server_with_splits() + self.assertIsNotNone(_test_server) + assert _test_server is not None + _test_server.dataset.max_labels_per_anchor_node = 1 + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + result = remote_dataset.fetch_ablp_input( + split="train", anchor_node_type=USER, supervision_edge_type=USER_TO_STORY + ) + pos_labels, neg_labels = result[0].labels[USER_TO_STORY] + self.assert_tensor_equality( + pos_labels, + torch.tensor([[0], [1], [2]]), + ) + assert neg_labels is not None + self.assert_tensor_equality( + neg_labels, + torch.tensor([[2], [3], [4]]), + ) + @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, diff --git a/tests/unit/utils/data_splitters_test.py b/tests/unit/utils/data_splitters_test.py index c9c7f6ce4..22424d889 100644 --- a/tests/unit/utils/data_splitters_test.py +++ b/tests/unit/utils/data_splitters_test.py @@ -18,6 +18,7 @@ _fast_hash, _get_padded_labels, get_labels_for_anchor_nodes, + get_max_labels_per_anchor_node_from_runtime_args, select_ssl_positive_label_edges, ) from tests.test_assets.distributed.utils import ( @@ -810,6 +811,37 @@ def test_get_padded_labels(self, _, node_ids, topo, expected): labels = _get_padded_labels(node_ids, topo) assert_close(labels, expected, rtol=0, atol=0) + def test_get_padded_labels_with_max_labels_per_anchor_node(self): + labels = _get_padded_labels( + torch.tensor([0, 1]), + Topology( + edge_index=torch.tensor([[0, 0, 1], [1, 2, 2]], dtype=torch.int64), + layout="CSR", + ), + max_labels_per_anchor_node=1, + ) + assert_close( + labels, + torch.tensor([[1], [2]], dtype=torch.int64), + rtol=0, + atol=0, + ) + + def test_get_max_labels_per_anchor_node_from_runtime_args(self): + self.assertIsNone(get_max_labels_per_anchor_node_from_runtime_args({})) + self.assertEqual( + get_max_labels_per_anchor_node_from_runtime_args( + {"max_labels_per_anchor_node": "3"} + ), + 3, + ) + + def test_get_max_labels_per_anchor_node_from_runtime_args_invalid(self): + with self.assertRaises(ValueError): + get_max_labels_per_anchor_node_from_runtime_args( + {"max_labels_per_anchor_node": "0"} + ) + @parameterized.expand( [ param( From ea4a0d9034dcc31722acc255b0cd4936393e388e Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 27 Apr 2026 15:31:05 -0700 Subject: [PATCH 2/6] attn --- .../graph_transformer/graph_transformer.py | 153 +++++++++++++++- gigl/transforms/graph_transformer.py | 118 ++++++++++++ .../graph_transformer_test.py | 172 ++++++++++++++++++ .../unit/transforms/graph_transformer_test.py | 61 +++++++ 4 files changed, 495 insertions(+), 9 deletions(-) diff --git a/gigl/src/common/models/graph_transformer/graph_transformer.py b/gigl/src/common/models/graph_transformer/graph_transformer.py index 82b188e4f..e27d382fb 100644 --- a/gigl/src/common/models/graph_transformer/graph_transformer.py +++ b/gigl/src/common/models/graph_transformer/graph_transformer.py @@ -239,6 +239,14 @@ class GraphTransformerEncoderLayer(nn.Module): activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". + relation_attention_mode: Optional relation-aware augmentation strategy + for attention scores. ``"none"`` preserves the default shared + self-attention path. ``"edge_type_additive"`` adds a learned + per-edge-type bilinear term for token pairs backed by sampled + directed graph edges. + num_relations: Number of relation channels expected in + ``pairwise_relation_mask`` when + ``relation_attention_mode="edge_type_additive"``. Raises: ValueError: If model_dim is not divisible by num_heads. @@ -252,16 +260,31 @@ def __init__( dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", + num_relations: int = 0, ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if relation_attention_mode == "edge_type_additive" and num_relations <= 0: + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "num_relations > 0." + ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate + self._relation_attention_mode = relation_attention_mode + self._num_relations = num_relations self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) @@ -269,6 +292,11 @@ def __init__( self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) + self._relation_attention_matrices: Optional[nn.Parameter] = None + if relation_attention_mode == "edge_type_additive": + self._relation_attention_matrices = nn.Parameter( + torch.empty(num_relations, num_heads, self._head_dim, self._head_dim) + ) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( @@ -287,6 +315,10 @@ def reset_parameters(self) -> None: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) + if self._relation_attention_matrices is not None: + for relation_matrices in self._relation_attention_matrices: + for head_matrix in relation_matrices: + nn.init.xavier_uniform_(head_matrix) self._ffn_norm.reset_parameters() self._ffn.reset_parameters() @@ -294,6 +326,7 @@ def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. @@ -303,6 +336,9 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. + pairwise_relation_mask: Optional multi-hot relation mask of shape + ``(batch, seq, seq, num_relations)`` that marks which sampled + directed edge types connect each token pair as ``key -> query``. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. @@ -330,14 +366,23 @@ def forward( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, - ) + if self._relation_attention_mode == "none": + attention_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + else: + attention_output = self._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + ) # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( @@ -360,6 +405,57 @@ def forward( return x + def _run_relation_aware_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_mask: Optional[Tensor], + ) -> Tensor: + if pairwise_relation_mask is None: + raise ValueError( + "pairwise_relation_mask is required when " + "relation_attention_mode='edge_type_additive'." + ) + if pairwise_relation_mask.size(-1) != self._num_relations: + raise ValueError( + "pairwise_relation_mask has unexpected relation dimension " + f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}." + ) + if self._relation_attention_matrices is None: + raise ValueError("Relation attention matrices are not initialized.") + + base_attention_scores = torch.matmul(query, key.transpose(-2, -1)) + relation_scores_by_type = torch.einsum( + "bhkd,rhde,bhqe->bhqkr", + key, + self._relation_attention_matrices.to(dtype=query.dtype), + query, + ) + relation_attention_scores = torch.einsum( + "bhqkr,bqkr->bhqk", + relation_scores_by_type, + pairwise_relation_mask.to(dtype=query.dtype), + ) + + attention_scores = ( + base_attention_scores + relation_attention_scores + ) / math.sqrt(self._head_dim) + if attn_bias is not None: + attention_scores = attention_scores + attn_bias + + attention_weights = F.softmax(attention_scores, dim=-1) + attention_weights = torch.nan_to_num(attention_weights, nan=0.0) + if self.training and self._attention_dropout_rate > 0.0: + attention_weights = F.dropout( + attention_weights, + p=self._attention_dropout_rate, + training=True, + ) + + return torch.matmul(attention_weights, value) + class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. @@ -450,6 +546,10 @@ class GraphTransformerEncoder(nn.Module): uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. + relation_attention_mode: Optional relation-aware augmentation for + attention scores. ``"none"`` preserves the current dense transformer + path. ``"edge_type_additive"`` adds a learned per-edge-type + bilinear score term for sampled directed edges in ``"khop"`` mode. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap @@ -501,6 +601,7 @@ def __init__( pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", **kwargs: object, ) -> None: super().__init__() @@ -542,6 +643,20 @@ def __init__( "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if ( + relation_attention_mode == "edge_type_additive" + and sequence_construction_method != "khop" + ): + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "sequence_construction_method='khop'." + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -573,6 +688,12 @@ def __init__( self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + self._relation_attention_mode = relation_attention_mode + self._relation_attention_edge_types = ( + sorted(edge_type_to_feat_dim_map.keys()) + if relation_attention_mode == "edge_type_additive" + else [] + ) anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None @@ -666,6 +787,8 @@ def __init__( dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, + relation_attention_mode=relation_attention_mode, + num_relations=len(self._relation_attention_edge_types), ) for _ in range(num_layers) ] @@ -803,6 +926,7 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + relation_edge_types=self._relation_attention_edge_types, ) # Free memory after sequences are built @@ -839,6 +963,9 @@ def forward( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, + pairwise_relation_mask=sequence_auxiliary_data.get( + "pairwise_relation_mask" + ), ) embeddings = self._output_projection(embeddings) @@ -1038,6 +1165,7 @@ def _encode_and_readout( sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, ) -> Tensor: """Process sequences through transformer layers and attention readout. @@ -1046,6 +1174,8 @@ def _encode_and_readout( valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. + pairwise_relation_mask: Optional relation mask shaped + ``(batch_size, seq, seq, num_relations)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. @@ -1053,7 +1183,12 @@ def _encode_and_readout( x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: - x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) + x = encoder_layer( + x, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + valid_mask=valid_mask, + ) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..6f117eb97 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -65,12 +65,15 @@ from torch_geometric.typing import NodeType from torch_geometric.utils import to_torch_sparse_tensor +from gigl.src.common.types.graph_data import EdgeType + TokenInputData = dict[str, Tensor] class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_relation_mask: Optional[Tensor] token_input: Optional[TokenInputData] @@ -90,6 +93,7 @@ def heterodata_to_graph_transformer_input( anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + relation_edge_types: Optional[list[EdgeType]] = None, ) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -131,6 +135,10 @@ def heterodata_to_graph_transformer_input( pairwise_attention_bias_attr_names: List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on ``data``. Example: ['pairwise_distance']. + relation_edge_types: Optional ordered edge types used to materialize a + dense per-token-pair relation mask. Each output channel corresponds + to one edge type in this list. Directed edges are placed at + ``[query_pos, key_pos] = [dst_token, src_token]``. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -143,6 +151,8 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_relation_mask"`` shaped + ``(batch, seq, seq, num_relations)`` or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -312,6 +322,15 @@ def heterodata_to_graph_transformer_input( csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, device=device, ) + pairwise_relation_mask = _lookup_pairwise_relation_masks( + data=data, + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + ) anchor_bias_features = _compose_anchor_feature_tensor( anchor_relative_feature_sequences=anchor_relative_feature_sequences, @@ -332,6 +351,7 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_relation_mask": pairwise_relation_mask, "token_input": token_input_features, }, ) @@ -875,6 +895,104 @@ def _lookup_pairwise_relative_features( return features +def _lookup_pairwise_relation_masks( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + relation_edge_types: Optional[list[EdgeType]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, +) -> Optional[Tensor]: + """Build a dense per-token-pair multi-hot relation mask. + + For each ordered token pair ``(query_pos, key_pos)``, this returns a + multi-hot vector indicating which directed sampled graph edges connect the + underlying nodes as ``key -> query``. The relation channel order is defined + by ``relation_edge_types``. + """ + if not relation_edge_types: + return None + + relation_adjacency_matrices = _build_relation_adjacency_matrices( + data=data, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + ) + if not relation_adjacency_matrices: + return None + + return _lookup_pairwise_relative_features( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + csr_matrices=relation_adjacency_matrices, + device=device, + ) + + +def _build_relation_adjacency_matrices( + data: HeteroData, + relation_edge_types: list[EdgeType], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, +) -> list[Tensor]: + """Create one binary CSR adjacency matrix per requested edge type.""" + adjacency_matrices: list[Tensor] = [] + empty_row_ptr = torch.zeros(num_nodes + 1, dtype=torch.int64, device=device) + empty_col_idx = torch.zeros(0, dtype=torch.int64, device=device) + empty_values = torch.zeros(0, dtype=torch.float, device=device) + + for edge_type in relation_edge_types: + if edge_type not in data.edge_types: + adjacency_matrices.append( + torch.sparse_csr_tensor( + empty_row_ptr, + empty_col_idx, + empty_values, + size=(num_nodes, num_nodes), + ) + ) + continue + + edge_index = data[edge_type].edge_index.to(device) + src_indices = edge_index[0].long() + int(node_type_offsets[edge_type[0]]) + dst_indices = edge_index[1].long() + int(node_type_offsets[edge_type[2]]) + if src_indices.numel() == 0: + adjacency_matrices.append( + torch.sparse_csr_tensor( + empty_row_ptr, + empty_col_idx, + empty_values, + size=(num_nodes, num_nodes), + ) + ) + continue + + coalesced_adjacency = torch.sparse_coo_tensor( + torch.stack([dst_indices, src_indices]), + torch.ones(src_indices.numel(), dtype=torch.float, device=device), + size=(num_nodes, num_nodes), + ).coalesce() + adjacency_matrices.append( + torch.sparse_coo_tensor( + coalesced_adjacency.indices(), + torch.ones( + coalesced_adjacency.indices().size(1), + dtype=torch.float, + device=device, + ), + size=(num_nodes, num_nodes), + ) + .coalesce() + .to_sparse_csr() + ) + + return adjacency_matrices + + def _get_k_hop_neighbors_sparse( anchor_indices: Tensor, edge_index: Tensor, diff --git a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py index c55d33d91..9586a13e7 100644 --- a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py +++ b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py @@ -274,6 +274,22 @@ def _create_user_graph_with_ppr_edges() -> HeteroData: return data +def _create_user_graph_with_relation_edges() -> HeteroData: + data = HeteroData() + + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ) + data["user", "follows", "user"].edge_index = torch.tensor([[0, 1], [1, 2]]) + data["user", "likes", "user"].edge_index = torch.tensor([[0], [1]]) + + return data + + class TestGraphTransformerEncoderPEModes(TestCase): def setUp(self) -> None: self._node_type = NodeType("user") @@ -411,6 +427,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), + "pairwise_relation_mask": None, "token_input": None, }, ) @@ -447,6 +464,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, + "pairwise_relation_mask": None, "token_input": None, }, ) @@ -773,6 +791,160 @@ def test_layer_with_swiglu(self) -> None: self.assertEqual(out.shape, (2, 10, 32)) +class TestGraphTransformerRelationAttention(TestCase): + def setUp(self) -> None: + self._node_type = NodeType("user") + self._follows_edge_type = EdgeType( + self._node_type, Relation("follows"), self._node_type + ) + self._likes_edge_type = EdgeType( + self._node_type, Relation("likes"), self._node_type + ) + self._device = torch.device("cpu") + + def test_relation_aware_layer_matches_baseline_when_matrices_are_zero( + self, + ) -> None: + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + num_relations=2, + ) + relation_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + relation_mask = torch.zeros((2, 4, 4, 2), dtype=torch.float) + relation_mask[0, 1, 0, 0] = 1.0 + relation_mask[1, 2, 1, 1] = 1.0 + + with torch.no_grad(): + assert relation_layer._relation_attention_matrices is not None + relation_layer._relation_attention_matrices.zero_() + base_output = base_layer(x, valid_mask=valid_mask) + relation_output = relation_layer( + x, + pairwise_relation_mask=relation_mask, + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_relation_aware_attention_only_changes_marked_query_positions( + self, + ) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + num_relations=1, + ) + + query = torch.tensor([[[[1.0, 0.0], [1.0, 0.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + value = torch.tensor([[[[3.0, 0.0], [0.0, 5.0]]]]) + empty_relation_mask = torch.zeros((1, 2, 2, 1), dtype=torch.float) + active_relation_mask = empty_relation_mask.clone() + active_relation_mask[0, 1, 0, 0] = 1.0 + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices.zero_() + layer._relation_attention_matrices[0, 0] = torch.eye(2) + base_output = layer._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=None, + pairwise_relation_mask=empty_relation_mask, + ) + relation_output = layer._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=None, + pairwise_relation_mask=active_relation_mask, + ) + + self.assertTrue( + torch.allclose(base_output[:, :, 0], relation_output[:, :, 0], atol=1e-6) + ) + self.assertFalse( + torch.allclose(base_output[:, :, 1], relation_output[:, :, 1], atol=1e-6) + ) + + def test_encoder_forward_supports_relation_aware_attention(self) -> None: + data = _create_user_graph_with_relation_edges() + encoder = GraphTransformerEncoder( + node_type_to_feat_dim_map={self._node_type: 4}, + edge_type_to_feat_dim_map={ + self._likes_edge_type: 0, + self._follows_edge_type: 0, + }, + hid_dim=8, + out_dim=6, + num_layers=1, + num_heads=2, + max_seq_len=4, + hop_distance=2, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + ) + encoder.eval() + + with torch.no_grad(): + embeddings = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertEqual( + encoder._relation_attention_edge_types, + sorted([self._likes_edge_type, self._follows_edge_type]), + ) + self.assertEqual(embeddings.shape, (3, 6)) + self.assertFalse(torch.isnan(embeddings).any()) + + def test_encoder_rejects_relation_attention_in_ppr_mode(self) -> None: + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + with self.assertRaisesRegex( + ValueError, + "relation_attention_mode='edge_type_additive' requires " + "sequence_construction_method='khop'", + ): + GraphTransformerEncoder( + node_type_to_feat_dim_map={self._node_type: 4}, + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + hid_dim=8, + out_dim=6, + num_layers=1, + num_heads=2, + max_seq_len=4, + hop_distance=1, + relation_attention_mode="edge_type_additive", + sequence_construction_method="ppr", + ) + + class TestGraphTransformerEncoderFeedforwardRatio(TestCase): """Tests for GraphTransformerEncoder feedforward_ratio parameter.""" diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 25cda0821..ac3f12a95 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -7,6 +7,7 @@ from absl.testing import absltest from torch_geometric.data import HeteroData +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.transforms.graph_transformer import ( _get_k_hop_neighbors_sparse, heterodata_to_graph_transformer_input, @@ -122,6 +123,23 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def create_relation_mask_hetero_data() -> HeteroData: + """Create a single-node-type graph with overlapping directed relations.""" + data = HeteroData() + + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ) + data["user", "follows", "user"].edge_index = torch.tensor([[0, 1], [1, 2]]) + data["user", "likes", "user"].edge_index = torch.tensor([[0], [1]]) + + return data + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -252,6 +270,7 @@ def test_basic_transform(self): self.assertIsInstance(attention_bias_data, dict) self.assertIn("anchor_bias", attention_bias_data) self.assertIn("pairwise_bias", attention_bias_data) + self.assertIn("pairwise_relation_mask", attention_bias_data) def test_attention_mask_validity(self): """Test that attention mask correctly identifies valid positions.""" @@ -491,6 +510,48 @@ def test_ppr_sequence_can_return_token_input_and_attention_bias_features(self): torch.equal(valid_mask[1], torch.tensor([True, True, True, False])) ) + def test_relation_mask_outputs_follow_requested_order_and_direction(self): + data = create_relation_mask_hetero_data() + user_node_type = NodeType("user") + follows_edge_type = EdgeType( + user_node_type, Relation("follows"), user_node_type + ) + likes_edge_type = EdgeType(user_node_type, Relation("likes"), user_node_type) + + sequences, valid_mask, attention_bias_data = ( + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + relation_edge_types=[likes_edge_type, follows_edge_type], + ) + ) + + relation_mask = attention_bias_data["pairwise_relation_mask"] + assert relation_mask is not None + + expected_sequences = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ] + ) + self.assertTrue(torch.allclose(sequences, expected_sequences)) + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + self.assertEqual(relation_mask.shape, (1, 4, 4, 2)) + self.assertTrue(torch.equal(relation_mask[0, 1, 0], torch.tensor([1.0, 1.0]))) + self.assertTrue(torch.equal(relation_mask[0, 2, 1], torch.tensor([0.0, 1.0]))) + self.assertTrue(torch.equal(relation_mask[0, 0, 1], torch.zeros(2))) + self.assertTrue(torch.all(relation_mask[0, 3] == 0)) + class TestPyTorchTransformerIntegration(TestCase): """Tests for integration with PyTorch TransformerEncoderLayer.""" From 6b60820bc2eaac5159d98ded4b46411df516523f Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Mon, 27 Apr 2026 15:52:57 -0700 Subject: [PATCH 3/6] optim --- .../graph_transformer/graph_transformer.py | 99 ++++++++++++++----- gigl/transforms/graph_transformer.py | 29 +++++- .../graph_transformer_test.py | 10 +- .../unit/transforms/graph_transformer_test.py | 13 ++- 4 files changed, 111 insertions(+), 40 deletions(-) diff --git a/gigl/src/common/models/graph_transformer/graph_transformer.py b/gigl/src/common/models/graph_transformer/graph_transformer.py index e27d382fb..ae6e3075b 100644 --- a/gigl/src/common/models/graph_transformer/graph_transformer.py +++ b/gigl/src/common/models/graph_transformer/graph_transformer.py @@ -336,7 +336,7 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. - pairwise_relation_mask: Optional multi-hot relation mask of shape + pairwise_relation_mask: Optional boolean multi-hot relation mask of shape ``(batch, seq, seq, num_relations)`` that marks which sampled directed edge types connect each token pair as ``key -> query``. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used @@ -413,6 +413,33 @@ def _run_relation_aware_attention( attn_bias: Optional[Tensor], pairwise_relation_mask: Optional[Tensor], ) -> Tensor: + relation_attention_bias = self._build_relation_attention_bias( + query=query, + key=key, + pairwise_relation_mask=pairwise_relation_mask, + ) + if relation_attention_bias is not None: + attn_bias = ( + relation_attention_bias + if attn_bias is None + else attn_bias + relation_attention_bias + ) + + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _build_relation_attention_bias( + self, + query: Tensor, + key: Tensor, + pairwise_relation_mask: Optional[Tensor], + ) -> Optional[Tensor]: if pairwise_relation_mask is None: raise ValueError( "pairwise_relation_mask is required when " @@ -425,36 +452,54 @@ def _run_relation_aware_attention( ) if self._relation_attention_matrices is None: raise ValueError("Relation attention matrices are not initialized.") + if pairwise_relation_mask.size(1) != query.size(2) or pairwise_relation_mask.size( + 2 + ) != key.size(2): + raise ValueError( + "pairwise_relation_mask must align with the query/key sequence " + "dimensions." + ) - base_attention_scores = torch.matmul(query, key.transpose(-2, -1)) - relation_scores_by_type = torch.einsum( - "bhkd,rhde,bhqe->bhqkr", - key, - self._relation_attention_matrices.to(dtype=query.dtype), - query, - ) - relation_attention_scores = torch.einsum( - "bhqkr,bqkr->bhqk", - relation_scores_by_type, - pairwise_relation_mask.to(dtype=query.dtype), + relation_mask = pairwise_relation_mask.to( + device=query.device, + dtype=torch.bool, ) + active_relation_positions = relation_mask.nonzero(as_tuple=False) + if active_relation_positions.numel() == 0: + return None - attention_scores = ( - base_attention_scores + relation_attention_scores - ) / math.sqrt(self._head_dim) - if attn_bias is not None: - attention_scores = attention_scores + attn_bias - - attention_weights = F.softmax(attention_scores, dim=-1) - attention_weights = torch.nan_to_num(attention_weights, nan=0.0) - if self.training and self._attention_dropout_rate > 0.0: - attention_weights = F.dropout( - attention_weights, - p=self._attention_dropout_rate, - training=True, + relation_attention_bias = query.new_zeros( + (query.size(0), query.size(2), key.size(2), self._num_heads) + ) + query_by_position = query.transpose(1, 2) + key_by_position = key.transpose(1, 2) + relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) + active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) + + for relation_idx_tensor in active_relation_ids: + relation_idx = int(relation_idx_tensor.item()) + relation_positions = active_relation_positions[ + active_relation_positions[:, 3] == relation_idx + ] + batch_indices, query_indices, key_indices = relation_positions[ + :, :3 + ].unbind(dim=1) + # Only materialize bilinear scores for token pairs backed by this relation. + selected_query = query_by_position[batch_indices, query_indices] + transformed_query = torch.einsum( + "nhe,hde->nhd", + selected_query, + relation_matrices[relation_idx], + ) + selected_key = key_by_position[batch_indices, key_indices] + relation_scores = (selected_key * transformed_query).sum(dim=-1) + relation_attention_bias.index_put_( + (batch_indices, query_indices, key_indices), + relation_scores / math.sqrt(self._head_dim), + accumulate=True, ) - return torch.matmul(attention_weights, value) + return relation_attention_bias.permute(0, 3, 1, 2) class GraphTransformerEncoder(nn.Module): @@ -1174,7 +1219,7 @@ def _encode_and_readout( valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. - pairwise_relation_mask: Optional relation mask shaped + pairwise_relation_mask: Optional boolean relation mask shaped ``(batch_size, seq, seq, num_relations)``. Returns: diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 6f117eb97..304ca0795 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -904,7 +904,7 @@ def _lookup_pairwise_relation_masks( num_nodes: int, device: torch.device, ) -> Optional[Tensor]: - """Build a dense per-token-pair multi-hot relation mask. + """Build a dense per-token-pair boolean multi-hot relation mask. For each ordered token pair ``(query_pos, key_pos)``, this returns a multi-hot vector indicating which directed sampled graph edges connect the @@ -924,12 +924,31 @@ def _lookup_pairwise_relation_masks( if not relation_adjacency_matrices: return None - return _lookup_pairwise_relative_features( - node_index_sequences=node_index_sequences, - valid_mask=valid_mask, - csr_matrices=relation_adjacency_matrices, + batch_size, max_seq_len = node_index_sequences.shape + num_relations = len(relation_adjacency_matrices) + relation_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len, num_relations), + dtype=torch.bool, device=device, ) + pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) + if not pair_valid_mask.any(): + return relation_mask + + row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) + col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) + valid_row_indices = row_indices[pair_valid_mask] + valid_col_indices = col_indices[pair_valid_mask] + + for relation_idx, adjacency_matrix in enumerate(relation_adjacency_matrices): + relation_values = _lookup_csr_values( + csr_matrix=adjacency_matrix, + row_indices=valid_row_indices, + col_indices=valid_col_indices, + ) + relation_mask[..., relation_idx][pair_valid_mask] = relation_values.ne(0.0) + + return relation_mask def _build_relation_adjacency_matrices( diff --git a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py index 9586a13e7..f2b1c023c 100644 --- a/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py +++ b/tests/unit/src/common/models/graph_transformer/graph_transformer_test.py @@ -827,9 +827,9 @@ def test_relation_aware_layer_matches_baseline_when_matrices_are_zero( x = torch.randn(2, 4, 8) valid_mask = torch.ones((2, 4), dtype=torch.bool) - relation_mask = torch.zeros((2, 4, 4, 2), dtype=torch.float) - relation_mask[0, 1, 0, 0] = 1.0 - relation_mask[1, 2, 1, 1] = 1.0 + relation_mask = torch.zeros((2, 4, 4, 2), dtype=torch.bool) + relation_mask[0, 1, 0, 0] = True + relation_mask[1, 2, 1, 1] = True with torch.no_grad(): assert relation_layer._relation_attention_matrices is not None @@ -859,9 +859,9 @@ def test_relation_aware_attention_only_changes_marked_query_positions( query = torch.tensor([[[[1.0, 0.0], [1.0, 0.0]]]]) key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) value = torch.tensor([[[[3.0, 0.0], [0.0, 5.0]]]]) - empty_relation_mask = torch.zeros((1, 2, 2, 1), dtype=torch.float) + empty_relation_mask = torch.zeros((1, 2, 2, 1), dtype=torch.bool) active_relation_mask = empty_relation_mask.clone() - active_relation_mask[0, 1, 0, 0] = 1.0 + active_relation_mask[0, 1, 0, 0] = True with torch.no_grad(): assert layer._relation_attention_matrices is not None diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index ac3f12a95..ee7e7ca78 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -547,9 +547,16 @@ def test_relation_mask_outputs_follow_requested_order_and_direction(self): torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) ) self.assertEqual(relation_mask.shape, (1, 4, 4, 2)) - self.assertTrue(torch.equal(relation_mask[0, 1, 0], torch.tensor([1.0, 1.0]))) - self.assertTrue(torch.equal(relation_mask[0, 2, 1], torch.tensor([0.0, 1.0]))) - self.assertTrue(torch.equal(relation_mask[0, 0, 1], torch.zeros(2))) + self.assertEqual(relation_mask.dtype, torch.bool) + self.assertTrue( + torch.equal(relation_mask[0, 1, 0], torch.tensor([True, True])) + ) + self.assertTrue( + torch.equal(relation_mask[0, 2, 1], torch.tensor([False, True])) + ) + self.assertTrue( + torch.equal(relation_mask[0, 0, 1], torch.zeros(2, dtype=torch.bool)) + ) self.assertTrue(torch.all(relation_mask[0, 3] == 0)) From b1e508f0d0cc7f05b3ae15decde199d82f34751d Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 5 May 2026 15:03:46 -0700 Subject: [PATCH 4/6] fix conflict --- gigl/distributed/dist_dataset.py | 16 - gigl/distributed/graph_store/storage_utils.py | 13 +- .../graph_transformer/graph_transformer.py | 1215 ----------------- gigl/utils/data_splitters.py | 22 - .../unit/transforms/graph_transformer_test.py | 4 +- tests/unit/utils/data_splitters_test.py | 9 - 6 files changed, 2 insertions(+), 1277 deletions(-) diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index cc0070fa0..b40f2969a 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -29,10 +29,6 @@ from gigl.utils.data_splitters import ( NodeAnchorLinkSplitter, NodeSplitter, -<<<<<<< HEAD - validate_max_labels_per_anchor_node, -======= ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 ) from gigl.utils.share_memory import share_memory @@ -153,13 +149,7 @@ def __init__( self._degree_tensor: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor -<<<<<<< HEAD - self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node( - max_labels_per_anchor_node - ) -======= self._max_labels_per_anchor_node = max_labels_per_anchor_node ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear # naming (i.e. rank, world_size). @@ -354,13 +344,7 @@ def max_labels_per_anchor_node(self) -> Optional[int]: def max_labels_per_anchor_node( self, new_max_labels_per_anchor_node: Optional[int] ) -> None: -<<<<<<< HEAD - self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node( - new_max_labels_per_anchor_node - ) -======= self._max_labels_per_anchor_node = new_max_labels_per_anchor_node ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 @property def train_node_ids( diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index e2237d801..3158fc0fe 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -35,10 +35,6 @@ from gigl.utils.data_splitters import ( DistNodeAnchorLinkSplitter, DistNodeSplitter, -<<<<<<< HEAD - get_max_labels_per_anchor_node_from_runtime_args, -======= ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 ) logger = Logger() @@ -81,12 +77,7 @@ def build_storage_dataset( ``0.1`` selects 10 % of edges. max_labels_per_anchor_node: Optional cap for how many labels to materialize per anchor node when the storage server serves ABLP -<<<<<<< HEAD - input. If ``None``, this is inferred from the task config's - ``trainer_args``. -======= input. ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 Returns: A partitioned :class:`DistDataset` ready to be served. @@ -103,10 +94,8 @@ def build_storage_dataset( graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, tfrecord_uri_pattern=tf_record_uri_pattern, ) -<<<<<<< HEAD -======= + # TODO: Pipe in max_labels_per_anchor_node to build_dataset. ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, sample_edge_direction=sample_edge_direction, diff --git a/gigl/src/common/models/graph_transformer/graph_transformer.py b/gigl/src/common/models/graph_transformer/graph_transformer.py index b64772865..87d8cbda1 100644 --- a/gigl/src/common/models/graph_transformer/graph_transformer.py +++ b/gigl/src/common/models/graph_transformer/graph_transformer.py @@ -37,1222 +37,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -<<<<<<< HEAD - -def _build_sinusoidal_sequence_position_table( - max_seq_len: int, - hid_dim: int, -) -> Tensor: - """Build a standard sinusoidal absolute position table.""" - positions = torch.arange(max_seq_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, hid_dim, 2, dtype=torch.float) * (-math.log(10000.0) / hid_dim) - ) - - position_table = torch.zeros(max_seq_len, hid_dim, dtype=torch.float) - position_table[:, 0::2] = torch.sin(positions * div_term) - if hid_dim > 1: - position_table[:, 1::2] = torch.cos( - positions * div_term[: position_table[:, 1::2].shape[1]] - ) - return position_table - - -# Supported activation functions for FeedForwardNetwork -_ACTIVATION_FNS = { - "gelu": nn.GELU, - "relu": nn.ReLU, - "silu": nn.SiLU, # Also known as Swish - "tanh": nn.Tanh, -} - -# XGLU activations use a gating mechanism: activation(xW) * xV -# where W and V are separate linear projections -_XGLU_BASE_ACTIVATIONS = { - "geglu": F.gelu, - "swiglu": F.silu, - "reglu": F.relu, -} - - -class FeedForwardNetwork(nn.Module): - """Two-layer feed-forward network with configurable activation. - - Supports standard activations (GELU, ReLU, SiLU) and XGLU family - (SwiGLU, GeGLU, ReGLU) which use a gating mechanism. - - Note: This module does NOT include LayerNorm. Normalization should be - applied externally (e.g., pre-norm in the transformer layer). - - Adapted from RelGT's FeedForwardNetwork. - - Args: - model_dim: Model (input and output) dimension of the FFN. - feedforward_dim: Inner dimension of the two-layer MLP. - dropout_rate: Dropout probability applied after each linear layer. - activation: Activation function name. Supported values: - - Standard: "gelu" (default), "relu", "silu", "tanh" - - XGLU family: "geglu", "swiglu", "reglu" - XGLU activations use gating: activation(xW) * xV, which requires - projecting to 2x feedforward_dim internally. - """ - - def __init__( - self, - model_dim: int, - feedforward_dim: int, - dropout_rate: float = 0.1, - activation: str = "gelu", - ) -> None: - super().__init__() - self._activation_name = activation.lower() - - # Validate activation - if ( - self._activation_name not in _ACTIVATION_FNS - and self._activation_name not in _XGLU_BASE_ACTIVATIONS - ): - supported = sorted( - set(_ACTIVATION_FNS.keys()) | set(_XGLU_BASE_ACTIVATIONS.keys()) - ) - raise ValueError( - f"Unsupported activation '{activation}'. Supported: {supported}" - ) - - self._is_xglu = self._activation_name in _XGLU_BASE_ACTIVATIONS - - # Type declarations for optional attributes - self._xglu_base_activation: Optional[Callable[..., Tensor]] = None - self._linear_in: Optional[nn.Linear] = None - self._dropout_in: Optional[nn.Dropout] = None - self._linear_out: Optional[nn.Linear] = None - self._dropout_out: Optional[nn.Dropout] = None - self._ffn: Optional[nn.Sequential] = None - - if self._is_xglu: - # XGLU: project to 2x feedforward_dim, split, apply gating - self._xglu_base_activation = cast( - Callable[..., Tensor], _XGLU_BASE_ACTIVATIONS[self._activation_name] - ) - self._linear_in = nn.Linear(model_dim, feedforward_dim * 2) - self._dropout_in = nn.Dropout(dropout_rate) - self._linear_out = nn.Linear(feedforward_dim, model_dim) - self._dropout_out = nn.Dropout(dropout_rate) - else: - # Standard activation - activation_fn = _ACTIVATION_FNS[self._activation_name] - self._ffn = nn.Sequential( - nn.Linear(model_dim, feedforward_dim), - activation_fn(), - nn.Dropout(dropout_rate), - nn.Linear(feedforward_dim, model_dim), - nn.Dropout(dropout_rate), - ) - - def reset_parameters(self) -> None: - """Reinitialize all learnable parameters.""" - if self._is_xglu: - assert self._linear_in is not None - assert self._linear_out is not None - nn.init.xavier_uniform_(self._linear_in.weight) - nn.init.zeros_(self._linear_in.bias) - nn.init.xavier_uniform_(self._linear_out.weight) - nn.init.zeros_(self._linear_out.bias) - else: - # Use xavier + zero bias for consistency with XGLU path and - # GraphTransformerEncoderLayer (standard Transformer practice) - assert self._ffn is not None - for layer in self._ffn: - if isinstance(layer, nn.Linear): - nn.init.xavier_uniform_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x: Input tensor of shape ``(batch, seq, model_dim)``. - - Returns: - Output tensor of shape ``(batch, seq, model_dim)``. - """ - if self._is_xglu: - # XGLU gating: activation(x @ W1) * (x @ W2) - # where W1 and W2 are the two halves of linear_in - assert self._xglu_base_activation is not None - assert self._linear_in is not None - assert self._dropout_in is not None - assert self._linear_out is not None - assert self._dropout_out is not None - x_proj = self._linear_in(x) # (batch, seq, feedforward_dim * 2) - x_gate, x_value = x_proj.chunk( - 2, dim=-1 - ) # Each: (batch, seq, feedforward_dim) - x = self._xglu_base_activation(x_gate) * x_value - x = self._dropout_in(x) - x = self._linear_out(x) - x = self._dropout_out(x) - else: - assert self._ffn is not None - x = self._ffn(x) - - return x - - -class GraphTransformerEncoderLayer(nn.Module): - """Pre-norm transformer encoder layer with multi-head self-attention. - - Uses ``F.scaled_dot_product_attention`` which automatically selects the - most efficient attention implementation (flash, memory-efficient, or - math-based) based on input properties and hardware. - - Adapted from RelGT's EncoderLayer. - - Args: - model_dim: Model dimension (d_model). - num_heads: Number of attention heads. Must evenly divide model_dim. - feedforward_dim: Inner dimension of the feed-forward network. - dropout_rate: Dropout probability for feed-forward layers. - attention_dropout_rate: Dropout probability for attention weights. - activation: Activation function for the feed-forward network. - Supported values: "gelu" (default), "relu", "silu", "tanh", - "geglu", "swiglu", "reglu". - relation_attention_mode: Optional relation-aware augmentation strategy - for attention scores. ``"none"`` preserves the default shared - self-attention path. ``"edge_type_additive"`` adds a learned - per-edge-type bilinear term for token pairs backed by sampled - directed graph edges. - num_relations: Number of relation channels expected in - ``pairwise_relation_mask`` when - ``relation_attention_mode="edge_type_additive"``. - - Raises: - ValueError: If model_dim is not divisible by num_heads. - """ - - def __init__( - self, - model_dim: int, - num_heads: int, - feedforward_dim: int, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - activation: str = "gelu", - relation_attention_mode: Literal["none", "edge_type_additive"] = "none", - num_relations: int = 0, - ) -> None: - super().__init__() - if model_dim % num_heads != 0: - raise ValueError( - f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" - ) - if relation_attention_mode not in {"none", "edge_type_additive"}: - raise ValueError( - "relation_attention_mode must be one of " - "{'none', 'edge_type_additive'}, " - f"got '{relation_attention_mode}'" - ) - if relation_attention_mode == "edge_type_additive" and num_relations <= 0: - raise ValueError( - "relation_attention_mode='edge_type_additive' requires " - "num_relations > 0." - ) - - self._num_heads = num_heads - self._head_dim = model_dim // num_heads - self._attention_dropout_rate = attention_dropout_rate - self._relation_attention_mode = relation_attention_mode - self._num_relations = num_relations - - self._attention_norm = nn.LayerNorm(model_dim) - self._query_projection = nn.Linear(model_dim, model_dim) - self._key_projection = nn.Linear(model_dim, model_dim) - self._value_projection = nn.Linear(model_dim, model_dim) - self._output_projection = nn.Linear(model_dim, model_dim) - self._dropout = nn.Dropout(dropout_rate) - self._relation_attention_matrices: Optional[nn.Parameter] = None - if relation_attention_mode == "edge_type_additive": - self._relation_attention_matrices = nn.Parameter( - torch.empty(num_relations, num_heads, self._head_dim, self._head_dim) - ) - - self._ffn_norm = nn.LayerNorm(model_dim) - self._ffn = FeedForwardNetwork( - model_dim, feedforward_dim, dropout_rate, activation=activation - ) - - def reset_parameters(self) -> None: - """Reinitialize all learnable parameters.""" - self._attention_norm.reset_parameters() - for projection in [ - self._query_projection, - self._key_projection, - self._value_projection, - self._output_projection, - ]: - nn.init.xavier_uniform_(projection.weight) - if projection.bias is not None: - nn.init.zeros_(projection.bias) - if self._relation_attention_matrices is not None: - for relation_matrices in self._relation_attention_matrices: - for head_matrix in relation_matrices: - nn.init.xavier_uniform_(head_matrix) - self._ffn_norm.reset_parameters() - self._ffn.reset_parameters() - - def forward( - self, - x: Tensor, - attn_bias: Optional[Tensor] = None, - pairwise_relation_mask: Optional[Tensor] = None, - valid_mask: Optional[Tensor] = None, - ) -> Tensor: - """Forward pass. - - Args: - x: Input tensor of shape ``(batch, seq, model_dim)``. - attn_bias: Optional attention bias of shape - ``(batch, num_heads, seq, seq)`` or broadcastable. - Added as an additive mask to attention scores. - pairwise_relation_mask: Optional boolean multi-hot relation mask of shape - ``(batch, seq, seq, num_relations)`` that marks which sampled - directed edge types connect each token pair as ``key -> query``. - valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used - to zero out padded token states after each residual block. - - Returns: - Output tensor of shape ``(batch, seq, model_dim)``. - """ - batch_size, seq_len, model_dim = x.shape - - # Self-attention block (pre-norm) - residual = x - x_norm = self._attention_norm(x) - - query = self._query_projection(x_norm) - key = self._key_projection(x_norm) - value = self._value_projection(x_norm) - - # Reshape to (batch, num_heads, seq, head_dim) - query = query.view( - batch_size, seq_len, self._num_heads, self._head_dim - ).transpose(1, 2) - key = key.view(batch_size, seq_len, self._num_heads, self._head_dim).transpose( - 1, 2 - ) - value = value.view( - batch_size, seq_len, self._num_heads, self._head_dim - ).transpose(1, 2) - - if self._relation_attention_mode == "none": - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, - ) - else: - attention_output = self._run_relation_aware_attention( - query=query, - key=key, - value=value, - attn_bias=attn_bias, - pairwise_relation_mask=pairwise_relation_mask, - ) - - # Reshape back to (batch, seq, model_dim) - attention_output = attention_output.transpose(1, 2).reshape( - batch_size, seq_len, model_dim - ) - attention_output = self._output_projection(attention_output) - attention_output = self._dropout(attention_output) - - x = residual + attention_output - if valid_mask is not None: - x = x * valid_mask.unsqueeze(-1).to(x.dtype) - - # Feed-forward block (pre-norm) - residual = x - x_norm = self._ffn_norm(x) - ffn_output = self._ffn(x_norm) - x = residual + ffn_output - if valid_mask is not None: - x = x * valid_mask.unsqueeze(-1).to(x.dtype) - - return x - - def _run_relation_aware_attention( - self, - query: Tensor, - key: Tensor, - value: Tensor, - attn_bias: Optional[Tensor], - pairwise_relation_mask: Optional[Tensor], - ) -> Tensor: - relation_attention_bias = self._build_relation_attention_bias( - query=query, - key=key, - pairwise_relation_mask=pairwise_relation_mask, - ) - if relation_attention_bias is not None: - attn_bias = ( - relation_attention_bias - if attn_bias is None - else attn_bias + relation_attention_bias - ) - - return F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, - ) - - def _build_relation_attention_bias( - self, - query: Tensor, - key: Tensor, - pairwise_relation_mask: Optional[Tensor], - ) -> Optional[Tensor]: - if pairwise_relation_mask is None: - raise ValueError( - "pairwise_relation_mask is required when " - "relation_attention_mode='edge_type_additive'." - ) - if pairwise_relation_mask.size(-1) != self._num_relations: - raise ValueError( - "pairwise_relation_mask has unexpected relation dimension " - f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}." - ) - if self._relation_attention_matrices is None: - raise ValueError("Relation attention matrices are not initialized.") - if pairwise_relation_mask.size(1) != query.size(2) or pairwise_relation_mask.size( - 2 - ) != key.size(2): - raise ValueError( - "pairwise_relation_mask must align with the query/key sequence " - "dimensions." - ) - - relation_mask = pairwise_relation_mask.to( - device=query.device, - dtype=torch.bool, - ) - active_relation_positions = relation_mask.nonzero(as_tuple=False) - if active_relation_positions.numel() == 0: - return None - - relation_attention_bias = query.new_zeros( - (query.size(0), query.size(2), key.size(2), self._num_heads) - ) - query_by_position = query.transpose(1, 2) - key_by_position = key.transpose(1, 2) - relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) - active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) - - for relation_idx_tensor in active_relation_ids: - relation_idx = int(relation_idx_tensor.item()) - relation_positions = active_relation_positions[ - active_relation_positions[:, 3] == relation_idx - ] - batch_indices, query_indices, key_indices = relation_positions[ - :, :3 - ].unbind(dim=1) - # Only materialize bilinear scores for token pairs backed by this relation. - selected_query = query_by_position[batch_indices, query_indices] - transformed_query = torch.einsum( - "nhe,hde->nhd", - selected_query, - relation_matrices[relation_idx], - ) - selected_key = key_by_position[batch_indices, key_indices] - relation_scores = (selected_key * transformed_query).sum(dim=-1) - relation_attention_bias.index_put_( - (batch_indices, query_indices, key_indices), - relation_scores / math.sqrt(self._head_dim), - accumulate=True, - ) - - return relation_attention_bias.permute(0, 3, 1, 2) - - -class GraphTransformerEncoder(nn.Module): - """Graph Transformer encoder for heterogeneous graphs. - - Converts heterogeneous graph data into fixed-length sequences via - ``heterodata_to_graph_transformer_input``, processes through pre-norm - transformer encoder layers, and produces per-node embeddings via - attention-weighted neighbor readout (from RelGT's LocalModule). - - Conforms to the same forward interface as ``HGT`` and ``SimpleHGN``, - making it a drop-in encoder for ``LinkPredictionGNN``. - - Args: - node_type_to_feat_dim_map: Dictionary mapping node types to their - input feature dimensions. - edge_type_to_feat_dim_map: Dictionary mapping edge types to their - feature dimensions. Accepted for interface conformance with - ``HGT``/``SimpleHGN``; edge features are not used by the - graph transformer. - hid_dim: Hidden dimension for transformer layers. All node types - are projected to this dimension before processing. - out_dim: Output embedding dimension. - num_layers: Number of transformer encoder layers. - num_heads: Number of attention heads per layer. Must evenly divide - ``hid_dim``. - max_seq_len: Maximum sequence length for the graph-to-sequence - transform. Neighborhoods are truncated to this length. - hop_distance: Number of hops for neighborhood extraction in the - graph-to-sequence transform when using ``"khop"`` sequence construction. - sequence_construction_method: Sequence builder used to create tokens for - each anchor. ``"khop"`` expands the sampled graph by hop distance, - while ``"ppr"`` consumes outgoing ``"ppr"`` edges sorted by weight. - sequence_positional_encoding_type: Optional sequence-level positional - encoding applied after sequence construction. Supported values are - ``None`` and ``"sinusoidal"``. Lower-cost future extensions could - add learned absolute position embeddings here, while attention-level - options like RoPE or ALiBi would require changes inside the - attention block. - dropout_rate: Dropout probability for feed-forward layers. - attention_dropout_rate: Dropout probability for attention weights. - should_l2_normalize_embedding_layer_output: Whether to L2 normalize - output embeddings. - pe_attr_names: List of node-level positional encoding attribute names. - In ``"concat"`` mode these are concatenated to sequence features. - In ``"add"`` mode they are projected to ``hid_dim`` and added to - node features before sequence construction. - anchor_based_attention_bias_attr_names: List of anchor-relative feature - names used as additive attention bias for sequence keys. Sparse - graph-level attributes are looked up from ``data`` and the reserved - name ``"ppr_weight"`` resolves to PPR edge weights in PPR mode. - Example: ``['hop_distance', 'ppr_weight']`` where ``hop_distance`` - is a sparse matrix attribute on ``data`` and ``ppr_weight`` is - extracted from PPR edge weights. - anchor_based_input_attr_names: List of anchor-relative attribute names - used as token-aligned input features. Sparse graph-level attributes - are looked up from ``data`` and ``"ppr_weight"`` resolves to PPR - edge weights in PPR mode. These are projected to ``hid_dim`` and - added to the sequence tokens after sequence construction. - Example: ``['hop_distance', 'ppr_weight']`` for continuous features, - or ``['hop_distance']`` when ``hop_distance`` will be embedded via - ``anchor_based_input_embedding_dict``. - anchor_based_input_embedding_dict: Optional ModuleDict mapping a subset - of ``anchor_based_input_attr_names`` to per-attribute embedding - layers. These attributes are treated as discrete indices and their - embedded contributions are added to the sequence tokens. Padding is - masked out using the sequence valid mask. - Example: ``nn.ModuleDict({'hop_distance': nn.Embedding(10, hid_dim)})`` - to embed hop distances 0-9 into ``hid_dim``-dimensional vectors. - The embedding output dimension must match ``hid_dim``. - pairwise_attention_bias_attr_names: List of pairwise feature names used - as additive attention bias. These must correspond to sparse - graph-level attributes on ``data``. - feature_embedding_layer_dict: Optional ModuleDict mapping node types to - feature embedding layers. If provided, these are applied to node - features before node projection. (default: None) - pe_integration_mode: How to fuse positional encodings into the model - input. ``"concat"`` preserves the current behavior by concatenating - node-level PE to token features. ``"add"`` uses node-level additive - PE before sequence construction and attention bias for relative - encodings. - activation: Activation function for the feed-forward network in each - transformer layer. Supported values: - - Standard: "gelu" (default), "relu", "silu", "tanh" - - XGLU family: "geglu", "swiglu", "reglu" - XGLU activations use gating: activation(xW) * xV. - feedforward_ratio: Ratio of feedforward dimension to hidden dimension - (feedforward_dim = hid_dim * feedforward_ratio). If None (default), - uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, - following the convention that XGLU's gating doubles the effective - parameters, so a smaller ratio maintains similar parameter count. - relation_attention_mode: Optional relation-aware augmentation for - attention scores. ``"none"`` preserves the current dense transformer - path. ``"edge_type_additive"`` adds a learned per-edge-type - bilinear score term for sampled directed edges in ``"khop"`` mode. - - Notes: - This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap - it with ``DistributedDataParallel``, run one representative no-grad - forward first, passing ``anchor_node_ids``/``anchor_node_type`` for the - graph-transformer path, or load a checkpoint before DDP so all ranks see - initialized weights. - - TODO: Pairwise relative bias is currently materialized densely for the selected - sequence. That is fine for moderate ``max_seq_len``, but a chunked or - sparse LPFormer-style path is still future work for larger sequences. - - Example: - >>> from gigl.src.common.models.graph_transformer.graph_transformer import ( - ... GraphTransformerEncoder, - ... ) - >>> encoder = GraphTransformerEncoder( - ... node_type_to_feat_dim_map={NodeType("user"): 64, NodeType("item"): 32}, - ... edge_type_to_feat_dim_map={}, - ... hid_dim=128, - ... out_dim=64, - ... num_layers=2, - ... num_heads=4, - ... ) - >>> embeddings = encoder(data, anchor_node_type=NodeType("user"), device=device) - """ - - def __init__( - self, - node_type_to_feat_dim_map: dict[NodeType, int], - edge_type_to_feat_dim_map: dict[EdgeType, int], - hid_dim: int, - out_dim: int = 128, - num_layers: int = 2, - num_heads: int = 2, - max_seq_len: int = 128, - hop_distance: int = 2, - sequence_construction_method: Literal["khop", "ppr"] = "khop", - sequence_positional_encoding_type: Optional[str] = None, - dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.0, - should_l2_normalize_embedding_layer_output: bool = False, - pe_attr_names: Optional[list[str]] = None, - anchor_based_attention_bias_attr_names: Optional[list[str]] = None, - anchor_based_input_attr_names: Optional[list[str]] = None, - anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None, - pairwise_attention_bias_attr_names: Optional[list[str]] = None, - feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, - pe_integration_mode: Literal["concat", "add"] = "concat", - activation: str = "gelu", - feedforward_ratio: Optional[float] = None, - relation_attention_mode: Literal["none", "edge_type_additive"] = "none", - **kwargs: object, - ) -> None: - super().__init__() - del kwargs - - if pe_integration_mode not in {"concat", "add"}: - raise ValueError( - "pe_integration_mode must be one of {'concat', 'add'}, " - f"got '{pe_integration_mode}'" - ) - - self._hid_dim = hid_dim - self._out_dim = out_dim - self._max_seq_len = max_seq_len - self._hop_distance = hop_distance - if sequence_construction_method not in {"khop", "ppr"}: - raise ValueError( - "sequence_construction_method must be one of {'khop', 'ppr'}, " - f"got '{sequence_construction_method}'" - ) - if sequence_positional_encoding_type is not None: - sequence_positional_encoding_type = ( - sequence_positional_encoding_type.lower() - ) - if sequence_positional_encoding_type == "none": - sequence_positional_encoding_type = None - if sequence_positional_encoding_type not in {None, "sinusoidal"}: - raise ValueError( - "sequence_positional_encoding_type must be one of " - "{None, 'sinusoidal'}, " - f"got '{sequence_positional_encoding_type}'" - ) - if ( - sequence_construction_method == "khop" - and sequence_positional_encoding_type is not None - ): - raise ValueError( - "sequence_positional_encoding_type requires " - "sequence_construction_method='ppr' because khop sequences do not " - "enforce a stable token order." - ) - if relation_attention_mode not in {"none", "edge_type_additive"}: - raise ValueError( - "relation_attention_mode must be one of " - "{'none', 'edge_type_additive'}, " - f"got '{relation_attention_mode}'" - ) - if ( - relation_attention_mode == "edge_type_additive" - and sequence_construction_method != "khop" - ): - raise ValueError( - "relation_attention_mode='edge_type_additive' requires " - "sequence_construction_method='khop'." - ) - anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] - anchor_input_attr_names = anchor_based_input_attr_names or [] - pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] - if PPR_WEIGHT_FEATURE_NAME in pairwise_bias_attr_names: - raise ValueError( - f"'{PPR_WEIGHT_FEATURE_NAME}' is an anchor-relative feature and " - "cannot be used as pairwise attention bias." - ) - if ( - PPR_WEIGHT_FEATURE_NAME in anchor_bias_attr_names + anchor_input_attr_names - and sequence_construction_method != "ppr" - ): - raise ValueError( - "The reserved anchor-relative feature 'ppr_weight' requires " - "sequence_construction_method='ppr'." - ) - self._sequence_construction_method = sequence_construction_method - self._sequence_positional_encoding_type = sequence_positional_encoding_type - self._should_l2_normalize_embedding_layer_output = ( - should_l2_normalize_embedding_layer_output - ) - self._pe_attr_names = pe_attr_names - self._anchor_based_attention_bias_attr_names = ( - anchor_based_attention_bias_attr_names - ) - self._anchor_based_input_attr_names = anchor_based_input_attr_names - self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict - self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names - self._feature_embedding_layer_dict = feature_embedding_layer_dict - self._pe_integration_mode = pe_integration_mode - self._num_heads = num_heads - self._relation_attention_mode = relation_attention_mode - self._relation_attention_edge_types = ( - sorted(edge_type_to_feat_dim_map.keys()) - if relation_attention_mode == "edge_type_additive" - else [] - ) - anchor_input_embedding_attr_names = ( - set(anchor_based_input_embedding_dict.keys()) - if anchor_based_input_embedding_dict is not None - else set() - ) - invalid_anchor_input_embedding_attr_names = ( - anchor_input_embedding_attr_names - set(anchor_input_attr_names) - ) - if invalid_anchor_input_embedding_attr_names: - raise ValueError( - "anchor_based_input_embedding_dict keys must be a subset of " - "anchor_based_input_attr_names, got unexpected keys " - f"{sorted(invalid_anchor_input_embedding_attr_names)}." - ) - self._continuous_anchor_input_attr_names = [ - attr_name - for attr_name in anchor_input_attr_names - if attr_name not in anchor_input_embedding_attr_names - ] - if self._sequence_positional_encoding_type == "sinusoidal": - self.register_buffer( - "_sequence_positional_encoding_table", - _build_sinusoidal_sequence_position_table( - max_seq_len=max_seq_len, - hid_dim=hid_dim, - ), - persistent=False, - ) - else: - self.register_buffer( - "_sequence_positional_encoding_table", - None, - persistent=False, - ) - - # Per-node-type input projection to hid_dim (like HGT's lin_dict) - self._node_projection_dict = nn.ModuleDict( - { - str(node_type): nn.Linear(feat_dim, hid_dim) - for node_type, feat_dim in node_type_to_feat_dim_map.items() - } - ) - - # PE fusion layers for node-level positional encodings. - # In "concat" mode: projects [node_features || PE] → hid_dim - # In "add" mode: projects PE → hid_dim, then adds to node features - self._concat_pe_fusion_projection: Optional[nn.Module] = None - has_node_level_pe = bool(pe_attr_names) - if pe_integration_mode == "concat" and has_node_level_pe: - self._concat_pe_fusion_projection = nn.LazyLinear(hid_dim) - - self._pe_projection: Optional[nn.Module] = None - if pe_integration_mode == "add" and has_node_level_pe: - self._pe_projection = nn.LazyLinear(hid_dim, bias=False) - - self._token_input_projection: Optional[nn.Module] = None - if self._continuous_anchor_input_attr_names: - self._token_input_projection = nn.LazyLinear(hid_dim, bias=False) - - self._anchor_pe_attention_bias_projection: Optional[nn.Linear] = None - num_anchor_bias_attrs = len(self._anchor_based_attention_bias_attr_names or []) - if num_anchor_bias_attrs > 0: - self._anchor_pe_attention_bias_projection = nn.Linear( - num_anchor_bias_attrs, - num_heads, - bias=False, - ) - - self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None - if self._pairwise_attention_bias_attr_names: - self._pairwise_pe_attention_bias_projection = nn.Linear( - len(self._pairwise_attention_bias_attr_names), - num_heads, - bias=False, - ) - - # Transformer encoder layers - # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU - # XGLU's gating mechanism doubles effective parameters, so smaller ratio - # maintains similar parameter count to standard activations with ratio 4. - is_xglu = activation.lower() in _XGLU_BASE_ACTIVATIONS - if feedforward_ratio is None: - feedforward_ratio = 8.0 / 3.0 if is_xglu else 4.0 - feedforward_dim = int(hid_dim * feedforward_ratio) - self._encoder_layers = nn.ModuleList( - [ - GraphTransformerEncoderLayer( - model_dim=hid_dim, - num_heads=num_heads, - feedforward_dim=feedforward_dim, - dropout_rate=dropout_rate, - attention_dropout_rate=attention_dropout_rate, - activation=activation, - relation_attention_mode=relation_attention_mode, - num_relations=len(self._relation_attention_edge_types), - ) - for _ in range(num_layers) - ] - ) - - self._final_norm = nn.LayerNorm(hid_dim) - - # Readout attention: projects concatenated (anchor, neighbor) to score - self._readout_attention = nn.Linear(2 * hid_dim, 1) - - # Output projection: hid_dim -> out_dim - self._output_projection = nn.Linear(hid_dim, out_dim) - - def forward( - self, - data: torch_geometric.data.hetero_data.HeteroData, - anchor_node_type: Optional[NodeType] = None, - anchor_node_ids: Optional[Tensor] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - """Run the forward pass of the Graph Transformer encoder. - - Args: - data: Input HeteroData object with node features (``x_dict``) - and edge indices (``edge_index_dict``). - anchor_node_type: Node type for which to compute embeddings. - If None, uses the first node type in data. - anchor_node_ids: Optional tensor of local node indices within - anchor_node_type to use as anchors. If None, uses the first - batch_size nodes (seed nodes from neighbor sampling). - device: Torch device for output tensors. If None, inferred from data. - - Returns: - Embeddings tensor of shape ``(num_anchor_nodes, out_dim)``. - """ - # Infer device from data if not provided - if device is None: - device = next(iter(data.x_dict.values())).device - - # Use first node type if not specified - if anchor_node_type is None: - anchor_node_type = list(data.node_types)[0] - - # 0. Apply feature embedding if provided (without modifying original data) - # 1. Project all node features to hid_dim - # Build a new x_dict with processed features to avoid in-place modifications - projected_x_dict: dict[NodeType, torch.Tensor] = {} - for node_type, x in data.x_dict.items(): - x_processed = x.to(device) - feature_embedding_layer = None - if ( - self._feature_embedding_layer_dict is not None - and node_type in self._feature_embedding_layer_dict - ): - feature_embedding_layer = self._feature_embedding_layer_dict[node_type] - # Apply feature embedding if available for this node type - if feature_embedding_layer is not None: - x_processed = feature_embedding_layer(x_processed) - # Project to hid_dim - x_projected = self._node_projection_dict[str(node_type)](x_processed) - node_pe_parts = [] - if self._pe_attr_names: - node_pe_parts.append( - _get_node_type_positional_encodings( - data=data, - node_type=node_type, - pe_attr_names=self._pe_attr_names, - device=device, - ) - ) - if node_pe_parts: - node_pe = torch.cat(node_pe_parts, dim=-1) - if self._pe_integration_mode == "add": - if self._pe_projection is None: - raise ValueError("PE projection layer is not initialized.") - x_projected = x_projected + self._pe_projection(node_pe) - else: - if self._concat_pe_fusion_projection is None: - raise ValueError( - "Concat PE fusion projection layer is not initialized." - ) - x_projected = self._concat_pe_fusion_projection( - torch.cat([x_projected, node_pe], dim=-1) - ) - projected_x_dict[node_type] = x_projected - - # Create a new HeteroData with projected features (avoiding in-place modification) - projected_data = torch_geometric.data.HeteroData() - for node_type in data.node_types: - projected_data[node_type].x = projected_x_dict[node_type] - # Copy batch_size if it exists - if hasattr(data[node_type], "batch_size"): - projected_data[node_type].batch_size = data[node_type].batch_size - for edge_type in data.edge_types: - projected_data[edge_type].edge_index = data[edge_type].edge_index - if hasattr(data[edge_type], "edge_attr"): - projected_data[edge_type].edge_attr = data[edge_type].edge_attr - # Copy relative-encoding attributes (e.g., hop_distance stored as sparse matrix) - relative_pe_attr_names = { - attr_name - for attr_name in (self._anchor_based_attention_bias_attr_names or []) - if attr_name != PPR_WEIGHT_FEATURE_NAME - } - relative_pe_attr_names.update(self._anchor_based_input_attr_names or []) - relative_pe_attr_names.update(self._pairwise_attention_bias_attr_names or []) - relative_pe_attr_names.discard(PPR_WEIGHT_FEATURE_NAME) - if relative_pe_attr_names: - for attr_name in sorted(relative_pe_attr_names): - if hasattr(data, attr_name): - setattr(projected_data, attr_name, getattr(data, attr_name)) - - # 2. Build sequences and run transformer - # If anchor_node_ids provided, use those; otherwise use first batch_size nodes - if anchor_node_ids is not None: - num_anchor_nodes = anchor_node_ids.size(0) - else: - num_anchor_nodes = getattr( - projected_data[anchor_node_type], - "batch_size", - projected_data[anchor_node_type].num_nodes, - ) - - ( - sequences, - valid_mask, - sequence_auxiliary_data, - ) = heterodata_to_graph_transformer_input( - data=projected_data, - batch_size=num_anchor_nodes, - max_seq_len=self._max_seq_len, - anchor_node_type=anchor_node_type, - anchor_node_ids=anchor_node_ids, - hop_distance=self._hop_distance, - sequence_construction_method=self._sequence_construction_method, - anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, - anchor_based_input_attr_names=self._anchor_based_input_attr_names, - pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, - relation_edge_types=self._relation_attention_edge_types, - ) - - # Free memory after sequences are built - del projected_data - - if sequences.size(-1) != self._hid_dim: - raise ValueError( - f"Expected sequence dim {self._hid_dim} after node projection, " - f"got {sequences.size(-1)}." - ) - - token_input_features = sequence_auxiliary_data["token_input"] - if token_input_features is not None: - sequences = sequences + self._build_token_input_contribution( - token_input_features=token_input_features, - sequences=sequences, - valid_mask=valid_mask, - ) - - sequence_positional_encoding = self._get_sequence_positional_encoding( - valid_mask=valid_mask, - sequences=sequences, - ) - if sequence_positional_encoding is not None: - sequences = sequences + sequence_positional_encoding - - attn_bias = self._build_attention_bias( - valid_mask=valid_mask, - sequences=sequences, - attention_bias_data=sequence_auxiliary_data, - ) - - embeddings = self._encode_and_readout( - sequences=sequences, - valid_mask=valid_mask, - attn_bias=attn_bias, - pairwise_relation_mask=sequence_auxiliary_data.get( - "pairwise_relation_mask" - ), - ) - embeddings = self._output_projection(embeddings) - - if self._should_l2_normalize_embedding_layer_output: - embeddings = F.normalize(embeddings, p=2, dim=-1) - - return embeddings - - def _get_sequence_positional_encoding( - self, - valid_mask: Tensor, - sequences: Tensor, - ) -> Optional[Tensor]: - if self._sequence_positional_encoding_type is None: - return None - if self._sequence_positional_encoding_type != "sinusoidal": - raise ValueError( - "Unsupported sequence_positional_encoding_type " - f"'{self._sequence_positional_encoding_type}'." - ) - if self._sequence_positional_encoding_table is None: - raise ValueError("Sequence positional encoding table is not initialized.") - position_table = cast(Tensor, self._sequence_positional_encoding_table) - - seq_len = sequences.size(1) - if seq_len > position_table.size(0): - raise ValueError( - f"Sequence length {seq_len} exceeds configured max_seq_len " - f"{position_table.size(0)}." - ) - - position_encoding = position_table[:seq_len] - position_encoding = position_encoding.to( - device=sequences.device, - dtype=sequences.dtype, - ) - position_encoding = position_encoding.unsqueeze(0).expand( - sequences.size(0), -1, -1 - ) - return position_encoding * valid_mask.unsqueeze(-1).to(sequences.dtype) - - def _build_token_input_contribution( - self, - token_input_features: TokenInputData, - sequences: Tensor, - valid_mask: Tensor, - ) -> Tensor: - token_contribution = torch.zeros_like(sequences) - valid_token_mask = valid_mask.unsqueeze(-1).to(sequences.dtype) - - if self._anchor_based_input_embedding_dict is not None: - for ( - attr_name, - embedding_layer, - ) in self._anchor_based_input_embedding_dict.items(): - if attr_name not in token_input_features: - raise ValueError( - f"Token-input feature '{attr_name}' is missing from the " - "sequence auxiliary data." - ) - indices = token_input_features[attr_name] - if indices.size(-1) != 1: - raise ValueError( - f"Embedded token-input feature '{attr_name}' must have " - f"shape (batch, seq, 1), got {indices.shape}." - ) - embedded_attr = embedding_layer(indices.squeeze(-1).long()) - if embedded_attr.shape != sequences.shape: - raise ValueError( - f"Embedded token-input feature '{attr_name}' must produce " - f"shape {sequences.shape}, got {embedded_attr.shape}." - ) - token_contribution = token_contribution + ( - embedded_attr.to(sequences.dtype) * valid_token_mask - ) - - if self._continuous_anchor_input_attr_names: - if self._token_input_projection is None: - raise ValueError("Token-input projection is not initialized.") - continuous_feature_parts: list[Tensor] = [] - for attr_name in self._continuous_anchor_input_attr_names: - if attr_name not in token_input_features: - raise ValueError( - f"Token-input feature '{attr_name}' is missing from the " - "sequence auxiliary data." - ) - continuous_feature_parts.append(token_input_features[attr_name]) - token_contribution = token_contribution + ( - self._token_input_projection( - torch.cat(continuous_feature_parts, dim=-1).to(sequences.dtype) - ) - * valid_token_mask - ) - - return token_contribution - - def _build_attention_bias( - self, - valid_mask: Tensor, - sequences: Tensor, - attention_bias_data: SequenceAuxiliaryData, - ) -> Tensor: - """Build additive attention bias from padding mask and learned relative PE projections. - - This function constructs a combined attention bias tensor that is added to - attention scores before softmax. The bias has three components: - - 1. **Padding mask bias**: Sets padded positions to -inf so they receive zero - attention weight after softmax. Shape: (batch, 1, 1, seq) broadcasts to - (batch, num_heads, seq, seq) for key masking. - - 2. **Anchor-relative bias** (optional): For each sequence position, looks up - the PE value relative to the anchor (e.g., hop distance from anchor). - Input shape: (batch, seq, num_anchor_attrs) - After projection: (batch, num_heads, 1, seq) - same bias for all query positions. - - 3. **Pairwise bias** (optional): For each (query, key) pair, looks up the PE - value between those two nodes (e.g., random walk structural encoding). - Input shape: (batch, seq, seq, num_pairwise_attrs) - After projection: (batch, num_heads, seq, seq) - unique bias per query-key pair. - - Args: - valid_mask: Boolean mask of shape (batch_size, seq_len) indicating - valid (non-padding) positions. - sequences: Input sequences of shape (batch_size, seq_len, hid_dim), - used only to infer dtype and device. - attention_bias_data: Dictionary containing optional PE tensors: - - "anchor_bias": (batch, seq, num_anchor_attrs) or None - - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None - - Returns: - Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) - or broadcastable shape. Added to attention scores before softmax. - - Example: - # With batch_size=2, seq_len=4, num_heads=8 - # valid_mask = [[T, T, T, F], [T, T, F, F]] - # - # Output attn_bias shape: (2, 8, 4, 4) - # - Positions where valid_mask is False get -inf - # - Anchor bias adds per-key bias (same for all queries) - # - Pairwise bias adds unique bias for each (query, key) pair - """ - batch_size, seq_len = valid_mask.shape - dtype = sequences.dtype - device = sequences.device - negative_inf = torch.finfo(dtype).min - - # Step 1: Initialize with padding mask bias - # Shape: (batch, 1, 1, seq) - broadcasts to mask invalid keys for all queries/heads - attn_bias = torch.zeros( - (batch_size, 1, 1, seq_len), - dtype=dtype, - device=device, - ) - attn_bias = attn_bias.masked_fill( - ~valid_mask.unsqueeze(1).unsqueeze(2), # (batch, 1, 1, seq) - negative_inf, - ) - - # Step 2: Add anchor-relative bias (optional) - # Projects (batch, seq, num_attrs) → (batch, seq, num_heads) - # Then reshapes to (batch, num_heads, 1, seq) for key-side bias - anchor_bias_features = attention_bias_data.get("anchor_bias") - if anchor_bias_features is not None: - if self._anchor_pe_attention_bias_projection is None: - raise ValueError("Anchor attention-bias projection is not initialized.") - anchor_bias = self._anchor_pe_attention_bias_projection( - anchor_bias_features.to(dtype) - ) # (batch, seq, num_heads) - anchor_bias = anchor_bias.permute(0, 2, 1).unsqueeze( - 2 - ) # (batch, num_heads, 1, seq) - attn_bias = attn_bias + anchor_bias - - # Step 3: Add pairwise bias (optional) - # Projects (batch, seq, seq, num_attrs) → (batch, seq, seq, num_heads) - # Then reshapes to (batch, num_heads, seq, seq) - pairwise_bias_features = attention_bias_data.get("pairwise_bias") - if pairwise_bias_features is not None: - if self._pairwise_pe_attention_bias_projection is None: - raise ValueError( - "Pairwise attention-bias projection is not initialized." - ) - pairwise_bias = self._pairwise_pe_attention_bias_projection( - pairwise_bias_features.to(dtype) - ) # (batch, seq, seq, num_heads) - pairwise_bias = pairwise_bias.permute( - 0, 3, 1, 2 - ) # (batch, num_heads, seq, seq) - attn_bias = attn_bias + pairwise_bias - - return attn_bias - - def _encode_and_readout( - self, - sequences: Tensor, - valid_mask: Tensor, - attn_bias: Optional[Tensor] = None, - pairwise_relation_mask: Optional[Tensor] = None, - ) -> Tensor: - """Process sequences through transformer layers and attention readout. - - Args: - sequences: Input tensor of shape ``(batch_size, max_seq_len, hid_dim)``. - valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. - attn_bias: Optional additive attention bias broadcastable to - ``(batch_size, num_heads, seq, seq)``. - pairwise_relation_mask: Optional boolean relation mask shaped - ``(batch_size, seq, seq, num_relations)``. - - Returns: - Output embeddings of shape ``(batch_size, hid_dim)``. - """ - x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) - - for encoder_layer in self._encoder_layers: - x = encoder_layer( - x, - attn_bias=attn_bias, - pairwise_relation_mask=pairwise_relation_mask, - valid_mask=valid_mask, - ) - - x = self._final_norm(x) - x = x * valid_mask.unsqueeze(-1).to(x.dtype) - - # Readout: anchor (position 0) + attention-weighted neighbor aggregation - anchor = x[:, 0, :].unsqueeze(1) # (batch, 1, hid_dim) - neighbors = x[:, 1:, :] # (batch, seq-1, hid_dim) - neighbor_valid_mask = valid_mask[:, 1:] - seq_minus_one = neighbors.size(1) - - if seq_minus_one == 0: - return anchor.squeeze(1) - - # Expand anchor to match neighbor dimension for concatenation - anchor_expanded = anchor.expand(-1, seq_minus_one, -1) - - # Compute attention scores over neighbors - readout_scores = self._readout_attention( - torch.cat([anchor_expanded, neighbors], dim=-1) - ) # (batch, seq-1, 1) - readout_scores = readout_scores.masked_fill( - ~neighbor_valid_mask.unsqueeze(-1), - torch.finfo(readout_scores.dtype).min, - ) - readout_weights = F.softmax(readout_scores, dim=1) # (batch, seq-1, 1) - readout_weights = torch.nan_to_num(readout_weights, nan=0.0) - readout_weights = readout_weights * neighbor_valid_mask.unsqueeze(-1).to( - readout_weights.dtype - ) - - neighbor_aggregation = (neighbors * readout_weights).sum( - dim=1, keepdim=True - ) # (batch, 1, hid_dim) - - output = (anchor + neighbor_aggregation).squeeze(1) # (batch, hid_dim) - - return output -======= class GraphTransformerEncoder(_GraphTransformerEncoder): def __init__(self, *args: Any, **kwargs: Any) -> None: logger.warning(_DEPRECATION_MSG) super().__init__(*args, **kwargs) ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 3a0a57f94..9c929a9da 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -39,17 +39,6 @@ def validate_max_labels_per_anchor_node( max_labels_per_anchor_node: Optional[int], -<<<<<<< HEAD -) -> Optional[int]: - """Validate the optional per-anchor label cap.""" - if max_labels_per_anchor_node is None: - return None - if max_labels_per_anchor_node <= 0: - raise ValueError( - "max_labels_per_anchor_node must be a positive integer when provided." - ) - return max_labels_per_anchor_node -======= ) -> None: """Validate the optional per-anchor label cap. @@ -63,7 +52,6 @@ def validate_max_labels_per_anchor_node( raise ValueError( "max_labels_per_anchor_node must be a positive integer when provided." ) ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 def get_max_labels_per_anchor_node_from_runtime_args( @@ -82,16 +70,12 @@ def get_max_labels_per_anchor_node_from_runtime_args( f"Invalid {MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG} value " f"{raw_max_labels_per_anchor_node!r}. Expected a positive integer." ) from exc -<<<<<<< HEAD - return validate_max_labels_per_anchor_node(parsed_max_labels_per_anchor_node) -======= if parsed_max_labels_per_anchor_node <= 0: raise ValueError( f"Invalid {MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG} value " f"{raw_max_labels_per_anchor_node!r}. Expected a positive integer." ) return parsed_max_labels_per_anchor_node ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 @runtime_checkable @@ -715,13 +699,7 @@ def _get_padded_labels( Returns: The shape of the returned tensor is [N, max_number_of_labels]. """ -<<<<<<< HEAD - max_labels_per_anchor_node = validate_max_labels_per_anchor_node( - max_labels_per_anchor_node - ) -======= validate_max_labels_per_anchor_node(max_labels_per_anchor_node) ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 # indptr is the ROW_INDEX of a CSR matrix. # and indices is the COL_INDEX of a CSR matrix. # See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index ee7e7ca78..34121565d 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -548,9 +548,7 @@ def test_relation_mask_outputs_follow_requested_order_and_direction(self): ) self.assertEqual(relation_mask.shape, (1, 4, 4, 2)) self.assertEqual(relation_mask.dtype, torch.bool) - self.assertTrue( - torch.equal(relation_mask[0, 1, 0], torch.tensor([True, True])) - ) + self.assertTrue(torch.equal(relation_mask[0, 1, 0], torch.tensor([True, True]))) self.assertTrue( torch.equal(relation_mask[0, 2, 1], torch.tensor([False, True])) ) diff --git a/tests/unit/utils/data_splitters_test.py b/tests/unit/utils/data_splitters_test.py index f1450cb51..0d76d0d07 100644 --- a/tests/unit/utils/data_splitters_test.py +++ b/tests/unit/utils/data_splitters_test.py @@ -820,16 +820,7 @@ def test_get_padded_labels_with_max_labels_per_anchor_node(self): ), max_labels_per_anchor_node=1, ) -<<<<<<< HEAD - assert_close( - labels, - torch.tensor([[1], [2]], dtype=torch.int64), - rtol=0, - atol=0, - ) -======= self.assert_tensor_equality(labels, torch.tensor([[1], [2]], dtype=torch.int64)) ->>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230 def test_get_max_labels_per_anchor_node_from_runtime_args(self): self.assertIsNone(get_max_labels_per_anchor_node_from_runtime_args({})) From 14b4a4c7bcde9ce3ac1b7b955da1d1628297ee20 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 5 May 2026 15:04:39 -0700 Subject: [PATCH 5/6] clean up --- gigl/distributed/graph_store/storage_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index 3158fc0fe..af6dafa23 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -85,16 +85,11 @@ def build_storage_dataset( gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) - if max_labels_per_anchor_node is None: - max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args( - dict(gbml_config_pb_wrapper.trainer_config.trainer_args) - ) serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, tfrecord_uri_pattern=tf_record_uri_pattern, ) - # TODO: Pipe in max_labels_per_anchor_node to build_dataset. dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, From 808e47f460ec6230425f2c3b114767757014cd22 Mon Sep 17 00:00:00 2001 From: Yozen Liu Date: Tue, 5 May 2026 15:23:02 -0700 Subject: [PATCH 6/6] updates --- gigl/nn/graph_transformer.py | 202 ++++++++++++++++++++++-- tests/unit/nn/graph_transformer_test.py | 32 ++++ 2 files changed, 225 insertions(+), 9 deletions(-) diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..eea2631c3 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -239,6 +239,14 @@ class GraphTransformerEncoderLayer(nn.Module): activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". + relation_attention_mode: Optional relation-aware augmentation strategy + for attention scores. ``"none"`` preserves the default shared + self-attention path. ``"edge_type_additive"`` adds a learned + per-edge-type bilinear term for token pairs backed by sampled + directed graph edges. + num_relations: Number of relation channels expected in + ``pairwise_relation_mask`` when + ``relation_attention_mode="edge_type_additive"``. Raises: ValueError: If model_dim is not divisible by num_heads. @@ -252,16 +260,31 @@ def __init__( dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", + num_relations: int = 0, ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if relation_attention_mode == "edge_type_additive" and num_relations <= 0: + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "num_relations > 0." + ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate + self._relation_attention_mode = relation_attention_mode + self._num_relations = num_relations self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) @@ -269,6 +292,11 @@ def __init__( self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) + self._relation_attention_matrices: Optional[nn.Parameter] = None + if relation_attention_mode == "edge_type_additive": + self._relation_attention_matrices = nn.Parameter( + torch.empty(num_relations, num_heads, self._head_dim, self._head_dim) + ) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( @@ -287,6 +315,10 @@ def reset_parameters(self) -> None: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) + if self._relation_attention_matrices is not None: + for relation_matrices in self._relation_attention_matrices: + for head_matrix in relation_matrices: + nn.init.xavier_uniform_(head_matrix) self._ffn_norm.reset_parameters() self._ffn.reset_parameters() @@ -294,6 +326,7 @@ def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. @@ -303,6 +336,9 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. + pairwise_relation_mask: Optional boolean multi-hot relation mask of shape + ``(batch, seq, seq, num_relations)`` that marks which sampled + directed edge types connect each token pair as ``key -> query``. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. @@ -330,14 +366,23 @@ def forward( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, - ) + if self._relation_attention_mode == "none": + attention_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + else: + attention_output = self._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + ) # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( @@ -360,6 +405,102 @@ def forward( return x + def _run_relation_aware_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_mask: Optional[Tensor], + ) -> Tensor: + relation_attention_bias = self._build_relation_attention_bias( + query, + key, + pairwise_relation_mask=pairwise_relation_mask, + ) + if relation_attention_bias is not None: + attn_bias = ( + relation_attention_bias + if attn_bias is None + else attn_bias + relation_attention_bias + ) + + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _build_relation_attention_bias( + self, + query: Tensor, + key: Tensor, + pairwise_relation_mask: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_mask is None: + raise ValueError( + "pairwise_relation_mask is required when " + "relation_attention_mode='edge_type_additive'." + ) + if pairwise_relation_mask.size(-1) != self._num_relations: + raise ValueError( + "pairwise_relation_mask has unexpected relation dimension " + f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}." + ) + if self._relation_attention_matrices is None: + raise ValueError("Relation attention matrices are not initialized.") + if pairwise_relation_mask.size(1) != query.size( + 2 + ) or pairwise_relation_mask.size(2) != key.size(2): + raise ValueError( + "pairwise_relation_mask must align with the query/key sequence " + "dimensions." + ) + + relation_mask = pairwise_relation_mask.to( + device=query.device, + dtype=torch.bool, + ) + active_relation_positions = relation_mask.nonzero(as_tuple=False) + if active_relation_positions.numel() == 0: + return None + + relation_attention_bias = query.new_zeros( + (query.size(0), query.size(2), key.size(2), self._num_heads) + ) + query_by_position = query.transpose(1, 2) + key_by_position = key.transpose(1, 2) + relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) + active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) + + for relation_idx_tensor in active_relation_ids: + relation_idx = int(relation_idx_tensor.item()) + relation_positions = active_relation_positions[ + active_relation_positions[:, 3] == relation_idx + ] + batch_indices, query_indices, key_indices = relation_positions[ + :, :3 + ].unbind(dim=1) + # Only materialize bilinear scores for token pairs backed by this relation. + selected_query = query_by_position[batch_indices, query_indices] + transformed_query = torch.einsum( + "nhe,hde->nhd", + selected_query, + relation_matrices[relation_idx], + ) + selected_key = key_by_position[batch_indices, key_indices] + relation_scores = (selected_key * transformed_query).sum(dim=-1) + relation_attention_bias.index_put_( + (batch_indices, query_indices, key_indices), + relation_scores / math.sqrt(self._head_dim), + accumulate=True, + ) + + return relation_attention_bias.permute(0, 3, 1, 2) + class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. @@ -450,6 +591,10 @@ class GraphTransformerEncoder(nn.Module): uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. + relation_attention_mode: Optional relation-aware augmentation for + attention scores. ``"none"`` preserves the current dense transformer + path. ``"edge_type_additive"`` adds a learned per-edge-type + bilinear score term for sampled directed edges in ``"khop"`` mode. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap @@ -499,6 +644,7 @@ def __init__( pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", **kwargs: object, ) -> None: super().__init__() @@ -540,6 +686,20 @@ def __init__( "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if ( + relation_attention_mode == "edge_type_additive" + and sequence_construction_method != "khop" + ): + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "sequence_construction_method='khop'." + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -571,6 +731,12 @@ def __init__( self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + self._relation_attention_mode = relation_attention_mode + self._relation_attention_edge_types = ( + sorted(edge_type_to_feat_dim_map.keys()) + if relation_attention_mode == "edge_type_additive" + else [] + ) anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None @@ -664,6 +830,8 @@ def __init__( dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, + relation_attention_mode=relation_attention_mode, + num_relations=len(self._relation_attention_edge_types), ) for _ in range(num_layers) ] @@ -801,6 +969,11 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + relation_edge_types=( + self._relation_attention_edge_types + if self._relation_attention_mode == "edge_type_additive" + else None + ), ) # Free memory after sequences are built @@ -837,6 +1010,9 @@ def forward( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, + pairwise_relation_mask=sequence_auxiliary_data.get( + "pairwise_relation_mask" + ), ) embeddings = self._output_projection(embeddings) @@ -1036,6 +1212,7 @@ def _encode_and_readout( sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, ) -> Tensor: """Process sequences through transformer layers and attention readout. @@ -1044,6 +1221,8 @@ def _encode_and_readout( valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. + pairwise_relation_mask: Optional boolean relation mask shaped + ``(batch_size, seq, seq, num_relations)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. @@ -1051,7 +1230,12 @@ def _encode_and_readout( x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: - x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) + x = encoder_layer( + x, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + valid_mask=valid_mask, + ) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index 8ae42ec53..9582c18e4 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -1,6 +1,7 @@ """Tests for GraphTransformerEncoder.""" from typing import cast +from unittest.mock import patch import torch import torch.nn as nn @@ -140,6 +141,37 @@ def test_forward_with_l2_normalization(self) -> None: norms = torch.norm(embeddings, p=2, dim=-1) self.assertTrue(torch.allclose(norms, torch.ones_like(norms), atol=1e-5)) + def test_forward_does_not_request_relation_masks_when_relation_attention_disabled( + self, + ) -> None: + data = _create_simple_hetero_data() + encoder = self._create_encoder() + fake_sequences = torch.zeros((3, 2, self._hid_dim)) + fake_valid_mask = torch.ones((3, 2), dtype=torch.bool) + fake_auxiliary_data = { + "anchor_bias": None, + "pairwise_bias": None, + "pairwise_relation_mask": None, + "token_input": None, + } + + with patch( + "gigl.nn.graph_transformer.heterodata_to_graph_transformer_input", + return_value=( + fake_sequences, + fake_valid_mask, + fake_auxiliary_data, + ), + ) as mock_transform: + embeddings = encoder( + data=data, + anchor_node_type=self._user_node_type, + device=self._device, + ) + + self.assertIsNone(mock_transform.call_args.kwargs["relation_edge_types"]) + self.assertEqual(embeddings.shape, (3, self._out_dim)) + def test_forward_defaults_to_first_node_type(self) -> None: """Test that omitted anchor node type defaults to the first node type.""" data = _create_simple_hetero_data()