Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 221fbda

Browse files
authored
[deepsparse.analyze] Fix v1 functionality to work with llms (#451) (#452)
* fix equivalent changes made to analyze_v2 such that inference session works for llms; update wanrings to be debug printouts * typo
1 parent 5296b2d commit 221fbda

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

src/sparsezoo/analyze/analysis.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections import defaultdict
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import Any, Dict, List, Optional, Union
25+
from typing import Any, Dict, List, Optional, Tuple, Union
2626

2727
import numpy
2828
import yaml
@@ -913,16 +913,20 @@ def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
913913
analyze
914914
:return: instance of cls
915915
"""
916+
path = None
916917
if isinstance(onnx_file_path, ModelProto):
917918
model_onnx = onnx_file_path
918919
model_name = ""
919920
else:
920-
model_onnx = load_model(onnx_file_path)
921+
# initially do not load the external data, if present
922+
# as not required for node analysis
923+
model_onnx = load_model(onnx_file_path, load_external_data=False)
921924
model_name = str(onnx_file_path)
925+
path = onnx_file_path
922926

923-
model_graph = ONNXGraph(model_onnx)
924-
925-
node_analyses = cls.analyze_nodes(model_graph)
927+
# returns the node analysis and the model graph after loading the model with
928+
# external data
929+
node_analyses, model_graph = cls.analyze_nodes(model_onnx, path=path)
926930

927931
layer_counts, op_counts = get_layer_and_op_counts(model_graph)
928932
layer_counts.update(op_counts)
@@ -1361,12 +1365,19 @@ def pretty_print_summary(self):
13611365
print(f"{footer_key}: {footer_value}")
13621366

13631367
@staticmethod
1364-
def analyze_nodes(model_graph: ONNXGraph) -> List[NodeAnalysis]:
1368+
def analyze_nodes(
1369+
model: ModelProto, path: Optional[str] = None
1370+
) -> Tuple[List[NodeAnalysis], ONNXGraph]:
13651371
"""
13661372
:param: model that contains the nodes to be analyzed
1367-
:return: list of node analyses from model graph
1373+
:return: list of node analyses from model graph and ONNXGraph of loaded model
13681374
"""
1369-
node_shapes, node_dtypes = extract_node_shapes_and_dtypes(model_graph.model)
1375+
node_shapes, node_dtypes = extract_node_shapes_and_dtypes(model, path)
1376+
1377+
if path:
1378+
model = load_model(path)
1379+
1380+
model_graph = ONNXGraph(model)
13701381

13711382
nodes = []
13721383
for node in model_graph.nodes:
@@ -1378,7 +1389,7 @@ def analyze_nodes(model_graph: ONNXGraph) -> List[NodeAnalysis]:
13781389
)
13791390
nodes.append(node_analysis)
13801391

1381-
return nodes
1392+
return nodes, model_graph
13821393

13831394

13841395
def _get_param_count_summary(analysis: ModelAnalysis) -> CountSummary:

src/sparsezoo/utils/calculate_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_ops_dict(
114114

115115
if node.op_type in ["Gemm", "MatMul", "MatMulInteger", "QLinearMatMul"]:
116116
if input_shapes is None:
117-
_LOGGER.warn(
117+
_LOGGER.debug(
118118
"Invalid shape, skipping "
119119
f"{'four block ' if is_four_block_sparse else ''}ops calculation"
120120
f" for {node.name}"
@@ -146,7 +146,7 @@ def get_ops_dict(
146146

147147
if node.op_type in ["Conv", "ConvInteger", "QLinearConv"]:
148148
if input_shapes is None:
149-
_LOGGER.warn(
149+
_LOGGER.debug(
150150
"Invalid shape, skipping "
151151
f"{'four block ' if is_four_block_sparse else ''}ops calculation"
152152
f" for {node.name}"

0 commit comments

Comments
 (0)