File tree Expand file tree Collapse file tree 3 files changed +16
-12
lines changed Expand file tree Collapse file tree 3 files changed +16
-12
lines changed Original file line number Diff line number Diff line change @@ -50,4 +50,5 @@ dependencies:
5050 - typing_extensions
5151 # optional
5252 - cython
53-
53+ - graphviz
54+ - pydot
Original file line number Diff line number Diff line change 99from pytensor import compile
1010from pytensor .compile .function import function
1111from pytensor .configdefaults import config
12- from pytensor .d3viz . formatting import pydot_imported , pydot_imported_msg
12+ from pytensor .printing import pydot_imported , pydot_imported_msg
1313from tests .d3viz import models
1414
1515
Original file line number Diff line number Diff line change 22import pytest
33
44from pytensor import config , function
5- from pytensor .d3viz .formatting import PyDotFormatter , pydot_imported , pydot_imported_msg
5+ from pytensor .d3viz .formatting import PyDotFormatter
6+ from pytensor .printing import pydot_imported , pydot_imported_msg
67
78
89if not pydot_imported :
@@ -21,21 +22,23 @@ def node_counts(self, graph):
2122 nc = dict (zip (a , b ))
2223 return nc
2324
24- def test_mlp (self ):
25+ @pytest .mark .parametrize ("mode" , ["FAST_RUN" , "FAST_COMPILE" ])
26+ def test_mlp (self , mode ):
2527 m = models .Mlp ()
26- f = function (m .inputs , m .outputs )
28+ f = function (m .inputs , m .outputs , mode = mode )
2729 pdf = PyDotFormatter ()
2830 graph = pdf (f )
29- expected = 11
30- if config .mode == "FAST_COMPILE" :
31- expected = 12
31+ if mode == "FAST_RUN" :
32+ expected = 13
33+ elif mode == "FAST_COMPILE" :
34+ expected = 14
3235 assert len (graph .get_nodes ()) == expected
3336 nc = self .node_counts (graph )
3437
35- if config . mode == "FAST_COMPILE " :
36- assert nc ["apply" ] == 6
37- else :
38- assert nc ["apply" ] == 5
38+ if mode == "FAST_RUN " :
39+ assert nc ["apply" ] == 7
40+ elif mode == "FAST_COMPILE" :
41+ assert nc ["apply" ] == 8
3942 assert nc ["output" ] == 1
4043
4144 def test_ofg (self ):
You can’t perform that action at this time.
0 commit comments