33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ """Provide utilities to register and apply TOSA node visitors.
7+
8+ Use this module to construct and serialize TOSA operators from FX nodes.
9+ - Define the NodeVisitor base class and registry
10+ - Register concrete visitors per TOSA specification
11+
12+ """
613
714import json
815from typing import Any , Dict , List , Optional
1825
1926
2027class NodeVisitor :
21- """
22- Node Visitor pattern for lowering edge IR to TOSA
28+ """Provide a visitor pattern to lower edge IR to TOSA.
29+
30+ Attributes:
31+ _exported_program (torch.export.ExportedProgram): Source program being lowered.
32+ tosa_spec (TosaSpecification): Active TOSA specification for lowering.
33+ debug_hook (Optional[DebugHook]): Optional hook for debug metadata.
34+
2335 """
2436
2537 # Add the currently supported node_visitor specs as default.
@@ -51,6 +63,23 @@ def _serialize_operator(
5163 outputs : List [str ],
5264 attributes : Optional [Any ] = None ,
5365 ) -> None :
66+ """Serialize a TOSA operator into the graph.
67+
68+ When a ``DebugHook`` is active, attach location metadata (in JSON) to
69+ the operator for traceability.
70+
71+ Args:
72+ node (torch.fx.Node): Source FX node being lowered.
73+ tosa_graph: Target TOSA serializer/graph object.
74+ tosa_op: TOSA operator enum value to emit.
75+ inputs (List[str]): Names of input tensors.
76+ outputs (List[str]): Names of output tensors.
77+ attributes (Optional[Any]): Optional TOSA attribute object.
78+
79+ Returns:
80+ None: Mutates ``tosa_graph`` in place.
81+
82+ """
5483 op_location = ts .TosaOpLocation ()
5584 if self .debug_hook :
5685 debug_info = self .debug_hook .add (
@@ -77,6 +106,21 @@ def define_node(
77106 inputs : List [TosaArg ],
78107 output : TosaArg ,
79108 ) -> None :
109+ """Define a TOSA operator node.
110+
111+ Args:
112+ node (torch.fx.Node): FX node being lowered.
113+ tosa_graph (serializer.tosa_serializer.TosaSerializer): Target TOSA graph.
114+ inputs (List[TosaArg]): Input tensor arguments.
115+ output (TosaArg): Output tensor descriptor.
116+
117+ Returns:
118+ None: Mutates ``tosa_graph`` in place.
119+
120+ Raises:
121+ ValueError: If input count or dtypes are invalid.
122+
123+ """
80124 raise NotImplementedError ("NodeVisitor must be extended." )
81125
82126
@@ -88,12 +132,14 @@ def define_node(
88132
89133
90134def register_node_visitor (visitor ):
135+ """Register a concrete ``NodeVisitor`` class for its TOSA specs."""
91136 for tosa_spec in visitor .tosa_specs :
92137 _node_visitor_dicts [tosa_spec ][visitor .target ] = visitor
93138 return visitor
94139
95140
96141def get_node_visitors (* args ) -> Dict [str , NodeVisitor ]:
142+ """Return a mapping from target names to visitor instances for a spec."""
97143 node_visitors = {}
98144 tosa_spec = None
99145 for arg in args :
0 commit comments