22import pymc as pm
33import pytensor .tensor as pt
44import pytest
5+ from pytensor import config , shared
56from pytensor .graph import Constant , FunctionGraph , node_rewriter
67from pytensor .graph .rewriting .basic import in2out
78from pytensor .tensor .exceptions import NotScalarConstantError
1314 ModelObservedRV ,
1415 ModelPotential ,
1516 ModelVar ,
17+ clone_model ,
1618 fgraph_from_model ,
1719 model_deterministic ,
1820 model_free_rv ,
@@ -76,17 +78,22 @@ def test_basic():
7678 )
7779
7880
81+ def same_storage (shared_1 , shared_2 ) -> bool :
82+ """Check if two shared variables have the same storage containers (i.e., they point to the same memory)."""
83+ return shared_1 .container .storage is shared_2 .container .storage
84+
85+
7986@pytest .mark .parametrize ("inline_views" , (False , True ))
8087def test_data (inline_views ):
81- """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
88+ """Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly.
8289
83- Everything should be preserved across new and old models, except for shared RNGs
90+ All model-related shared variables should be copied to become independent across models.
8491 """
8592 with pm .Model (coords_mutable = {"test_dim" : range (3 )}) as m_old :
8693 x = pm .MutableData ("x" , [0.0 , 1.0 , 2.0 ], dims = ("test_dim" ,))
8794 y = pm .MutableData ("y" , [10.0 , 11.0 , 12.0 ], dims = ("test_dim" ,))
88- b0 = pm .ConstantData ("b0" , np .zeros (3 ))
89- b1 = pm .Normal ("b1" )
95+ b0 = pm .ConstantData ("b0" , np .zeros (( 1 ,) ))
96+ b1 = pm .DiracDelta ("b1" , 1.0 )
9097 mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
9198 obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
9299
@@ -109,22 +116,46 @@ def test_data(inline_views):
109116
110117 m_new = model_from_fgraph (m_fgraph )
111118
112- # ConstantData is preserved
113- assert np .all (m_new ["b0" ].data == m_old ["b0" ].data )
114-
115- # Shared non-rng shared variables are preserved
116- assert m_new ["x" ].container is x .container
117- assert m_new ["y" ].container is y .container
119+ # The rv-data mapping is preserved
118120 assert m_new .rvs_to_values [m_new ["obs" ]] is m_new ["y" ]
119121
120- # Shared rng shared variables are not preserved
121- assert m_new ["b1" ]. owner . inputs [ 0 ]. container is not m_old ["b1" ]. owner . inputs [ 0 ]. container
122+ # ConstantData is still accessible as a model variable
123+ np . testing . assert_array_equal ( m_new ["b0" ], m_old ["b0" ])
122124
123- with m_old :
124- pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
125+ # Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
126+ assert not same_storage (m_new ["x" ], x )
127+ assert not same_storage (m_new ["y" ], y )
128+ assert not same_storage (m_new ["b1" ].owner .inputs [0 ], b1 .owner .inputs [0 ])
129+ assert not same_storage (m_new .dim_lengths ["test_dim" ], m_old .dim_lengths ["test_dim" ])
125130
131+ # Updating model shared variables in new model, doesn't affect old one
132+ with m_new :
133+ pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
126134 assert m_new .dim_lengths ["test_dim" ].eval () == 2
127- np .testing .assert_array_almost_equal (pm .draw (m_new ["x" ], random_seed = 63 ), [100.0 , 200.0 ])
135+ assert m_old .dim_lengths ["test_dim" ].eval () == 3
136+ np .testing .assert_allclose (pm .draw (m_new ["mu" ]), [100.0 , 200.0 ])
137+ np .testing .assert_allclose (pm .draw (m_old ["mu" ]), [0.0 , 1.0 , 2.0 ], atol = 1e-6 )
138+
139+
140+ @config .change_flags (floatX = "float64" ) # Avoid downcasting Ops in the graph
141+ def test_shared_variable ():
142+ """Test that user defined shared variables (other than RNGs) aren't copied."""
143+ x = shared (np .array ([1 , 2 , 3.0 ]), name = "x" )
144+ y = shared (np .array ([1 , 2 , 3.0 ]), name = "y" )
145+
146+ with pm .Model () as m_old :
147+ test = pm .Normal ("test" , mu = x , observed = y )
148+
149+ assert test .owner .inputs [3 ] is x
150+ assert m_old .rvs_to_values [test ] is y
151+
152+ m_new = clone_model (m_old )
153+ test_new = m_new ["test" ]
154+ # Shared Variables are cloned but still point to the same memory
155+ assert test_new .owner .inputs [3 ] is not x
156+ assert m_new .rvs_to_values [test_new ] is not y
157+ assert same_storage (test_new .owner .inputs [3 ], x )
158+ assert same_storage (m_new .rvs_to_values [test_new ], y )
128159
129160
130161@pytest .mark .parametrize ("inline_views" , (False , True ))
0 commit comments