@@ -44,32 +44,23 @@ def rewrite_graph(
4444 """
4545 from pytensor .compile import optdb
4646
47- return_fgraph = False
4847 if isinstance (graph , FunctionGraph ):
4948 fgraph = graph
50- return_fgraph = True
5149 else :
52- if isinstance (graph , list | tuple ):
53- outputs = graph
54- else :
55- assert isinstance (graph , Variable )
56- outputs = [graph ]
57-
50+ outputs = [graph ] if isinstance (graph , Variable ) else graph
5851 fgraph = FunctionGraph (outputs = outputs , clone = clone )
5952
6053 query_rewrites = optdb .query (RewriteDatabaseQuery (include = include , ** kwargs ))
61- _ = query_rewrites .rewrite (fgraph )
54+ query_rewrites .rewrite (fgraph )
6255
63- if custom_rewrite :
56+ if custom_rewrite is not None :
6457 custom_rewrite .rewrite (fgraph )
6558
66- if return_fgraph :
59+ if isinstance ( graph , FunctionGraph ) :
6760 return fgraph
68- else :
69- if isinstance (graph , list | tuple ):
70- return fgraph .outputs
71- else :
72- return fgraph .outputs [0 ]
61+ if isinstance (graph , Variable ):
62+ return fgraph .outputs [0 ]
63+ return fgraph .outputs
7364
7465
7566def is_same_graph_with_merge (
@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
9081 """
9182 from pytensor .graph .rewriting .basic import MergeOptimizer
9283
93- if givens is None :
94- givens = {}
95- givens = dict (givens )
84+ givens = {} if givens is None else dict (givens )
9685
9786 # Copy variables since the MergeOptimizer will modify them.
98- copied = copy .deepcopy ((var1 , var2 , givens ))
99- vars = copied [0 :2 ]
100- givens = copied [2 ]
87+ * vars , givens = copy .deepcopy ((var1 , var2 , givens ))
10188 # Create FunctionGraph.
10289 inputs = list (graph_inputs (vars ))
10390 # The clone isn't needed as we did a deepcopy and we cloning will
@@ -120,8 +107,7 @@ def is_same_graph_with_merge(
120107 # Comparing two single-Variable graphs: they are equal if they are
121108 # the same Variable.
122109 return vars_replaced [0 ] == vars_replaced [1 ]
123- else :
124- return o1 is o2
110+ return o1 is o2
125111
126112
127113def is_same_graph (
@@ -171,71 +157,58 @@ def is_same_graph(
171157 ====== ====== ====== ======
172158
173159 """
174- use_equal_computations = True
175-
176- if givens is None :
177- givens = {}
178- givens = dict (givens )
160+ givens = {} if givens is None else dict (givens )
179161
180162 # Get result from the merge-based function.
181163 rval1 = is_same_graph_with_merge (var1 = var1 , var2 = var2 , givens = givens )
182164
183- if givens :
184- # We need to build the `in_xs` and `in_ys` lists. To do this, we need
185- # to be able to tell whether a variable belongs to the computational
186- # graph of `var1` or `var2`.
187- # The typical case we want to handle is when `to_replace` belongs to
188- # one of these graphs, and `replace_by` belongs to the other one. In
189- # other situations, the current implementation of `equal_computations`
190- # is probably not appropriate, so we do not call it.
191- ok = True
192- in_xs = []
193- in_ys = []
194- # Compute the sets of all variables found in each computational graph.
195- inputs_var1 = graph_inputs ([var1 ])
196- inputs_var2 = graph_inputs ([var2 ])
197- all_vars = [
198- set (vars_between (v_i , v_o ))
199- for v_i , v_o in ((inputs_var1 , [var1 ]), (inputs_var2 , [var2 ]))
200- ]
201-
202- def in_var (x , k ):
203- # Return True iff `x` is in computation graph of variable `vark`.
204- return x in all_vars [k - 1 ]
165+ if not givens :
166+ rval2 = equal_computations (xs = [var1 ], ys = [var2 ])
167+ assert rval1 == rval2
168+ return rval1
169+
170+ # We need to build the `in_xs` and `in_ys` lists. To do this, we need
171+ # to be able to tell whether a variable belongs to the computational
172+ # graph of `var1` or `var2`.
173+ # The typical case we want to handle is when `to_replace` belongs to
174+ # one of these graphs, and `replace_by` belongs to the other one. In
175+ # other situations, the current implementation of `equal_computations`
176+ # is probably not appropriate, so we do not call it.
177+ use_equal_computations = True
178+ in_xs = []
179+ in_ys = []
180+ # Compute the sets of all variables found in each computational graph.
181+ inputs_var1 = graph_inputs ([var1 ])
182+ inputs_var2 = graph_inputs ([var2 ])
183+ all_vars1 = set (vars_between (inputs_var1 , [var1 ]))
184+ all_vars2 = set (vars_between (inputs_var2 , [var2 ]))
205185
206- for to_replace , replace_by in givens .items ():
207- # Map a substitution variable to the computational graphs it
208- # belongs to.
209- inside = {
210- v : [in_var (v , k ) for k in (1 , 2 )] for v in (to_replace , replace_by )
211- }
212- if (
213- inside [to_replace ][0 ]
214- and not inside [to_replace ][1 ]
215- and inside [replace_by ][1 ]
216- and not inside [replace_by ][0 ]
217- ):
218- # Substitute variable in `var1` by one from `var2`.
219- in_xs .append (to_replace )
220- in_ys .append (replace_by )
221- elif (
222- inside [to_replace ][1 ]
223- and not inside [to_replace ][0 ]
224- and inside [replace_by ][0 ]
225- and not inside [replace_by ][1 ]
226- ):
227- # Substitute variable in `var2` by one from `var1`.
228- in_xs .append (replace_by )
229- in_ys .append (to_replace )
230- else :
231- ok = False
232- break
233- if not ok :
234- # We cannot directly use `equal_computations`.
186+ for to_replace , replace_by in givens .items ():
187+ # Map a substitution variable to the computational graphs it
188+ # belongs to.
189+ inside = {v : [v in all_vars1 , v in all_vars2 ] for v in (to_replace , replace_by )}
190+ if (
191+ inside [to_replace ][0 ]
192+ and not inside [to_replace ][1 ]
193+ and inside [replace_by ][1 ]
194+ and not inside [replace_by ][0 ]
195+ ):
196+ # Substitute variable in `var1` by one from `var2`.
197+ in_xs .append (to_replace )
198+ in_ys .append (replace_by )
199+ elif (
200+ inside [to_replace ][1 ]
201+ and not inside [to_replace ][0 ]
202+ and inside [replace_by ][0 ]
203+ and not inside [replace_by ][1 ]
204+ ):
205+ # Substitute variable in `var2` by one from `var1`.
206+ in_xs .append (replace_by )
207+ in_ys .append (to_replace )
208+ else :
235209 use_equal_computations = False
236- else :
237- in_xs = None
238- in_ys = None
210+ break
211+
239212 if use_equal_computations :
240213 rval2 = equal_computations (xs = [var1 ], ys = [var2 ], in_xs = in_xs , in_ys = in_ys )
241214 assert rval2 == rval1
0 commit comments