|
1 | 1 | import logging |
2 | 2 | import re |
3 | 3 | import traceback as tb |
| 4 | +from collections.abc import Iterable |
4 | 5 | from pathlib import Path |
5 | 6 |
|
| 7 | +import pytensor.misc.pkl_utils |
6 | 8 | from pytensor.compile.function.pfunc import pfunc |
7 | 9 | from pytensor.compile.function.types import orig_function |
| 10 | +from pytensor.compile.mode import Mode |
| 11 | +from pytensor.compile.profiling import ProfileStats |
| 12 | +from pytensor.graph import Variable |
8 | 13 |
|
9 | 14 |
|
10 | 15 | __all__ = ["types", "pfunc"] |
|
15 | 20 |
|
16 | 21 | def function_dump( |
17 | 22 | filename: str | Path, |
18 | | - inputs, |
19 | | - outputs=None, |
20 | | - mode=None, |
21 | | - updates=None, |
22 | | - givens=None, |
23 | | - no_default_updates=False, |
24 | | - accept_inplace=False, |
25 | | - name=None, |
26 | | - rebuild_strict=True, |
27 | | - allow_input_downcast=None, |
28 | | - profile=None, |
29 | | - on_unused_input=None, |
| 23 | + inputs: Iterable[Variable], |
| 24 | + outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None, |
| 25 | + mode: str | Mode | None = None, |
| 26 | + updates: Iterable[tuple[Variable, Variable]] |
| 27 | + | dict[Variable, Variable] |
| 28 | + | None = None, |
| 29 | + givens: Iterable[tuple[Variable, Variable]] |
| 30 | + | dict[Variable, Variable] |
| 31 | + | None = None, |
| 32 | + no_default_updates: bool = False, |
| 33 | + accept_inplace: bool = False, |
| 34 | + name: str | None = None, |
| 35 | + rebuild_strict: bool = True, |
| 36 | + allow_input_downcast: bool | None = None, |
| 37 | + profile: bool | ProfileStats | None = None, |
| 38 | + on_unused_input: str | None = None, |
30 | 39 | extra_tag_to_remove: str | None = None, |
31 | 40 | ): |
32 | 41 | """ |
@@ -60,43 +69,44 @@ def function_dump( |
60 | 69 | `['annotations', 'replacement_of', 'aggregation_scheme', 'roles']` |
61 | 70 |
|
62 | 71 | """ |
63 | | - filename = Path(filename) |
64 | | - d = dict( |
65 | | - inputs=inputs, |
66 | | - outputs=outputs, |
67 | | - mode=mode, |
68 | | - updates=updates, |
69 | | - givens=givens, |
70 | | - no_default_updates=no_default_updates, |
71 | | - accept_inplace=accept_inplace, |
72 | | - name=name, |
73 | | - rebuild_strict=rebuild_strict, |
74 | | - allow_input_downcast=allow_input_downcast, |
75 | | - profile=profile, |
76 | | - on_unused_input=on_unused_input, |
77 | | - ) |
78 | | - with filename.open("wb") as f: |
79 | | - import pytensor.misc.pkl_utils |
80 | | - |
| 72 | + d = { |
| 73 | + "inputs": inputs, |
| 74 | + "outputs": outputs, |
| 75 | + "mode": mode, |
| 76 | + "updates": updates, |
| 77 | + "givens": givens, |
| 78 | + "no_default_updates": no_default_updates, |
| 79 | + "accept_inplace": accept_inplace, |
| 80 | + "name": name, |
| 81 | + "rebuild_strict": rebuild_strict, |
| 82 | + "allow_input_downcast": allow_input_downcast, |
| 83 | + "profile": profile, |
| 84 | + "on_unused_input": on_unused_input, |
| 85 | + } |
| 86 | + with Path(filename).open("wb") as f: |
81 | 87 | pickler = pytensor.misc.pkl_utils.StripPickler( |
82 | 88 | f, protocol=-1, extra_tag_to_remove=extra_tag_to_remove |
83 | 89 | ) |
84 | 90 | pickler.dump(d) |
85 | 91 |
|
86 | 92 |
|
87 | 93 | def function( |
88 | | - inputs, |
89 | | - outputs=None, |
90 | | - mode=None, |
91 | | - updates=None, |
92 | | - givens=None, |
93 | | - no_default_updates=False, |
94 | | - accept_inplace=False, |
95 | | - name=None, |
96 | | - rebuild_strict=True, |
97 | | - allow_input_downcast=None, |
98 | | - profile=None, |
99 | | - on_unused_input=None, |
| 94 | + inputs: Iterable[Variable], |
| 95 | + outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None, |
| 96 | + mode: str | Mode | None = None, |
| 97 | + updates: Iterable[tuple[Variable, Variable]] |
| 98 | + | dict[Variable, Variable] |
| 99 | + | None = None, |
| 100 | + givens: Iterable[tuple[Variable, Variable]] |
| 101 | + | dict[Variable, Variable] |
| 102 | + | None = None, |
| 103 | + no_default_updates: bool = False, |
| 104 | + accept_inplace: bool = False, |
| 105 | + name: str | None = None, |
| 106 | + rebuild_strict: bool = True, |
| 107 | + allow_input_downcast: bool | None = None, |
| 108 | + profile: bool | ProfileStats | None = None, |
| 109 | + on_unused_input: str | None = None, |
100 | 110 | ): |
101 | 111 | """ |
102 | 112 | Return a :class:`callable object <pytensor.compile.function.types.Function>` |
|
0 commit comments