2222from collections import defaultdict
2323from dataclasses import dataclass
2424from pathlib import Path
25- from typing import Any , Dict , List , Optional , Union
25+ from typing import Any , Dict , List , Optional , Tuple , Union
2626
2727import numpy
2828import yaml
@@ -913,16 +913,20 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
913913 analyze
914914 :return: instance of cls
915915 """
916+ path = None
916917 if isinstance (onnx_file_path , ModelProto ):
917918 model_onnx = onnx_file_path
918919 model_name = ""
919920 else :
920- model_onnx = load_model (onnx_file_path )
921+ # initially do not load the external data, if present
922+ # as not required for node analysis
923+ model_onnx = load_model (onnx_file_path , load_external_data = False )
921924 model_name = str (onnx_file_path )
925+ path = onnx_file_path
922926
923- model_graph = ONNXGraph ( model_onnx )
924-
925- node_analyses = cls .analyze_nodes (model_graph )
927+ # returns the node analysis and the model graph after loading the model with
928+ # external data
929+ node_analyses , model_graph = cls .analyze_nodes (model_onnx , path = path )
926930
927931 layer_counts , op_counts = get_layer_and_op_counts (model_graph )
928932 layer_counts .update (op_counts )
@@ -1361,12 +1365,19 @@ def pretty_print_summary(self):
13611365 print (f"{ footer_key } : { footer_value } " )
13621366
13631367 @staticmethod
1364- def analyze_nodes (model_graph : ONNXGraph ) -> List [NodeAnalysis ]:
1368+ def analyze_nodes (
1369+ model : ModelProto , path : Optional [str ] = None
1370+ ) -> Tuple [List [NodeAnalysis ], ONNXGraph ]:
13651371 """
13661372 :param: model that contains the nodes to be analyzed
1367- :return: list of node analyses from model graph
1373+ :return: list of node analyses from model graph and ONNXGraph of loaded model
13681374 """
1369- node_shapes , node_dtypes = extract_node_shapes_and_dtypes (model_graph .model )
1375+ node_shapes , node_dtypes = extract_node_shapes_and_dtypes (model , path )
1376+
1377+ if path :
1378+ model = load_model (path )
1379+
1380+ model_graph = ONNXGraph (model )
13701381
13711382 nodes = []
13721383 for node in model_graph .nodes :
@@ -1378,7 +1389,7 @@ def analyze_nodes(model_graph: ONNXGraph) -> List[NodeAnalysis]:
13781389 )
13791390 nodes .append (node_analysis )
13801391
1381- return nodes
1392+ return nodes , model_graph
13821393
13831394
13841395def _get_param_count_summary (analysis : ModelAnalysis ) -> CountSummary :
0 commit comments