diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index b3d8f5676d5..6fd51497521 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Import and register Arm TOSA operator visitors. + +Importing this package loads all visitor modules so their classes can be +registered via decorators and discovered at runtime. + +""" from . import ( # noqa diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 682c849fe80..c03c27574b8 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -3,6 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide utilities to register and apply TOSA node visitors. + +Use this module to construct and serialize TOSA operators from FX nodes. +- Define the NodeVisitor base class and registry +- Register concrete visitors per TOSA specification + +""" import json from typing import Any, Dict, List, Optional @@ -18,8 +25,13 @@ class NodeVisitor: - """ - Node Visitor pattern for lowering edge IR to TOSA + """Provide a visitor pattern to lower edge IR to TOSA. + + Attributes: + _exported_program (torch.export.ExportedProgram): Source program being lowered. + tosa_spec (TosaSpecification): Active TOSA specification for lowering. + debug_hook (Optional[DebugHook]): Optional hook for debug metadata. + """ # Add the currently supported node_visitor specs as default. @@ -51,6 +63,23 @@ def _serialize_operator( outputs: List[str], attributes: Optional[Any] = None, ) -> None: + """Serialize a TOSA operator into the graph. + + When a ``DebugHook`` is active, attach location metadata (in JSON) to + the operator for traceability. + + Args: + node (torch.fx.Node): Source FX node being lowered. + tosa_graph: Target TOSA serializer/graph object. + tosa_op: TOSA operator enum value to emit. + inputs (List[str]): Names of input tensors. + outputs (List[str]): Names of output tensors. + attributes (Optional[Any]): Optional TOSA attribute object. + + Returns: + None: Mutates ``tosa_graph`` in place. + + """ op_location = ts.TosaOpLocation() if self.debug_hook: debug_info = self.debug_hook.add( @@ -77,6 +106,21 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + """Define a TOSA operator node. + + Args: + node (torch.fx.Node): FX node being lowered. + tosa_graph (serializer.tosa_serializer.TosaSerializer): Target TOSA graph. + inputs (List[TosaArg]): Input tensor arguments. + output (TosaArg): Output tensor descriptor. + + Returns: + None: Mutates ``tosa_graph`` in place. + + Raises: + ValueError: If input count or dtypes are invalid. + + """ raise NotImplementedError("NodeVisitor must be extended.") @@ -88,12 +132,14 @@ def define_node( def register_node_visitor(visitor): + """Register a concrete ``NodeVisitor`` class for its TOSA specs.""" for tosa_spec in visitor.tosa_specs: _node_visitor_dicts[tosa_spec][visitor.target] = visitor return visitor def get_node_visitors(*args) -> Dict[str, NodeVisitor]: + """Return a mapping from target names to visitor instances for a spec.""" node_visitors = {} tosa_spec = None for arg in args: