@@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto):
276276 with InsertionPoint (self ._b ), Location .name (node .name ):
277277 op_type = node .op_type
278278 # Handle special op types that materialize to non-op IR constructs.
279+ # Handlers return True if the op was handled, else this function
280+ # should process it as a general node.
279281 special_key = f"_handle_node_{ op_type } "
280282 if hasattr (self , special_key ):
281- getattr (self , special_key )(node )
282- return
283+ was_handled = getattr (self , special_key )(node )
284+ if was_handled :
285+ return
283286
284287 # General node import.
285288 input_values = []
@@ -333,16 +336,19 @@ def import_attributes(
333336 )
334337 attrs [f"torch.onnx.{ onnx_attr .name } " ] = handler (onnx_attr , self ._cc )
335338
336- def import_initializer (self , initializer : onnx .TensorProto ) -> Value :
337- with InsertionPoint (self ._b ), Location .name (initializer .name ):
339+ def import_initializer (self , initializer : onnx .TensorProto , extern_name : str = None ) -> Value :
340+ # If an explicitly specified name is given, use that; otherwise, pick
341+ # up the name from the tensor proto itself
342+ iname = extern_name if extern_name else initializer .name
343+ with InsertionPoint (self ._b ), Location .name (iname ):
338344 value_attr = self ._cc .tensor_proto_to_attr (initializer )
339345 vtensor_type = self ._cc .tensor_proto_to_type (initializer )
340346 literal_op = Operation .create (
341347 name = "torch.vtensor.literal" ,
342348 results = [vtensor_type ],
343349 attributes = {"value" : value_attr },
344350 )
345- self ._nv_map [initializer . name ] = literal_op .result
351+ self ._nv_map [iname ] = literal_op .result
346352 return literal_op .result
347353
348354 def _get_immediate_tensor (self , name : str ) -> np .array :
@@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array:
366372 f"Unhandled ONNX TensorProto immediate data: { initializer } "
367373 )
368374
369- def _handle_node_ConstantOfShape (self , node : onnx .NodeProto ):
375+ def _handle_node_Constant (self , node : onnx .NodeProto ) -> bool :
376+ # Special case only for constants specified by value attribute (for now)
377+ value_proto = _get_attr (node , "value" , False )
378+ if not value_proto :
379+ return False
380+
381+ # Produce an initializer for the constant, so that it can be used in
382+ # combination with other ops, such as ConstantOfShape, requiring
383+ # a constant input
384+ assert value_proto .type == onnx .AttributeProto .AttributeType .TENSOR
385+ assert len (node .output ) == 1
386+ const_name = node .output [0 ]
387+ self .import_initializer (value_proto .t , const_name )
388+ self ._gi .initializer_map [const_name ] = value_proto .t
389+ return True
390+
391+ def _handle_node_ConstantOfShape (self , node : onnx .NodeProto ) -> bool :
370392 # This op is special: It has an input of the shape, and in full generality
371393 # could involve eager production of constants of variable size. In
372394 # practice, the DNN profile for ONNX makes this very difficult to do
@@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
394416 attributes = {"value" : value_attr },
395417 )
396418 self ._nv_map [node .output [0 ]] = literal_op .result
419+ return True
397420
398421
399422class ContextCache :
@@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
515538 onnx .TensorProto .DataType .FLOAT : lambda tp , shape : DenseElementsAttr .get_splat (
516539 RankedTensorType .get (shape , F32Type .get ()), FloatAttr .get_f32 (tp .float_data [0 ])
517540 ),
541+ onnx .TensorProto .DataType .INT64 : lambda tp , shape : DenseElementsAttr .get_splat (
542+ RankedTensorType .get (shape , IntegerType .get_signed (64 )), IntegerAttr .get (
543+ IntegerType .get_signed (64 ), int .from_bytes (tp .raw_data , "little" ,
544+ signed = True ) if tp .HasField ("raw_data" ) else tp .int64_data [0 ])
545+ ),
518546 # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
519547}
520548
@@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
605633}
606634
607635
608- def _get_attr (node : onnx .NodeProto , attr_name : str ) -> onnx .AttributeProto :
636+ def _get_attr (node : onnx .NodeProto , attr_name : str , is_required : bool = True ) -> onnx .AttributeProto :
609637 for attr in node .attribute :
610638 if attr .name == attr_name :
611639 return attr
612- else :
640+ if is_required :
613641 raise OnnxImportError (f"Required attribute { attr_name } not found in { node } " )
642+ return None
0 commit comments