Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c8c711c
Modified fx_importer to support hop_while_loop
keshavvinayak01 Oct 22, 2025
b250583
Addressed Comments | Simplified unique child_func_name creation
keshavvinayak01 Oct 23, 2025
db1e7e9
Addressed comments
keshavvinayak01 Oct 24, 2025
d9646c6
Formatting
keshavvinayak01 Oct 24, 2025
cc03291
Added children module imports to import_frozen_program flow
keshavvinayak01 Oct 24, 2025
6a70e1c
Formatting and reordered CHECKs
keshavvinayak01 Oct 24, 2025
85e3acd
Changes done to TorchToScf:
keshavvinayak01 Oct 24, 2025
e1ff87d
Added Control flow test
keshavvinayak01 Oct 27, 2025
558c7db
Cannot FX trace HOP
keshavvinayak01 Oct 28, 2025
39d5b24
Added flex_attention hop function
keshavvinayak01 Oct 28, 2025
dfdca75
Formatting
keshavvinayak01 Oct 28, 2025
6178d07
Fixed merge newline removals
keshavvinayak01 Oct 28, 2025
52f1fbc
Added AtenFluxAttentionOp
keshavvinayak01 Oct 29, 2025
a56433a
Added changes for correct functional references
keshavvinayak01 Oct 30, 2025
b0e8585
QOL changes:
keshavvinayak01 Nov 4, 2025
c34efab
Merge branch 'main' into keshavvinayak01/torch-aten-flex_attention
keshavvinayak01 Nov 4, 2025
4470978
Update fx_importer.py to remove deprecated note
keshavvinayak01 Nov 4, 2025
719fe5a
Clarify enable_gqa support in fx_importer.py
keshavvinayak01 Nov 4, 2025
5e024f6
Fix formatting in GeneratedTorchOps.td
keshavvinayak01 Nov 4, 2025
c78d699
return_lse is part of the kernel options
keshavvinayak01 Nov 6, 2025
da23ec9
Moved op definition to TorchOps.td
keshavvinayak01 Nov 7, 2025
af59413
Formatting TorchOps
keshavvinayak01 Nov 7, 2025
0103163
Added lit-test; Docs for FlexAttention
keshavvinayak01 Nov 7, 2025
48f12bc
Formatting
keshavvinayak01 Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1442,4 +1442,68 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// FlexAttention operation

// NOTE: This op is manually defined because `aten::flex_attention` exists in
// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet
// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script
// validates against the JIT registry, so it cannot auto-generate this op.
// Once PyTorch adds flex_attention to the JIT registry, this can be moved to
// the auto-generated section.
//===----------------------------------------------------------------------===//
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::flex_attention`";
let description = [{
FlexAttention operation with flexible block-sparse attention patterns.

Args:
query: Query tensor [B, H, M, K]
key: Key tensor [B, H, N, K]
value: Value tensor [B, H, N, Ev]
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
return_lse: Bool to return log-sum-exp values

Attributes:
score_mod_fn: Optional function symbol reference for score modification
mask_mod_fn: Optional function symbol reference for mask modification

# TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)

Returns:
output: Result tensor [B, H, M, Ev]
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
}];

let arguments = (ins
AnyTorchTensorType:$query,
AnyTorchTensorType:$key,
AnyTorchTensorType:$value,
AnyTorchOptionalFloatType:$scale,
Torch_BoolType:$enable_gqa,
Torch_BoolType:$return_lse,
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
);

let results = (outs
AnyTorchTensorType:$output,
AnyTorchOptionalTensorType:$logsumexp
);

let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}

#endif // TORCH_OPS
156 changes: 156 additions & 0 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,162 @@ def _import_hop_auto_functionalized(
for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i + bind_none)

def _import_hop_flex_attention(
self, loc: Location, node: torch_fx.Node, hop: HigherOrderOperator
):
"""Imports the torch._higher_order_ops.flex_attention HOP.

Args format: (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, ...)
- query, key, value: Attention input tensors
- score_mod: Optional submodule/callable for score modification (imported as function)
- block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors
- scale: Optional float for attention score scaling
- enable_gqa: Boolean for grouped query attention support
- kernel_options: Dict of performance tuning options:
- return_lse: Boolean for whether to return the log-sum-exp tensor

This creates a call to aten.flex_attention with function symbol references for
score_mod and mask_mod. The return_lse flag is extracted from kernel_options.
"""
# flex_attention HOP args from PyTorch:
# (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...)
if len(node.args) < 6:
raise ValueError(
f"flex_attention expects at least 6 arguments, got {len(node.args)}"
)

query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = (
node.args[:6]
)

# This is a boolean flag that enables GQA optimization
enable_gqa = node.args[6] if len(node.args) > 6 else False

# TODO: Add support for kernel_options (performance tuning parameters)
# This is a dict containing options like block sizes, num_warps, etc.
kernel_options = node.args[7] if len(node.args) > 7 else {}

# Import Q, K, V tensors
query = self._import_argument(loc, query_arg, None)
key = self._import_argument(loc, key_arg, None)
value = self._import_argument(loc, value_arg, None)

score_mod_ref = None
if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node):
assert (
score_mod_arg.op == "get_attr"
), f"Expected get_attr for score_mod, got {score_mod_arg.op}"
root_module = node.graph.owning_module
score_mod_module = getattr(root_module, score_mod_arg.target, None)
if score_mod_module is not None:
score_mod_func_name = self.fx_importer._graph_module_to_func_name[
id(score_mod_module)
]
score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name)

# Handle block_mask: extract only mask_mod function reference
# Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.)
# that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx).
mask_mod_ref = None
if block_mask_arg is not None and isinstance(block_mask_arg, tuple):
root_module = node.graph.owning_module
# The mask_mod function is the last element in the BlockMask tuple
mask_mod_arg = block_mask_arg[-1]
if mask_mod_arg is not None and isinstance(mask_mod_arg, torch_fx.Node):
assert (
mask_mod_arg.op == "get_attr"
), f"Expected get_attr for mask_mod, got {mask_mod_arg.op}"
mask_mod_module = getattr(root_module, mask_mod_arg.target, None)
if mask_mod_module is not None:
mask_mod_func_name = self.fx_importer._graph_module_to_func_name[
id(mask_mod_module)
]
mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name)

# Import scale (float or None)
if scale_arg is None:
scale = Operation.create(
"torch.constant.none",
results=[self._cc.torch_none_type],
loc=loc,
).result
elif isinstance(scale_arg, (int, float)):
with loc:
scale = _make_constant_op(
"torch.constant.float",
FloatAttr.get_f64(float(scale_arg)),
self._cc.torch_float_type,
).result
else:
scale = self._import_argument(loc, scale_arg, None)

# Determine result types from node metadata
node_val = node.meta.get("val")
if isinstance(node_val, (list, tuple)) and len(node_val) >= 2:
# flex_attention returns (output, logsumexp)
result_types = [self._cc.value_info_to_type(v) for v in node_val]
self._multi_result_nodes.add(node)
else:
# Single output
result_types = [self._cc.node_val_to_type(node)]

with loc:
enable_gqa_value = _make_constant_op(
"torch.constant.bool",
self._cc.integer_attr(1 if enable_gqa else 0, 1),
self._cc.torch_bool_type,
).result

# Extract return_lse from kernel_options
return_lse_value = False
if isinstance(kernel_options, dict):
return_lse_value = kernel_options.get("return_lse", False)

with loc:
return_lse = _make_constant_op(
"torch.constant.bool",
self._cc.integer_attr(1 if return_lse_value else 0, 1),
self._cc.torch_bool_type,
).result

# Build operands for aten.flex_attention.
# Op expects exactly 5 operands: query, key, value, scale, return_lse.
# Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands.
# Note: block_mask tensors are handled by mask_mod_fn, not passed as operands.

flat_operands = [
query,
key,
value,
scale,
enable_gqa_value,
return_lse,
]

# Build attributes with function references
# Only include attributes if they're not None (OptionalAttr in TableGen)
attributes = {}
if score_mod_ref is not None:
attributes["score_mod_fn"] = score_mod_ref
if mask_mod_ref is not None:
attributes["mask_mod_fn"] = mask_mod_ref

operation = Operation.create(
"torch.aten.flex_attention",
results=result_types,
operands=flat_operands,
attributes=attributes if attributes else None,
loc=loc,
)

# Bind results
if len(result_types) > 1:
self._multi_result_nodes.add(node)
for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i)
else:
self.bind_node_value(node, operation.results[0])

def _import_torch_op_overload(
self,
loc: Location,
Expand Down
34 changes: 34 additions & 0 deletions test/Dialect/Torch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,37 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to
%1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
return %1 : !torch.vtensor<[3,3],f32>
}

// CHECK-LABEL: func.func @torch.aten.flex_attention
func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>) {
%float1.0 = torch.constant.float 1.000000e+00
%false_0 = torch.constant.bool false
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
return %output, %logsumexp : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
}

func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg3, %arg4, %int1 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.int -> !torch.vtensor<[],si32>
%float1.000000e-01 = torch.constant.float 1.000000e-01
%1 = torch.aten.mul.Scalar %arg2, %float1.000000e-01 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
%float1.000000e-02 = torch.constant.float 1.000000e-02
%2 = torch.aten.mul.Scalar %0, %float1.000000e-02 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
%int1_0 = torch.constant.int 1
%3 = torch.aten.add.Tensor %arg0, %2, %int1_0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
%int1_1 = torch.constant.int 1
%4 = torch.aten.add.Tensor %3, %1, %int1_1 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
%5 = torch.aten.tanh %4 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %5 : !torch.vtensor<[],f32>
}

func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}
Loading