diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..eea2631c3 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -239,6 +239,14 @@ class GraphTransformerEncoderLayer(nn.Module): activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". + relation_attention_mode: Optional relation-aware augmentation strategy + for attention scores. ``"none"`` preserves the default shared + self-attention path. ``"edge_type_additive"`` adds a learned + per-edge-type bilinear term for token pairs backed by sampled + directed graph edges. + num_relations: Number of relation channels expected in + ``pairwise_relation_mask`` when + ``relation_attention_mode="edge_type_additive"``. Raises: ValueError: If model_dim is not divisible by num_heads. @@ -252,16 +260,31 @@ def __init__( dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", + num_relations: int = 0, ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if relation_attention_mode == "edge_type_additive" and num_relations <= 0: + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "num_relations > 0." + ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate + self._relation_attention_mode = relation_attention_mode + self._num_relations = num_relations self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) @@ -269,6 +292,11 @@ def __init__( self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) + self._relation_attention_matrices: Optional[nn.Parameter] = None + if relation_attention_mode == "edge_type_additive": + self._relation_attention_matrices = nn.Parameter( + torch.empty(num_relations, num_heads, self._head_dim, self._head_dim) + ) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( @@ -287,6 +315,10 @@ def reset_parameters(self) -> None: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) + if self._relation_attention_matrices is not None: + for relation_matrices in self._relation_attention_matrices: + for head_matrix in relation_matrices: + nn.init.xavier_uniform_(head_matrix) self._ffn_norm.reset_parameters() self._ffn.reset_parameters() @@ -294,6 +326,7 @@ def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. @@ -303,6 +336,9 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. + pairwise_relation_mask: Optional boolean multi-hot relation mask of shape + ``(batch, seq, seq, num_relations)`` that marks which sampled + directed edge types connect each token pair as ``key -> query``. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. @@ -330,14 +366,23 @@ def forward( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, - ) + if self._relation_attention_mode == "none": + attention_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + else: + attention_output = self._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + ) # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( @@ -360,6 +405,102 @@ def forward( return x + def _run_relation_aware_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_mask: Optional[Tensor], + ) -> Tensor: + relation_attention_bias = self._build_relation_attention_bias( + query, + key, + pairwise_relation_mask=pairwise_relation_mask, + ) + if relation_attention_bias is not None: + attn_bias = ( + relation_attention_bias + if attn_bias is None + else attn_bias + relation_attention_bias + ) + + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _build_relation_attention_bias( + self, + query: Tensor, + key: Tensor, + pairwise_relation_mask: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_mask is None: + raise ValueError( + "pairwise_relation_mask is required when " + "relation_attention_mode='edge_type_additive'." + ) + if pairwise_relation_mask.size(-1) != self._num_relations: + raise ValueError( + "pairwise_relation_mask has unexpected relation dimension " + f"{pairwise_relation_mask.size(-1)}; expected {self._num_relations}." + ) + if self._relation_attention_matrices is None: + raise ValueError("Relation attention matrices are not initialized.") + if pairwise_relation_mask.size(1) != query.size( + 2 + ) or pairwise_relation_mask.size(2) != key.size(2): + raise ValueError( + "pairwise_relation_mask must align with the query/key sequence " + "dimensions." + ) + + relation_mask = pairwise_relation_mask.to( + device=query.device, + dtype=torch.bool, + ) + active_relation_positions = relation_mask.nonzero(as_tuple=False) + if active_relation_positions.numel() == 0: + return None + + relation_attention_bias = query.new_zeros( + (query.size(0), query.size(2), key.size(2), self._num_heads) + ) + query_by_position = query.transpose(1, 2) + key_by_position = key.transpose(1, 2) + relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) + active_relation_ids = torch.unique(active_relation_positions[:, 3], sorted=True) + + for relation_idx_tensor in active_relation_ids: + relation_idx = int(relation_idx_tensor.item()) + relation_positions = active_relation_positions[ + active_relation_positions[:, 3] == relation_idx + ] + batch_indices, query_indices, key_indices = relation_positions[ + :, :3 + ].unbind(dim=1) + # Only materialize bilinear scores for token pairs backed by this relation. + selected_query = query_by_position[batch_indices, query_indices] + transformed_query = torch.einsum( + "nhe,hde->nhd", + selected_query, + relation_matrices[relation_idx], + ) + selected_key = key_by_position[batch_indices, key_indices] + relation_scores = (selected_key * transformed_query).sum(dim=-1) + relation_attention_bias.index_put_( + (batch_indices, query_indices, key_indices), + relation_scores / math.sqrt(self._head_dim), + accumulate=True, + ) + + return relation_attention_bias.permute(0, 3, 1, 2) + class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. @@ -450,6 +591,10 @@ class GraphTransformerEncoder(nn.Module): uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. + relation_attention_mode: Optional relation-aware augmentation for + attention scores. ``"none"`` preserves the current dense transformer + path. ``"edge_type_additive"`` adds a learned per-edge-type + bilinear score term for sampled directed edges in ``"khop"`` mode. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap @@ -499,6 +644,7 @@ def __init__( pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, + relation_attention_mode: Literal["none", "edge_type_additive"] = "none", **kwargs: object, ) -> None: super().__init__() @@ -540,6 +686,20 @@ def __init__( "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) + if relation_attention_mode not in {"none", "edge_type_additive"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_additive'}, " + f"got '{relation_attention_mode}'" + ) + if ( + relation_attention_mode == "edge_type_additive" + and sequence_construction_method != "khop" + ): + raise ValueError( + "relation_attention_mode='edge_type_additive' requires " + "sequence_construction_method='khop'." + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -571,6 +731,12 @@ def __init__( self._feature_embedding_layer_dict = feature_embedding_layer_dict self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + self._relation_attention_mode = relation_attention_mode + self._relation_attention_edge_types = ( + sorted(edge_type_to_feat_dim_map.keys()) + if relation_attention_mode == "edge_type_additive" + else [] + ) anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None @@ -664,6 +830,8 @@ def __init__( dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, + relation_attention_mode=relation_attention_mode, + num_relations=len(self._relation_attention_edge_types), ) for _ in range(num_layers) ] @@ -801,6 +969,11 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + relation_edge_types=( + self._relation_attention_edge_types + if self._relation_attention_mode == "edge_type_additive" + else None + ), ) # Free memory after sequences are built @@ -837,6 +1010,9 @@ def forward( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, + pairwise_relation_mask=sequence_auxiliary_data.get( + "pairwise_relation_mask" + ), ) embeddings = self._output_projection(embeddings) @@ -1036,6 +1212,7 @@ def _encode_and_readout( sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_mask: Optional[Tensor] = None, ) -> Tensor: """Process sequences through transformer layers and attention readout. @@ -1044,6 +1221,8 @@ def _encode_and_readout( valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. + pairwise_relation_mask: Optional boolean relation mask shaped + ``(batch_size, seq, seq, num_relations)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. @@ -1051,7 +1230,12 @@ def _encode_and_readout( x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: - x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) + x = encoder_layer( + x, + attn_bias=attn_bias, + pairwise_relation_mask=pairwise_relation_mask, + valid_mask=valid_mask, + ) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..304ca0795 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -65,12 +65,15 @@ from torch_geometric.typing import NodeType from torch_geometric.utils import to_torch_sparse_tensor +from gigl.src.common.types.graph_data import EdgeType + TokenInputData = dict[str, Tensor] class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_relation_mask: Optional[Tensor] token_input: Optional[TokenInputData] @@ -90,6 +93,7 @@ def heterodata_to_graph_transformer_input( anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + relation_edge_types: Optional[list[EdgeType]] = None, ) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -131,6 +135,10 @@ def heterodata_to_graph_transformer_input( pairwise_attention_bias_attr_names: List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on ``data``. Example: ['pairwise_distance']. + relation_edge_types: Optional ordered edge types used to materialize a + dense per-token-pair relation mask. Each output channel corresponds + to one edge type in this list. Directed edges are placed at + ``[query_pos, key_pos] = [dst_token, src_token]``. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -143,6 +151,8 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_relation_mask"`` shaped + ``(batch, seq, seq, num_relations)`` or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -312,6 +322,15 @@ def heterodata_to_graph_transformer_input( csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, device=device, ) + pairwise_relation_mask = _lookup_pairwise_relation_masks( + data=data, + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + ) anchor_bias_features = _compose_anchor_feature_tensor( anchor_relative_feature_sequences=anchor_relative_feature_sequences, @@ -332,6 +351,7 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_relation_mask": pairwise_relation_mask, "token_input": token_input_features, }, ) @@ -875,6 +895,123 @@ def _lookup_pairwise_relative_features( return features +def _lookup_pairwise_relation_masks( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + relation_edge_types: Optional[list[EdgeType]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, +) -> Optional[Tensor]: + """Build a dense per-token-pair boolean multi-hot relation mask. + + For each ordered token pair ``(query_pos, key_pos)``, this returns a + multi-hot vector indicating which directed sampled graph edges connect the + underlying nodes as ``key -> query``. The relation channel order is defined + by ``relation_edge_types``. + """ + if not relation_edge_types: + return None + + relation_adjacency_matrices = _build_relation_adjacency_matrices( + data=data, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + ) + if not relation_adjacency_matrices: + return None + + batch_size, max_seq_len = node_index_sequences.shape + num_relations = len(relation_adjacency_matrices) + relation_mask = torch.zeros( + (batch_size, max_seq_len, max_seq_len, num_relations), + dtype=torch.bool, + device=device, + ) + pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) + if not pair_valid_mask.any(): + return relation_mask + + row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) + col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) + valid_row_indices = row_indices[pair_valid_mask] + valid_col_indices = col_indices[pair_valid_mask] + + for relation_idx, adjacency_matrix in enumerate(relation_adjacency_matrices): + relation_values = _lookup_csr_values( + csr_matrix=adjacency_matrix, + row_indices=valid_row_indices, + col_indices=valid_col_indices, + ) + relation_mask[..., relation_idx][pair_valid_mask] = relation_values.ne(0.0) + + return relation_mask + + +def _build_relation_adjacency_matrices( + data: HeteroData, + relation_edge_types: list[EdgeType], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, +) -> list[Tensor]: + """Create one binary CSR adjacency matrix per requested edge type.""" + adjacency_matrices: list[Tensor] = [] + empty_row_ptr = torch.zeros(num_nodes + 1, dtype=torch.int64, device=device) + empty_col_idx = torch.zeros(0, dtype=torch.int64, device=device) + empty_values = torch.zeros(0, dtype=torch.float, device=device) + + for edge_type in relation_edge_types: + if edge_type not in data.edge_types: + adjacency_matrices.append( + torch.sparse_csr_tensor( + empty_row_ptr, + empty_col_idx, + empty_values, + size=(num_nodes, num_nodes), + ) + ) + continue + + edge_index = data[edge_type].edge_index.to(device) + src_indices = edge_index[0].long() + int(node_type_offsets[edge_type[0]]) + dst_indices = edge_index[1].long() + int(node_type_offsets[edge_type[2]]) + if src_indices.numel() == 0: + adjacency_matrices.append( + torch.sparse_csr_tensor( + empty_row_ptr, + empty_col_idx, + empty_values, + size=(num_nodes, num_nodes), + ) + ) + continue + + coalesced_adjacency = torch.sparse_coo_tensor( + torch.stack([dst_indices, src_indices]), + torch.ones(src_indices.numel(), dtype=torch.float, device=device), + size=(num_nodes, num_nodes), + ).coalesce() + adjacency_matrices.append( + torch.sparse_coo_tensor( + coalesced_adjacency.indices(), + torch.ones( + coalesced_adjacency.indices().size(1), + dtype=torch.float, + device=device, + ), + size=(num_nodes, num_nodes), + ) + .coalesce() + .to_sparse_csr() + ) + + return adjacency_matrices + + def _get_k_hop_neighbors_sparse( anchor_indices: Tensor, edge_index: Tensor, diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index d0fce10c3..9582c18e4 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -1,6 +1,7 @@ """Tests for GraphTransformerEncoder.""" from typing import cast +from unittest.mock import patch import torch import torch.nn as nn @@ -140,6 +141,37 @@ def test_forward_with_l2_normalization(self) -> None: norms = torch.norm(embeddings, p=2, dim=-1) self.assertTrue(torch.allclose(norms, torch.ones_like(norms), atol=1e-5)) + def test_forward_does_not_request_relation_masks_when_relation_attention_disabled( + self, + ) -> None: + data = _create_simple_hetero_data() + encoder = self._create_encoder() + fake_sequences = torch.zeros((3, 2, self._hid_dim)) + fake_valid_mask = torch.ones((3, 2), dtype=torch.bool) + fake_auxiliary_data = { + "anchor_bias": None, + "pairwise_bias": None, + "pairwise_relation_mask": None, + "token_input": None, + } + + with patch( + "gigl.nn.graph_transformer.heterodata_to_graph_transformer_input", + return_value=( + fake_sequences, + fake_valid_mask, + fake_auxiliary_data, + ), + ) as mock_transform: + embeddings = encoder( + data=data, + anchor_node_type=self._user_node_type, + device=self._device, + ) + + self.assertIsNone(mock_transform.call_args.kwargs["relation_edge_types"]) + self.assertEqual(embeddings.shape, (3, self._out_dim)) + def test_forward_defaults_to_first_node_type(self) -> None: """Test that omitted anchor node type defaults to the first node type.""" data = _create_simple_hetero_data() @@ -274,6 +306,22 @@ def _create_user_graph_with_ppr_edges() -> HeteroData: return data +def _create_user_graph_with_relation_edges() -> HeteroData: + data = HeteroData() + + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ) + data["user", "follows", "user"].edge_index = torch.tensor([[0, 1], [1, 2]]) + data["user", "likes", "user"].edge_index = torch.tensor([[0], [1]]) + + return data + + class TestGraphTransformerEncoderPEModes(TestCase): def setUp(self) -> None: self._node_type = NodeType("user") @@ -411,6 +459,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), + "pairwise_relation_mask": None, "token_input": None, }, ) @@ -447,6 +496,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, + "pairwise_relation_mask": None, "token_input": None, }, ) @@ -773,6 +823,160 @@ def test_layer_with_swiglu(self) -> None: self.assertEqual(out.shape, (2, 10, 32)) +class TestGraphTransformerRelationAttention(TestCase): + def setUp(self) -> None: + self._node_type = NodeType("user") + self._follows_edge_type = EdgeType( + self._node_type, Relation("follows"), self._node_type + ) + self._likes_edge_type = EdgeType( + self._node_type, Relation("likes"), self._node_type + ) + self._device = torch.device("cpu") + + def test_relation_aware_layer_matches_baseline_when_matrices_are_zero( + self, + ) -> None: + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + num_relations=2, + ) + relation_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + relation_mask = torch.zeros((2, 4, 4, 2), dtype=torch.bool) + relation_mask[0, 1, 0, 0] = True + relation_mask[1, 2, 1, 1] = True + + with torch.no_grad(): + assert relation_layer._relation_attention_matrices is not None + relation_layer._relation_attention_matrices.zero_() + base_output = base_layer(x, valid_mask=valid_mask) + relation_output = relation_layer( + x, + pairwise_relation_mask=relation_mask, + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_relation_aware_attention_only_changes_marked_query_positions( + self, + ) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + num_relations=1, + ) + + query = torch.tensor([[[[1.0, 0.0], [1.0, 0.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + value = torch.tensor([[[[3.0, 0.0], [0.0, 5.0]]]]) + empty_relation_mask = torch.zeros((1, 2, 2, 1), dtype=torch.bool) + active_relation_mask = empty_relation_mask.clone() + active_relation_mask[0, 1, 0, 0] = True + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices.zero_() + layer._relation_attention_matrices[0, 0] = torch.eye(2) + base_output = layer._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=None, + pairwise_relation_mask=empty_relation_mask, + ) + relation_output = layer._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=None, + pairwise_relation_mask=active_relation_mask, + ) + + self.assertTrue( + torch.allclose(base_output[:, :, 0], relation_output[:, :, 0], atol=1e-6) + ) + self.assertFalse( + torch.allclose(base_output[:, :, 1], relation_output[:, :, 1], atol=1e-6) + ) + + def test_encoder_forward_supports_relation_aware_attention(self) -> None: + data = _create_user_graph_with_relation_edges() + encoder = GraphTransformerEncoder( + node_type_to_feat_dim_map={self._node_type: 4}, + edge_type_to_feat_dim_map={ + self._likes_edge_type: 0, + self._follows_edge_type: 0, + }, + hid_dim=8, + out_dim=6, + num_layers=1, + num_heads=2, + max_seq_len=4, + hop_distance=2, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_additive", + ) + encoder.eval() + + with torch.no_grad(): + embeddings = encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertEqual( + encoder._relation_attention_edge_types, + sorted([self._likes_edge_type, self._follows_edge_type]), + ) + self.assertEqual(embeddings.shape, (3, 6)) + self.assertFalse(torch.isnan(embeddings).any()) + + def test_encoder_rejects_relation_attention_in_ppr_mode(self) -> None: + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + with self.assertRaisesRegex( + ValueError, + "relation_attention_mode='edge_type_additive' requires " + "sequence_construction_method='khop'", + ): + GraphTransformerEncoder( + node_type_to_feat_dim_map={self._node_type: 4}, + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + hid_dim=8, + out_dim=6, + num_layers=1, + num_heads=2, + max_seq_len=4, + hop_distance=1, + relation_attention_mode="edge_type_additive", + sequence_construction_method="ppr", + ) + + class TestGraphTransformerEncoderFeedforwardRatio(TestCase): """Tests for GraphTransformerEncoder feedforward_ratio parameter.""" diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 25cda0821..34121565d 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -7,6 +7,7 @@ from absl.testing import absltest from torch_geometric.data import HeteroData +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.transforms.graph_transformer import ( _get_k_hop_neighbors_sparse, heterodata_to_graph_transformer_input, @@ -122,6 +123,23 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def create_relation_mask_hetero_data() -> HeteroData: + """Create a single-node-type graph with overlapping directed relations.""" + data = HeteroData() + + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + ] + ) + data["user", "follows", "user"].edge_index = torch.tensor([[0, 1], [1, 2]]) + data["user", "likes", "user"].edge_index = torch.tensor([[0], [1]]) + + return data + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -252,6 +270,7 @@ def test_basic_transform(self): self.assertIsInstance(attention_bias_data, dict) self.assertIn("anchor_bias", attention_bias_data) self.assertIn("pairwise_bias", attention_bias_data) + self.assertIn("pairwise_relation_mask", attention_bias_data) def test_attention_mask_validity(self): """Test that attention mask correctly identifies valid positions.""" @@ -491,6 +510,53 @@ def test_ppr_sequence_can_return_token_input_and_attention_bias_features(self): torch.equal(valid_mask[1], torch.tensor([True, True, True, False])) ) + def test_relation_mask_outputs_follow_requested_order_and_direction(self): + data = create_relation_mask_hetero_data() + user_node_type = NodeType("user") + follows_edge_type = EdgeType( + user_node_type, Relation("follows"), user_node_type + ) + likes_edge_type = EdgeType(user_node_type, Relation("likes"), user_node_type) + + sequences, valid_mask, attention_bias_data = ( + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + relation_edge_types=[likes_edge_type, follows_edge_type], + ) + ) + + relation_mask = attention_bias_data["pairwise_relation_mask"] + assert relation_mask is not None + + expected_sequences = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ] + ) + self.assertTrue(torch.allclose(sequences, expected_sequences)) + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + self.assertEqual(relation_mask.shape, (1, 4, 4, 2)) + self.assertEqual(relation_mask.dtype, torch.bool) + self.assertTrue(torch.equal(relation_mask[0, 1, 0], torch.tensor([True, True]))) + self.assertTrue( + torch.equal(relation_mask[0, 2, 1], torch.tensor([False, True])) + ) + self.assertTrue( + torch.equal(relation_mask[0, 0, 1], torch.zeros(2, dtype=torch.bool)) + ) + self.assertTrue(torch.all(relation_mask[0, 3] == 0)) + class TestPyTorchTransformerIntegration(TestCase): """Tests for integration with PyTorch TransformerEncoderLayer."""