Skip to content

Commit b0e8585

Browse files
QOL changes:
1. Better documentation for AtenFlexAttentionOp 2. Function referece added as attributes to aten.flex_attention 3. Updates to _import_hop_flex_attention reflecting latest changes of module import. 4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent a56433a commit b0e8585

File tree

2 files changed

+84
-133
lines changed

2 files changed

+84
-133
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16194,62 +16194,60 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
1619416194
let hasFolder = 1;
1619516195
}
1619616196

16197-
1619816197
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
1619916198
AllowsTypeRefinement,
1620016199
HasValueSemantics,
1620116200
ReadOnly
1620216201
]> {
16203-
let summary = "Generated op for `aten::flex_attention : (Tensor, Tensor, Tensor, Any?, Any?, float?, bool, Any?, bool) -> (Tensor, Tensor)`";
16202+
let summary = "Generated op for `aten::flex_attention`";
1620416203
let description = [{
16205-
Flexible attention operator that supports custom score modification and masking.
16206-
16204+
FlexAttention operation with flexible block-sparse attention patterns.
16205+
1620716206
Args:
16208-
query: Query tensor [B, H, M, E]
16209-
key: Key tensor [B, H, N, E]
16207+
query: Query tensor [B, H, M, K]
16208+
key: Key tensor [B, H, N, K]
1621016209
value: Value tensor [B, H, N, Ev]
16211-
score_mod: Optional callable to modify attention scores (represented as None or opaque type)
16212-
block_mask: Optional BlockMask tuple for sparse attention patterns
1621316210
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
16214-
enable_gqa: bool for grouped query attention support
16215-
kernel_options: Optional dict of kernel configuration options
16216-
return_lse: bool to return log-sum-exp values
16211+
return_lse: Bool to return log-sum-exp values
1621716212

16218-
Returns:
16219-
- If return_lse=False: Just the output tensor [B, H, M, Ev]
16220-
- If return_lse=True: Tuple of (output [B, H, M, Ev], logsumexp [B, H, M])
16213+
Attributes:
16214+
score_mod_fn: Optional function symbol reference for score modification
16215+
mask_mod_fn: Optional function symbol reference for mask modification
16216+
16217+
# TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
1622116218

16222-
Note: score_mod and block_mask are higher-order/complex types in PyTorch.
16223-
For MLIR representation, score_mod is represented as None (identity) or an opaque type,
16224-
and block_mask is represented as None or a tuple/list of tensors containing the block indices.
16219+
Returns:
16220+
output: Result tensor [B, H, M, Ev]
16221+
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
1622516222
}];
16223+
1622616224
let arguments = (ins
1622716225
AnyTorchTensorType:$query,
1622816226
AnyTorchTensorType:$key,
1622916227
AnyTorchTensorType:$value,
16230-
AnyType:$score_mod,
16231-
AnyType:$block_mask,
1623216228
AnyTorchOptionalFloatType:$scale,
16233-
Torch_BoolType:$enable_gqa,
16234-
AnyType:$kernel_options,
16235-
Torch_BoolType:$return_lse
16229+
Torch_BoolType:$enable_gqa
16230+
Torch_BoolType:$return_lse,
16231+
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
16232+
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
1623616233
);
16234+
1623716235
let results = (outs
1623816236
AnyTorchTensorType:$output,
1623916237
AnyTorchOptionalTensorType:$logsumexp
1624016238
);
16239+
1624116240
let hasCustomAssemblyFormat = 1;
1624216241
let extraClassDefinition = [{
1624316242
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
16244-
return parseDefaultTorchOp(parser, result, 9, 2);
16243+
return parseDefaultTorchOp(parser, result, 5, 2);
1624516244
}
1624616245
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
16247-
printDefaultTorchOp(printer, *this, 9, 2);
16246+
printDefaultTorchOp(printer, *this, 5, 2);
1624816247
}
1624916248
}];
1625016249
}
1625116250

16252-
1625316251
def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [
1625416252
AllowsTypeRefinement,
1625516253
HasValueSemantics,

python/torch_mlir/extras/fx_importer.py

Lines changed: 61 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,14 +1922,19 @@ def _import_hop_flex_attention(
19221922
):
19231923
"""Imports the torch._higher_order_ops.flex_attention HOP.
19241924
1925-
Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...)
1926-
The score_mod is a submodule/callable that has been imported as a private function.
1927-
The block_mask is a tuple: (kv_num_blocks, kv_indices, ..., mask_mod)
1928-
1929-
This creates a call to aten.flex_attention with function symbol references.
1925+
Args format: (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, ...)
1926+
- query, key, value: Attention input tensors
1927+
- score_mod: Optional submodule/callable for score modification (imported as function)
1928+
- block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors
1929+
- scale: Optional float for attention score scaling
1930+
- enable_gqa: Boolean for grouped query attention support (TODO: NYI)
1931+
- kernel_options: Dict of performance tuning options (TODO: NYI)
1932+
1933+
This creates a call to aten.flex_attention with function symbol references for
1934+
score_mod and mask_mod.
19301935
"""
19311936
# flex_attention HOP args from PyTorch:
1932-
# (query, key, value, score_mod, block_mask, scale, kernel_options, return_lse_tuple, ...)
1937+
# (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...)
19331938
if len(node.args) < 6:
19341939
raise ValueError(
19351940
f"flex_attention expects at least 6 arguments, got {len(node.args)}"
@@ -1938,68 +1943,51 @@ def _import_hop_flex_attention(
19381943
query_arg, key_arg, value_arg, score_mod_arg, block_mask_arg, scale_arg = (
19391944
node.args[:6]
19401945
)
1941-
kernel_options = node.args[6] if len(node.args) > 6 else {}
1946+
1947+
# TODO: Add support for enable_gqa (grouped query attention)
1948+
# This is a boolean flag that enables GQA optimization
1949+
enable_gqa = node.args[6] if len(node.args) > 6 else False
1950+
1951+
# TODO: Add support for kernel_options (performance tuning parameters)
1952+
# This is a dict containing options like block sizes, num_warps, etc.
1953+
kernel_options = node.args[7] if len(node.args) > 7 else {}
19421954

19431955
# Import Q, K, V tensors
19441956
query = self._import_argument(loc, query_arg, None)
19451957
key = self._import_argument(loc, key_arg, None)
19461958
value = self._import_argument(loc, value_arg, None)
19471959

1948-
# Handle score_mod: extract function reference from submodule
19491960
score_mod_ref = None
19501961
if score_mod_arg is not None and isinstance(score_mod_arg, torch_fx.Node):
1951-
# score_mod is a GraphModule reference from get_attr
1962+
assert (
1963+
score_mod_arg.op == "get_attr"
1964+
), f"Expected get_attr for score_mod, got {score_mod_arg.op}"
19521965
root_module = node.graph.owning_module
1953-
if hasattr(score_mod_arg, "target"):
1954-
score_mod_name = score_mod_arg.target
1955-
score_mod_module = getattr(root_module, score_mod_name, None)
1956-
if score_mod_module is not None:
1957-
score_mod_func_name = score_mod_name
1958-
score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name)
1959-
1960-
# Handle block_mask: extract mask_mod function and tensor components
1961-
# block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
1962-
# kv_block_size, q_block_size, ..., mask_mod)
1963-
mask_mod_ref = None
1964-
block_mask_tensors = []
1965-
kv_block_size = None
1966-
q_block_size = None
1966+
score_mod_module = getattr(root_module, score_mod_arg.target, None)
1967+
if score_mod_module is not None:
1968+
score_mod_func_name = self.fx_importer._graph_module_to_func_name[
1969+
id(score_mod_module)
1970+
]
1971+
score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name)
19671972

1973+
# Handle block_mask: extract only mask_mod function reference
1974+
# Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.)
1975+
# that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx).
1976+
mask_mod_ref = None
19681977
if block_mask_arg is not None and isinstance(block_mask_arg, tuple):
1969-
# Parse the block_mask tuple structure
1970-
# First two entries: kv_num_blocks (int), kv_indices (tensor)
1971-
# Next two: q_num_blocks (tensor), q_indices (tensor)
1972-
# Then: scalar dimensions and the mask_mod function at the end
19731978
root_module = node.graph.owning_module
1974-
1975-
for i, component in enumerate(block_mask_arg):
1976-
if isinstance(component, torch_fx.Node):
1977-
# Check if it's a tensor or a submodule reference
1978-
if component.op == "get_attr" and hasattr(
1979-
root_module, component.target
1980-
):
1981-
obj = getattr(root_module, component.target)
1982-
# Check if it's a GraphModule (mask_mod) or a tensor
1983-
if isinstance(obj, GraphModule):
1984-
# This is the mask_mod function
1985-
mask_mod_func_name = component.target
1986-
mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name)
1987-
else:
1988-
# It's a tensor (block indices)
1989-
block_mask_tensors.append(
1990-
self._import_argument(loc, component, None)
1991-
)
1992-
else:
1993-
# Regular tensor argument
1994-
block_mask_tensors.append(
1995-
self._import_argument(loc, component, None)
1996-
)
1997-
elif isinstance(component, int):
1998-
# Scalar dimensions (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
1999-
if kv_block_size is None:
2000-
kv_block_size = component
2001-
elif q_block_size is None:
2002-
q_block_size = component
1979+
# The mask_mod function is the last element in the BlockMask tuple
1980+
mask_mod_arg = block_mask_arg[-1]
1981+
if mask_mod_arg is not None and isinstance(mask_mod_arg, torch_fx.Node):
1982+
assert (
1983+
mask_mod_arg.op == "get_attr"
1984+
), f"Expected get_attr for mask_mod, got {mask_mod_arg.op}"
1985+
mask_mod_module = getattr(root_module, mask_mod_arg.target, None)
1986+
if mask_mod_module is not None:
1987+
mask_mod_func_name = self.fx_importer._graph_module_to_func_name[
1988+
id(mask_mod_module)
1989+
]
1990+
mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name)
20031991

20041992
# Import scale (float or None)
20051993
if scale_arg is None:
@@ -2018,17 +2006,6 @@ def _import_hop_flex_attention(
20182006
else:
20192007
scale = self._import_argument(loc, scale_arg, None)
20202008

2021-
# Get enable_gqa from kernel_options if present
2022-
enable_gqa = False
2023-
if isinstance(kernel_options, dict) and "enable_gqa" in kernel_options:
2024-
enable_gqa = kernel_options["enable_gqa"]
2025-
with loc:
2026-
enable_gqa_value = _make_constant_op(
2027-
"torch.constant.bool",
2028-
self._cc.integer_attr(1 if enable_gqa else 0, 1),
2029-
self._cc.torch_bool_type,
2030-
).result
2031-
20322009
# Determine result types from node metadata
20332010
node_val = node.meta.get("val")
20342011
if isinstance(node_val, (list, tuple)) and len(node_val) >= 2:
@@ -2039,6 +2016,13 @@ def _import_hop_flex_attention(
20392016
# Single output
20402017
result_types = [self._cc.node_val_to_type(node)]
20412018

2019+
with loc:
2020+
enable_gqa_value = _make_constant_op(
2021+
"torch.constant.bool",
2022+
self._cc.integer_attr(1 if enable_gqa else 0, 1),
2023+
self._cc.torch_bool_type,
2024+
).result
2025+
20422026
with loc:
20432027
return_lse = _make_constant_op(
20442028
"torch.constant.bool",
@@ -2059,58 +2043,27 @@ def _import_hop_flex_attention(
20592043
self._cc.torch_bool_type,
20602044
).result
20612045

2062-
# Build operands for aten.flex_attention
2063-
# Note: score_mod and block_mask function references go as ATTRIBUTES, not operands
2064-
2065-
# Handle block_mask: wrap tensors in a list construct if present
2066-
if block_mask_tensors:
2067-
# Wrap block_mask tensors in torch.prim.ListConstruct
2068-
block_mask_list = Operation.create(
2069-
"torch.prim.ListConstruct",
2070-
results=[IrType.parse("!torch.list<vtensor>", context=self._c)],
2071-
operands=block_mask_tensors,
2072-
loc=loc,
2073-
).result
2074-
else:
2075-
# No block mask, use None
2076-
block_mask_list = Operation.create(
2077-
"torch.constant.none",
2078-
results=[self._cc.torch_none_type],
2079-
loc=loc,
2080-
).result
2046+
# Build operands for aten.flex_attention.
2047+
# Op expects exactly 5 operands: query, key, value, scale, return_lse.
2048+
# Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands.
2049+
# Note: block_mask tensors are handled by mask_mod_fn, not passed as operands.
20812050

20822051
flat_operands = [
20832052
query,
20842053
key,
20852054
value,
2086-
# score_mod placeholder (None)
2087-
Operation.create(
2088-
"torch.constant.none",
2089-
results=[self._cc.torch_none_type],
2090-
loc=loc,
2091-
).result,
2092-
# block_mask as single list operand
2093-
block_mask_list,
20942055
scale,
20952056
enable_gqa_value,
2096-
# Kernel options as None
2097-
Operation.create(
2098-
"torch.constant.none",
2099-
results=[self._cc.torch_none_type],
2100-
loc=loc,
2101-
).result,
2102-
# return_lse
21032057
return_lse,
21042058
]
21052059

21062060
# Build attributes with function references
2107-
attributes = {
2108-
"score_mod_fn": score_mod_ref,
2109-
"mask_mod_fn": mask_mod_ref,
2110-
"kv_block_size": self._cc.integer_attr(kv_block_size, 64),
2111-
"q_block_size": self._cc.integer_attr(q_block_size, 64),
2112-
}
2113-
attributes = {k: v for k, v in attributes.items() if v is not None}
2061+
# Only include attributes if they're not None (OptionalAttr in TableGen)
2062+
attributes = {}
2063+
if score_mod_ref is not None:
2064+
attributes["score_mod_fn"] = score_mod_ref
2065+
if mask_mod_ref is not None:
2066+
attributes["mask_mod_fn"] = mask_mod_ref
21142067

21152068
operation = Operation.create(
21162069
"torch.aten.flex_attention",

0 commit comments

Comments
 (0)