diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 1595bf58e410..0841398e5dd4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1442,4 +1442,70 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [ let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// FlexAttention operation + +// NOTE: This op is manually defined because flex_attention exists in +// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet +// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script +// validates against the JIT registry, so it cannot auto-generate this op. +// Once PyTorch adds flex_attention to the JIT registry, this can be moved to +// the auto-generated section. +//===----------------------------------------------------------------------===// +def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Computes the flex_attention operation (1-1 with torch._higher_order_ops.flex_attention)"; + let description = [{ + FlexAttention operation with flexible block-sparse attention patterns. + + Args: + query: Query tensor [B, H, M, K] + key: Key tensor [B, H, N, K] + value: Value tensor [B, H, N, Ev] + scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim)) + return_lse: Bool to return log-sum-exp values + + Attributes: + score_mod_fn: Optional function symbol reference for score modification + mask_mod_fn: Optional function symbol reference for mask modification + + # TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.) + + Returns: + output: Result tensor [B, H, M, Ev] + logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True) + max_scores: Optional max-scores tensor [B, H, M] (if return_max_scores=True) + }]; + + let arguments = (ins + AnyTorchTensorType:$query, + AnyTorchTensorType:$key, + AnyTorchTensorType:$value, + AnyTorchOptionalFloatType:$scale, + Torch_BoolType:$return_lse, + Torch_BoolType:$return_max_scores, + OptionalAttr:$score_mod_fn, + OptionalAttr:$mask_mod_fn + ); + + let results = (outs + AnyTorchTensorType:$output, + AnyTorchOptionalTensorType:$logsumexp, + AnyTorchOptionalTensorType:$max_scores + ); + + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult HigherOrderFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 3); + } + void HigherOrderFlexAttentionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 3); + } + }]; +} + #endif // TORCH_OPS diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 69cc92130725..d79ba099d8ec 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1905,6 +1905,150 @@ def _import_hop_auto_functionalized( for i, value in enumerate(operation.results): self.bind_node_value(node, value, i + bind_none) + def _import_hop_flex_attention( + self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator + ): + """Imports the torch._higher_order_ops.flex_attention HOP. + + Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...) + - query, key, value: Attention input tensors + - score_mod: Optional submodule/callable for score modification (imported as function) + - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors + - scale: Optional float for attention score scaling + - kernel_options: Optional Dict of performance tuning options: + - return_lse: Boolean for whether to return the log-sum-exp tensor + + This creates a call to hop_flex_attention with function symbol references for + score_mod and mask_mod. + """ + # flex_attention HOP args from PyTorch: + # (query, key, value, score_mod, block_mask, scale, kernel_options, ...) + ( + query_arg, + key_arg, + value_arg, + score_mod_arg, + block_mask_arg, + scale_arg, + kernel_options, + ) = node.args[:7] + + # Import Q, K, V tensors + query = self._import_argument(loc, query_arg, None) + key = self._import_argument(loc, key_arg, None) + value = self._import_argument(loc, value_arg, None) + + score_mod_ref = None + if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node): + assert ( + score_mod_arg.op == "get_attr" + ), f"Expected get_attr for score_mod, got {score_mod_arg.op}" + root_module = node.graph.owning_module + score_mod_module = getattr(root_module, score_mod_arg.target, None) + if score_mod_module is not None: + score_mod_func_name = self.fx_importer._graph_module_to_func_name[ + id(score_mod_module) + ] + score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name) + + # Handle block_mask: extract only mask_mod function reference + # Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.) + # that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx). + mask_mod_ref = None + if block_mask_arg is not None and isinstance(block_mask_arg, tuple): + root_module = node.graph.owning_module + # The mask_mod function is the last element in the BlockMask tuple + mask_mod_arg = block_mask_arg[-1] + if mask_mod_arg is not None and isinstance(mask_mod_arg, torch_fx.Node): + assert ( + mask_mod_arg.op == "get_attr" + ), f"Expected get_attr for mask_mod, got {mask_mod_arg.op}" + mask_mod_module = getattr(root_module, mask_mod_arg.target, None) + if mask_mod_module is not None: + mask_mod_func_name = self.fx_importer._graph_module_to_func_name[ + id(mask_mod_module) + ] + mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name) + + # Import scale (float or None) + if scale_arg is None: + scale = Operation.create( + "torch.constant.none", + results=[self._cc.torch_none_type], + loc=loc, + ).result + elif isinstance(scale_arg, (int, float)): + with loc: + scale = _make_constant_op( + "torch.constant.float", + FloatAttr.get_f64(float(scale_arg)), + self._cc.torch_float_type, + ).result + else: + scale = self._import_argument(loc, scale_arg, None) + + # Determine result types from node metadata + node_val = node.meta.get("val") + if isinstance(node_val, (list, tuple)) and len(node_val) >= 2: + # flex_attention returns (output, logsumexp) + result_types = [self._cc.value_info_to_type(v) for v in node_val] + self._multi_result_nodes.add(node) + else: + # Single output + result_types = [self._cc.node_val_to_type(node)] + + # Extract OUTPUT_LOGSUMEXP and OUTPUT_MAX from kernel_options + with loc: + return_lse = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr( + bool(kernel_options.get("OUTPUT_LOGSUMEXP", 0)), 1 + ), + self._cc.torch_bool_type, + ).result + return_max_scores = _make_constant_op( + "torch.constant.bool", + self._cc.integer_attr(bool(kernel_options.get("OUTPUT_MAX", 0)), 1), + self._cc.torch_bool_type, + ).result + + # Build operands for aten.flex_attention. + # Op expects exactly 6 operands: query, key, value, scale, return_lse, return_max_scores. + # Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands. + # Note: block_mask tensors are handled by mask_mod_fn, not passed as operands. + + flat_operands = [ + query, + key, + value, + scale, + return_lse, + return_max_scores, + ] + + # Build attributes with function references + # Only include attributes if they're not None (OptionalAttr in TableGen) + attributes = {} + if score_mod_ref is not None: + attributes["score_mod_fn"] = score_mod_ref + if mask_mod_ref is not None: + attributes["mask_mod_fn"] = mask_mod_ref + + operation = Operation.create( + "torch.hop_flex_attention", + results=result_types, + operands=flat_operands, + attributes=attributes if attributes else None, + loc=loc, + ) + # Bind results + if len(result_types) > 1: + self._multi_result_nodes.add(node) + for i, value in enumerate(operation.results): + self.bind_node_value(node, value, i) + else: + self.bind_node_value(node, operation.results[0]) + def _import_torch_op_overload( self, loc: Location, @@ -1932,7 +2076,7 @@ def _import_torch_op_overload( # torch dynamo where it emits the Tensor variant of ops even when processing # scalar arguments, therefore we retrieve the schema as well so that we # consume the correct typing information when subsequently importing the - # function arguments and result types + # function arguments and result types. # i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema` op_attrs = mlir_op_name.split(".") op_overload = getattr(torch, "ops") diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index a47cbf83a318..169b1094c8b5 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -205,3 +205,73 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to %1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32> return %1 : !torch.vtensor<[3,3],f32> } + + +//===----------------------------------------------------------------------===// +// FlexAttention variant tests +//===----------------------------------------------------------------------===// + +func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> { + %5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %5 : !torch.vtensor<[],f32> +} + +func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> { + %0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention +func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention_nomask +func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {score_mod_fn = @sdpa_score0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore +func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: {mask_mod_fn = @sdpa_mask0} + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> +} + +// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask +func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) { + %float1.0 = torch.constant.float 1.000000e+00 + %false_0 = torch.constant.bool false + // CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]] + // CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool + // CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32> + %output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> + return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32> +} diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index e490a8d3636c..bb424f0489b0 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -251,6 +251,88 @@ def body(i, x): print(m) +@run +# CHECK-LABEL: test_flex_attention +# Check that helper functions are emitted first +# CHECK: func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> +# CHECK: torch.aten.tanh +# CHECK: func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> +# CHECK: torch.aten.new_ones +# Then check the main function +# CHECK: func.func @test_flex_attention(%arg0: !torch.vtensor<[4,8,1024,64],f32>, %arg1: !torch.vtensor<[4,8,1024,64],f32>, %arg2: !torch.vtensor<[4,8,1024,64],f32>) +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32> +# Validate flex_attention op with 3 results and 6 operands: +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool false +# CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false +# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.hop_flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} +# CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool +# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32> +# CHECK: return %[[OUTPUT]] +def test_flex_attention(): + from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, + ) + from torch.nn.attention.flex_attention import ( + BlockMask, + _LARGE_SPARSE_BLOCK_SIZE, + create_block_mask, + flex_attention, + ) + from torch import Tensor + + def _create_empty_block_mask(query: Tensor, key: Tensor): + # Default block mask for flex attention. + device = query.device + return BlockMask.from_kv_blocks( + kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), + kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), + BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, + seq_lengths=(1, 1), + ).as_tuple() + + def relative_position_bias( + score: Tensor, + batch: Tensor, + head: Tensor, + token_q: Tensor, + token_kv: Tensor, + ) -> Tensor: + # Simple score mod function. + return torch.tanh(score) + + class FlexAttention(torch.nn.Module): + def __init__(self, block_mask): + super().__init__() + self.block_mask = block_mask + + def forward(self, q, k, v): + output, lse, max_scores = flex_attention_hop( + q, + k, + v, + score_mod=relative_position_bias, + block_mask=self.block_mask, + scale=1.0, + kernel_options={}, + ) + return output + + # Export -> import to Torch-MLIR + B, Hq, Hkv, L, S, E, Ev = 4, 8, 8, 1024, 1024, 64, 64 + q = torch.ones(B, Hq, L, E) + k = torch.ones(B, Hkv, S, E) + v = torch.ones(B, Hkv, S, Ev) + m = fx.export_and_import( + FlexAttention(_create_empty_block_mask(q, k)), + q, + k, + v, + func_name="test_flex_attention", + ) + print(m) + + @run # CHECK-LABEL: test_stack_trace # CHECK: #loc[[LOC1:.+]] = loc(