@@ -310,25 +310,29 @@ class EquilibriumDB(RewriteDatabase):
310310 """
311311
312312 def __init__ (
313- self , ignore_newtrees : bool = True , tracks_on_change_inputs : bool = False
313+ self ,
314+ ignore_newtrees : bool = True ,
315+ tracks_on_change_inputs : bool = False ,
316+ eq_rewriter_class = pytensor_rewriting .EquilibriumGraphRewriter ,
314317 ):
315318 """
316319
317320 Parameters
318321 ----------
319322 ignore_newtrees
320- If ``False``, apply rewrites to new nodes introduced during
321- rewriting.
322-
323+ If ``False``, apply rewrites to new nodes introduced during rewritings.
323324 tracks_on_change_inputs
324325 If ``True``, re-apply rewrites on nodes with changed inputs.
326+ eq_rewriter_class: EquilibriumGraphRewriter class, optional
327+ The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter.
325328
326329 """
327330 super ().__init__ ()
328331 self .ignore_newtrees = ignore_newtrees
329332 self .tracks_on_change_inputs = tracks_on_change_inputs
330333 self .__final__ : dict [str , bool ] = {}
331334 self .__cleanup__ : dict [str , bool ] = {}
335+ self .eq_rewriter_class = eq_rewriter_class
332336
333337 def register (
334338 self ,
@@ -360,7 +364,7 @@ def query(self, *tags, **kwtags):
360364 final_rewriters = None
361365 if len (cleanup_rewriters ) == 0 :
362366 cleanup_rewriters = None
363- return pytensor_rewriting . EquilibriumGraphRewriter (
367+ return self . eq_rewriter_class (
364368 rewriters ,
365369 max_use_ratio = config .optdb__max_use_ratio ,
366370 ignore_newtrees = self .ignore_newtrees ,
0 commit comments