55unit tests or regression tests.
66"""
77
8- import os
98import pickle
109import sys
11- import tempfile
12- import zipfile
13- from collections import Counter
14- from contextlib import closing
15- from io import BytesIO
16- from pickle import HIGHEST_PROTOCOL
17-
18- import numpy as np
1910
2011import pytensor
2112
2213
23- try :
24- from pickle import DEFAULT_PROTOCOL
25- except ImportError :
26- DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
27-
28- from pytensor .compile .sharedvalue import SharedVariable
29-
30-
3114__docformat__ = "restructuredtext en"
3215__authors__ = "Pascal Lamblin " "PyMC Developers " "PyTensor Developers "
3316__copyright__ = "Copyright 2013, Universite de Montreal"
@@ -49,16 +32,18 @@ class StripPickler(Pickler):
4932
5033 ..code-block:: python
5134
52- fn_args = dict(inputs=inputs,
53- outputs=outputs,
54- updates=updates)
55- dest_pkl = 'my_test.pkl'
56- with open(dest_pkl, 'wb') as f:
35+ fn_args = {
36+ "inputs": inputs,
37+ "outputs": outputs,
38+ "updates": updates,
39+ }
40+ dest_pkl = "my_test.pkl"
41+ with Path(dest_pkl).open("wb") as f:
5742 strip_pickler = StripPickler(f, protocol=-1)
5843 strip_pickler.dump(fn_args)
5944 """
6045
61- def __init__ (self , file , protocol = 0 , extra_tag_to_remove = None ):
46+ def __init__ (self , file , protocol : int = 0 , extra_tag_to_remove : str | None = None ):
6247 # Can't use super as Pickler isn't a new style class
6348 super ().__init__ (file , protocol )
6449 self .tag_to_remove = ["trace" , "test_value" ]
@@ -77,226 +62,3 @@ def save(self, obj):
7762 del obj .__dict__ ["__doc__" ]
7863
7964 return Pickler .save (self , obj )
80-
81-
82- class PersistentNdarrayID :
83- """Persist ndarrays in an object by saving them to a zip file.
84-
85- :param zip_file: A zip file handle that the NumPy arrays will be saved to.
86- :type zip_file: :class:`zipfile.ZipFile`
87-
88-
89- .. note:
90- The convention for persistent ids given by this class and its derived
91- classes is that the name should take the form `type.name` where `type`
92- can be used by the persistent loader to determine how to load the
93- object, while `name` is human-readable and as descriptive as possible.
94-
95- """
96-
97- def __init__ (self , zip_file ):
98- self .zip_file = zip_file
99- self .count = 0
100- self .seen = {}
101-
102- def _resolve_name (self , obj ):
103- """Determine the name the object should be saved under."""
104- name = f"array_{ self .count } "
105- self .count += 1
106- return name
107-
108- def __call__ (self , obj ):
109- if isinstance (obj , np .ndarray ):
110- if id (obj ) not in self .seen :
111-
112- def write_array (f ):
113- np .lib .format .write_array (f , obj )
114-
115- name = self ._resolve_name (obj )
116- zipadd (write_array , self .zip_file , name )
117- self .seen [id (obj )] = f"ndarray.{ name } "
118- return self .seen [id (obj )]
119-
120-
121- class PersistentSharedVariableID (PersistentNdarrayID ):
122- """Uses shared variable names when persisting to zip file.
123-
124- If a shared variable has a name, this name is used as the name of the
125- NPY file inside of the zip file. NumPy arrays that aren't matched to a
126- shared variable are persisted as usual (i.e. `array_0`, `array_1`,
127- etc.)
128-
129- :param allow_unnamed: Allow shared variables without a name to be
130- persisted. Defaults to ``True``.
131- :type allow_unnamed: bool, optional
132-
133- :param allow_duplicates: Allow multiple shared variables to have the same
134- name, in which case they will be numbered e.g. `x`, `x_2`, `x_3`, etc.
135- Defaults to ``True``.
136- :type allow_duplicates: bool, optional
137-
138- :raises ValueError
139- If an unnamed shared variable is encountered and `allow_unnamed` is
140- ``False``, or if two shared variables have the same name, and
141- `allow_duplicates` is ``False``.
142-
143- """
144-
145- def __init__ (self , zip_file , allow_unnamed = True , allow_duplicates = True ):
146- super ().__init__ (zip_file )
147- self .name_counter = Counter ()
148- self .ndarray_names = {}
149- self .allow_unnamed = allow_unnamed
150- self .allow_duplicates = allow_duplicates
151-
152- def _resolve_name (self , obj ):
153- if id (obj ) in self .ndarray_names :
154- name = self .ndarray_names [id (obj )]
155- count = self .name_counter [name ]
156- self .name_counter [name ] += 1
157- if count :
158- if not self .allow_duplicates :
159- raise ValueError (
160- f"multiple shared variables with the name `{ name } ` found"
161- )
162- name = f"{ name } _{ count + 1 } "
163- return name
164- return super ()._resolve_name (obj )
165-
166- def __call__ (self , obj ):
167- if isinstance (obj , SharedVariable ):
168- if obj .name :
169- if obj .name == "pkl" :
170- ValueError ("can't pickle shared variable with name `pkl`" )
171- self .ndarray_names [id (obj .container .storage [0 ])] = obj .name
172- elif not self .allow_unnamed :
173- raise ValueError (f"unnamed shared variable, { obj } " )
174- return super ().__call__ (obj )
175-
176-
177- class PersistentNdarrayLoad :
178- """Load NumPy arrays that were persisted to a zip file when pickling.
179-
180- :param zip_file: The zip file handle in which the NumPy arrays are saved.
181- :type zip_file: :class:`zipfile.ZipFile`
182-
183- """
184-
185- def __init__ (self , zip_file ):
186- self .zip_file = zip_file
187- self .cache = {}
188-
189- def __call__ (self , persid ):
190- array_type , name = persid .split ("." )
191- del array_type
192- # array_type was used for switching gpu/cpu arrays
193- # it is better to put these into sublclasses properly
194- # this is more work but better logic
195- if name in self .cache :
196- return self .cache [name ]
197- ret = None
198- with self .zip_file .open (name ) as f :
199- ret = np .lib .format .read_array (f )
200- self .cache [name ] = ret
201- return ret
202-
203-
204- def dump (
205- obj ,
206- file_handler ,
207- protocol = DEFAULT_PROTOCOL ,
208- persistent_id = PersistentSharedVariableID ,
209- ):
210- """Pickles an object to a zip file using external persistence.
211-
212- :param obj: The object to pickle.
213- :type obj: object
214-
215- :param file_handler: The file handle to save the object to.
216- :type file_handler: file
217-
218- :param protocol: The pickling protocol to use. Unlike Python's built-in
219- pickle, the default is set to `2` instead of 0 for Python 2. The
220- Python 3 default (level 3) is maintained.
221- :type protocol: int, optional
222-
223- :param persistent_id: The callable that persists certain objects in the
224- object hierarchy to separate files inside of the zip file. For example,
225- :class:`PersistentNdarrayID` saves any :class:`numpy.ndarray` to a
226- separate NPY file inside of the zip file.
227- :type persistent_id: callable
228-
229- .. versionadded:: 0.8
230-
231- .. note::
232- The final file is simply a zipped file containing at least one file,
233- `pkl`, which contains the pickled object. It can contain any other
234- number of external objects. Note that the zip files are compatible with
235- NumPy's :func:`numpy.load` function.
236-
237- >>> import pytensor
238- >>> foo_1 = pytensor.shared(0, name='foo')
239- >>> foo_2 = pytensor.shared(1, name='foo')
240- >>> with open('model.zip', 'wb') as f:
241- ... dump((foo_1, foo_2, np.array(2)), f)
242- >>> list(np.load('model.zip').keys())
243- ['foo', 'foo_2', 'array_0', 'pkl']
244- >>> np.load('model.zip')['foo']
245- array(0)
246- >>> with open('model.zip', 'rb') as f:
247- ... foo_1, foo_2, array = load(f)
248- >>> array
249- array(2)
250-
251- """
252- with closing (
253- zipfile .ZipFile (file_handler , "w" , zipfile .ZIP_DEFLATED , allowZip64 = True )
254- ) as zip_file :
255-
256- def func (f ):
257- p = pickle .Pickler (f , protocol = protocol )
258- p .persistent_id = persistent_id (zip_file )
259- p .dump (obj )
260-
261- zipadd (func , zip_file , "pkl" )
262-
263-
264- def load (f , persistent_load = PersistentNdarrayLoad ):
265- """Load a file that was dumped to a zip file.
266-
267- :param f: The file handle to the zip file to load the object from.
268- :type f: file
269-
270- :param persistent_load: The persistent loading function to use for
271- unpickling. This must be compatible with the `persistent_id` function
272- used when pickling.
273- :type persistent_load: callable, optional
274-
275- .. versionadded:: 0.8
276- """
277- with closing (zipfile .ZipFile (f , "r" )) as zip_file :
278- p = pickle .Unpickler (BytesIO (zip_file .open ("pkl" ).read ()))
279- p .persistent_load = persistent_load (zip_file )
280- return p .load ()
281-
282-
283- def zipadd (func , zip_file , name ):
284- """Calls a function with a file object, saving it to a zip file.
285-
286- :param func: The function to call.
287- :type func: callable
288-
289- :param zip_file: The zip file that `func` should write its data to.
290- :type zip_file: :class:`zipfile.ZipFile`
291-
292- :param name: The name of the file inside of the zipped archive that `func`
293- should save its data to.
294- :type name: str
295-
296- """
297- with tempfile .NamedTemporaryFile ("wb" , delete = False ) as temp_file :
298- func (temp_file )
299- temp_file .close ()
300- zip_file .write (temp_file .name , arcname = name )
301- if os .path .isfile (temp_file .name ):
302- os .remove (temp_file .name )
0 commit comments