Skip to content

Commit a769114

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor] FX backend via Wrapper IR (pytorch#146942)
# Sub-PRs These PRs contain refactors from the main one. They should be reviewed and merged first. - pytorch#150458 - pytorch#152391 - pytorch#152587 # Feature The goals of this PR are twofold. ## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen. In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components. This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR. ## Goal 2: Convert Wrapper IR into FX IR. One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc. It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes. The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source. Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa. # Current status Things that seem to work: - Converted a lot of the most common Python codegen lines to Wrapper IR lines. - Handled the following cases, in addition to what was already in the Memory Planning IR: - Comments - Triton kernels - Extern/fallback kernels - Freeing tensors (`del buf0`) - MultiOutput - Graph outputs - ReinterpretView / StorageBox, for both call args and outputs. - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code. - Prototype FX converter which can handle some of the most common use cases. - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565). - Calling wrapped Triton kernels. - Calling extern kernels and certain types of fallback kernels. - Support both `extern_kernels.*` and `aten.*`. - Support multi-output kernels like `torch.topk`. - Graphs with multiple inputs/outputs. - Training i.e. calling `Tensor.backward()` in a compiled function. - Graph breaks (training). - Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen. Things that don't work: - Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections. - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX. - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR. - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this. # Out-of-tree compilers With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below. ``` from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen class MyCustomBackend(WrapperFxCodegen): def compile_graph(self, gm): # Add 1 to the graph's outputs def compiled_fn(*args): return [x + 1 for x in gm.graph.forward(*args)] return compiled_fn ``` # Example FX graphs This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`. Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) return (buf0,) ``` Here's a more complicated graph that calls a `torch.addmm` extern kernel. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}}) %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) return (buf2,) ``` Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {}) %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}}) %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}}) return (buf1, buf2) ``` Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {}) return (buf0_view_buf0_0,) ``` Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`. ``` graph(): %s6 : [num_users=0] = placeholder[target=s6] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}}) return buf0 ``` Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions. ``` graph(): %s10 : [num_users=0] = placeholder[target=s10] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}}) return buf0 ``` Pull Request resolved: pytorch#146942 Approved by: https://github.com/jansel
1 parent fdadda2 commit a769114

File tree

9 files changed

+1251
-49
lines changed

9 files changed

+1251
-49
lines changed

test/inductor/test_fxir_backend.py

Lines changed: 417 additions & 0 deletions
Large diffs are not rendered by default.

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,25 +1750,39 @@ def check_grid(
17501750
# normalize to tuple
17511751
return tuple(grid)
17521752

1753-
def call_HOP(
1753+
def store_non_graphable_args(
17541754
self,
1755-
variable: "TraceableTritonKernelWrapper",
1756-
grids: list["TritonGridTupleType"],
17571755
combined_args: dict[str, Any],
1758-
tx: None,
1759-
) -> None:
1760-
assert tx is None
1761-
assert isinstance(variable, TraceableTritonKernelWrapper)
1756+
) -> tuple[dict, int]:
1757+
"""
1758+
Some args cannot be stored in the FX graph.
1759+
Put them in the side table.
1760+
"""
17621761

17631762
def is_graphable(val: Any) -> bool:
1764-
return isinstance(val, fx.node.base_types)
1763+
return isinstance(val, (fx.node.base_types, fx.Node))
17651764

17661765
non_graphable_args = {
17671766
k: v for k, v in combined_args.items() if not is_graphable(v)
17681767
}
17691768
graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)}
17701769

17711770
constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args)
1771+
1772+
return graphable_args, constant_args_idx
1773+
1774+
def call_HOP(
1775+
self,
1776+
variable: "TraceableTritonKernelWrapper",
1777+
grids: list["TritonGridTupleType"],
1778+
combined_args: dict[str, Any],
1779+
tx: None,
1780+
) -> None:
1781+
assert tx is None
1782+
assert isinstance(variable, TraceableTritonKernelWrapper)
1783+
1784+
graphable_args, constant_args_idx = self.store_non_graphable_args(combined_args)
1785+
17721786
assert isinstance(variable.kernel_idx, int)
17731787
return triton_kernel_wrapper_mutation(
17741788
kernel_idx=variable.kernel_idx,

torch/_inductor/codegen/common.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import atexit
34
import contextlib
45
import dataclasses
56
import enum
@@ -8,8 +9,11 @@
89
import logging
910
import math
1011
import operator
12+
import os
1113
import re
14+
import tempfile
1215
import typing
16+
from abc import ABC, abstractmethod
1317
from enum import auto, Enum
1418
from itertools import chain
1519
from typing import (
@@ -60,6 +64,8 @@
6064
if TYPE_CHECKING:
6165
from collections.abc import Iterator, MutableMapping, Sequence
6266

67+
from torch.fx import GraphModule
68+
6369
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
6470
from ..loop_body import LoopBody
6571
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
@@ -83,6 +89,38 @@ def data_type_logger(msg: str) -> None:
8389
schedule_log.debug("Data type propagation: %s", msg)
8490

8591

92+
@dataclasses.dataclass
93+
class FileBackedGraphModule:
94+
"""
95+
Output of FX wrapper codegen. Exposes the same methods as ModuleType, but these
96+
map back to a GraphModule instead of Python source.
97+
"""
98+
99+
gm: GraphModule
100+
compiled_fn: Callable[..., Any]
101+
102+
def __post_init__(self) -> None:
103+
# Write the code to a file for compatibility with debugging utilities.
104+
# The file is deleted upon program termination.
105+
self.tempfile = tempfile.NamedTemporaryFile(
106+
mode="w+", suffix=".py", delete=False
107+
)
108+
atexit.register(os.remove, self.tempfile.name)
109+
with self.tempfile as f:
110+
f.write(self.value)
111+
112+
@property
113+
def __file__(self) -> str:
114+
return self.tempfile.name
115+
116+
def call(self, args: list[Any]) -> Any:
117+
return self.compiled_fn(*args)
118+
119+
@property
120+
def value(self) -> str:
121+
return self.gm.code
122+
123+
86124
class WorkspaceZeroMode(enum.Enum):
87125
UNINITIALIZED = 0
88126
ZERO_ON_CALL = 1 # kernel may leave workspace dirty
@@ -103,8 +141,22 @@ def from_bool(zero_fill: bool) -> WorkspaceZeroMode:
103141
return WorkspaceZeroMode.UNINITIALIZED
104142

105143

144+
class CodegenSymbol(ABC):
145+
"""
146+
An IR object possibly corresponding to a variable in the wrapper code.
147+
"""
148+
149+
@abstractmethod
150+
def get_name(self) -> str:
151+
pass
152+
153+
@abstractmethod
154+
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
155+
pass
156+
157+
106158
@ir_dataclass(frozen=True)
107-
class WorkspaceArg:
159+
class WorkspaceArg(CodegenSymbol):
108160
"""A temporary buffer used for a single kernel, then discarded.
109161
110162
Not registered as a traditional buffer since there are no users,
@@ -167,6 +219,9 @@ def get_device(self) -> torch.device:
167219
def get_dtype(self) -> torch.dtype:
168220
return self.dtype
169221

222+
def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
223+
return self.get_layout().get_example()
224+
170225
def get_layout(self) -> FixedLayout:
171226
from ..ir import FixedLayout
172227

@@ -185,6 +240,9 @@ def layout(self) -> FixedLayout:
185240
maybe_get_output_spec = get_layout
186241
maybe_get_layout = get_layout
187242

243+
def get_offset(self) -> sympy.Expr:
244+
return sympy.S.Zero
245+
188246
def get_size(self) -> list[sympy.Expr]:
189247
return [self.count]
190248

torch/_inductor/codegen/wrapper.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import triton
7575

7676
from ..graph import GraphLowering
77+
from .wrapper_fxir import FxConverter
7778

7879

7980
log = logging.getLogger(__name__)
@@ -83,6 +84,7 @@
8384

8485
ReuseKey = tuple[torch.device, torch.dtype, str, bool]
8586
BufferLike = Union[ir.Buffer, WorkspaceArg]
87+
FxConversionFunc = Callable[["WrapperLine"], None]
8688

8789

8890
def buffer_reuse_key(node: BufferLike) -> ReuseKey:
@@ -349,7 +351,8 @@ def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None:
349351

350352

351353
class WrapperLine:
352-
pass
354+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
355+
raise NotImplementedError("FX codegen not yet supported for type {type(self)}")
353356

354357

355358
@dataclasses.dataclass
@@ -364,6 +367,9 @@ def codegen(self, code: IndentedBuffer) -> None:
364367
self.wrapper.push_codegened_graph(self.graph)
365368
code.do_indent()
366369

370+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
371+
return converter._generate_enter_subgraph
372+
367373

368374
@dataclasses.dataclass
369375
class CommentLine(WrapperLine):
@@ -372,6 +378,10 @@ class CommentLine(WrapperLine):
372378
def codegen(self, code: IndentedBuffer) -> None:
373379
code.writeline(self.line)
374380

381+
@staticmethod
382+
def codegen_fx(converter: FxConverter) -> FxConversionFunc:
383+
return converter._generate_comment
384+
375385

376386
@dataclasses.dataclass
377387
class ExitSubgraphLine(WrapperLine):
@@ -384,6 +394,9 @@ def codegen(self, code: IndentedBuffer) -> None:
384394
self.wrapper.pop_codegened_graph()
385395
code.do_unindent()
386396

397+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
398+
return converter._generate_exit_subgraph
399+
387400

388401
@dataclasses.dataclass
389402
class EnterDeviceContextManagerLine(WrapperLine):
@@ -419,12 +432,18 @@ def codegen(self, code: IndentedBuffer) -> None:
419432
code.do_indent()
420433
code.writeline(V.graph.device_ops.set_device(self.device_idx))
421434

435+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
436+
return converter._generate_enter_device_context_manager
437+
422438

423439
class ExitDeviceContextManagerLine(WrapperLine):
424440
def codegen(self, code: IndentedBuffer) -> None:
425441
if not V.graph.cpp_wrapper:
426442
code.do_unindent()
427443

444+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
445+
return converter._generate_exit_device_context_manager
446+
428447

429448
@dataclasses.dataclass
430449
class ExternKernelAllocLine(WrapperLine):
@@ -436,6 +455,9 @@ def codegen(self, code: IndentedBuffer) -> None:
436455
args = [*node.codegen_args(), *node.codegen_kwargs()]
437456
self.wrapper._generate_extern_kernel_alloc_helper(self.node, args)
438457

458+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
459+
return converter._generate_extern_kernel_alloc
460+
439461

440462
@dataclasses.dataclass
441463
class ExternKernelOutLine(WrapperLine):
@@ -466,6 +488,9 @@ def codegen(self, code: IndentedBuffer) -> None:
466488
device,
467489
)
468490

491+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
492+
return converter._generate_extern_kernel_out
493+
469494

470495
@dataclasses.dataclass
471496
class FreeLine(WrapperLine):
@@ -476,6 +501,9 @@ def codegen(self, code: IndentedBuffer) -> None:
476501
assert self.node.get_name() not in V.graph.removed_buffers
477502
code.writeline(self.wrapper.make_buffer_free(self.node))
478503

504+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
505+
return converter._generate_free
506+
479507

480508
@dataclasses.dataclass
481509
class KernelCallLine(WrapperLine):
@@ -505,6 +533,9 @@ def codegen(self, code: IndentedBuffer) -> None:
505533
original_fxnode_name=self.original_fxnode_name,
506534
)
507535

536+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
537+
return converter._generate_kernel_call
538+
508539

509540
@dataclasses.dataclass
510541
class KernelDefinitionLine(WrapperLine):
@@ -524,6 +555,9 @@ def codegen(self, code: IndentedBuffer) -> None:
524555
cpp_definition=self.cpp_definition,
525556
)
526557

558+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
559+
return converter._generate_kernel_definition
560+
527561

528562
@dataclasses.dataclass
529563
class MemoryPlanningLine(WrapperLine):
@@ -580,6 +614,9 @@ def codegen(self, code: IndentedBuffer) -> None:
580614
line = self.wrapper.make_buffer_allocation(self.node)
581615
code.writeline(line)
582616

617+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
618+
return converter._generate_allocate
619+
583620

584621
@dataclasses.dataclass
585622
class FreeIfNotReusedLine(MemoryPlanningLine):
@@ -603,6 +640,9 @@ def codegen(self, code: IndentedBuffer) -> None:
603640
if not self.is_reused:
604641
code.writeline(self.wrapper.make_buffer_free(self.node))
605642

643+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
644+
return converter._generate_free_if_not_reused
645+
606646

607647
@dataclasses.dataclass
608648
class ReinterpretLine(MemoryPlanningLine):
@@ -620,6 +660,9 @@ def codegen(self, code: IndentedBuffer) -> None:
620660
self.reused_as.get_name(), self.layout.view
621661
)
622662

663+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
664+
return converter._generate_reinterpret
665+
623666

624667
@dataclasses.dataclass
625668
class ReuseLine(MemoryPlanningLine):
@@ -641,9 +684,13 @@ def codegen(self, code: IndentedBuffer) -> None:
641684
self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old)
642685
)
643686

687+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
688+
return converter._generate_reuse
689+
644690

645691
class NullLine(MemoryPlanningLine):
646-
pass
692+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
693+
return converter._generate_null
647694

648695

649696
@dataclasses.dataclass
@@ -717,13 +764,19 @@ def make_allocation_line(
717764
f"Unsupported comm buffer type: {comm_buffer_type}"
718765
)
719766

767+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
768+
return converter._generate_comm_buffer_allocate
769+
720770

721771
@dataclasses.dataclass
722772
class CommBufferFreeLine(CommBufferLine):
723773
def codegen(self, code: IndentedBuffer) -> None:
724774
line = self.wrapper.make_buffer_free(self.node)
725775
code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free")
726776

777+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
778+
return converter._generate_comm_buffer_free
779+
727780

728781
@dataclasses.dataclass
729782
class MultiOutputLine(WrapperLine):
@@ -760,6 +813,22 @@ def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def
760813
f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}"
761814
)
762815

816+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
817+
return converter._generate_multi_output
818+
819+
820+
@dataclasses.dataclass
821+
class SymbolicCallArgLine(WrapperLine):
822+
wrapper: PythonWrapperCodegen
823+
arg: SymbolicCallArg
824+
graph: GraphLowering
825+
826+
def codegen(self, code: IndentedBuffer) -> None:
827+
self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph)
828+
829+
def codegen_fx(self, converter: FxConverter) -> FxConversionFunc:
830+
return converter._generate_symbolic_call_arg
831+
763832

764833
@dataclasses.dataclass
765834
class SymbolicCallArgLine(WrapperLine):

0 commit comments

Comments
 (0)