Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit a610202

Browse files
authored
Adds a numpy_array_representer to yaml (#455)
on runtime, to avoid serialization issues
1 parent a8e3379 commit a610202

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

src/sparsezoo/analyze/analysis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt
3131

3232
from sparsezoo import Model
33+
from sparsezoo.analyze.utils.helpers import numpy_array_representer
3334
from sparsezoo.analyze.utils.models import (
3435
DenseSparseOps,
3536
Entry,
@@ -93,6 +94,9 @@
9394
"Gemm",
9495
}
9596

97+
# add numpy array representer to yaml
98+
yaml.add_representer(numpy.ndarray, numpy_array_representer)
99+
96100

97101
class YAMLSerializableBaseModel(BaseModel):
98102
"""

src/sparsezoo/analyze/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
# limitations under the License.
1414

1515
# flake8: noqa
16+
17+
from .helpers import *
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import numpy
17+
import yaml
18+
19+
20+
__all__ = [
21+
"numpy_array_representer",
22+
]
23+
24+
25+
def numpy_array_representer(dumper: yaml.Dumper, data: numpy.ndarray):
26+
"""
27+
A representer for numpy arrays to be used with pyyaml
28+
"""
29+
return dumper.represent_sequence("tag:yaml.org,2002:seq", data.tolist())

0 commit comments

Comments
 (0)