Skip to content

Commit 208695c

Browse files
Arm backend: Add docstrings for operators/node_visitor.py (#15693)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 9c7cb61 commit 208695c

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

backends/arm/operators/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Import and register Arm TOSA operator visitors.
6+
7+
Importing this package loads all visitor modules so their classes can be
8+
registered via decorators and discovered at runtime.
9+
10+
"""
511

612

713
from . import ( # noqa

backends/arm/operators/node_visitor.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
"""Provide utilities to register and apply TOSA node visitors.
7+
8+
Use this module to construct and serialize TOSA operators from FX nodes.
9+
- Define the NodeVisitor base class and registry
10+
- Register concrete visitors per TOSA specification
11+
12+
"""
613

714
import json
815
from typing import Any, Dict, List, Optional
@@ -18,8 +25,13 @@
1825

1926

2027
class NodeVisitor:
21-
"""
22-
Node Visitor pattern for lowering edge IR to TOSA
28+
"""Provide a visitor pattern to lower edge IR to TOSA.
29+
30+
Attributes:
31+
_exported_program (torch.export.ExportedProgram): Source program being lowered.
32+
tosa_spec (TosaSpecification): Active TOSA specification for lowering.
33+
debug_hook (Optional[DebugHook]): Optional hook for debug metadata.
34+
2335
"""
2436

2537
# Add the currently supported node_visitor specs as default.
@@ -51,6 +63,23 @@ def _serialize_operator(
5163
outputs: List[str],
5264
attributes: Optional[Any] = None,
5365
) -> None:
66+
"""Serialize a TOSA operator into the graph.
67+
68+
When a ``DebugHook`` is active, attach location metadata (in JSON) to
69+
the operator for traceability.
70+
71+
Args:
72+
node (torch.fx.Node): Source FX node being lowered.
73+
tosa_graph: Target TOSA serializer/graph object.
74+
tosa_op: TOSA operator enum value to emit.
75+
inputs (List[str]): Names of input tensors.
76+
outputs (List[str]): Names of output tensors.
77+
attributes (Optional[Any]): Optional TOSA attribute object.
78+
79+
Returns:
80+
None: Mutates ``tosa_graph`` in place.
81+
82+
"""
5483
op_location = ts.TosaOpLocation()
5584
if self.debug_hook:
5685
debug_info = self.debug_hook.add(
@@ -77,6 +106,21 @@ def define_node(
77106
inputs: List[TosaArg],
78107
output: TosaArg,
79108
) -> None:
109+
"""Define a TOSA operator node.
110+
111+
Args:
112+
node (torch.fx.Node): FX node being lowered.
113+
tosa_graph (serializer.tosa_serializer.TosaSerializer): Target TOSA graph.
114+
inputs (List[TosaArg]): Input tensor arguments.
115+
output (TosaArg): Output tensor descriptor.
116+
117+
Returns:
118+
None: Mutates ``tosa_graph`` in place.
119+
120+
Raises:
121+
ValueError: If input count or dtypes are invalid.
122+
123+
"""
80124
raise NotImplementedError("NodeVisitor must be extended.")
81125

82126

@@ -88,12 +132,14 @@ def define_node(
88132

89133

90134
def register_node_visitor(visitor):
135+
"""Register a concrete ``NodeVisitor`` class for its TOSA specs."""
91136
for tosa_spec in visitor.tosa_specs:
92137
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
93138
return visitor
94139

95140

96141
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
142+
"""Return a mapping from target names to visitor instances for a spec."""
97143
node_visitors = {}
98144
tosa_spec = None
99145
for arg in args:

0 commit comments

Comments
 (0)