@@ -42,56 +42,56 @@ def interpolate_two_points(start, stop, steps):
4242
4343if __name__ == '__main__' :
4444 # variable or fixed length?
45- variable = True
4645 num_paths = 1000
4746
48- save_dir = f"out/baselines/mueller"
49- if variable :
50- save_dir += "-variable"
51-
52- os .makedirs (save_dir , exist_ok = True )
53-
54- xi = 5
55- dt = 1e-4
56- T = 275e-4
57- N = 0 if variable else int (T / dt )
58-
59- system = System .from_name ('mueller_brown' , float ('inf' ))
60- initial_trajectory = [t .reshape (1 , 2 ) for t in interpolate (jnp .array ([system .A , system .B ]), 100 if variable else N )]
61-
62- @jax .jit
63- def step (_x , _key ):
64- """Perform one step of forward euler"""
65- return _x - dt * system .dUdx (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
66-
67-
68- tps_config = tps1 .FirstOrderSystem (
69- jax .jit (lambda s : jnp .linalg .norm (s - system .A ) <= 0.1 ),
70- jax .jit (lambda s : jnp .linalg .norm (s - system .B ) <= 0.1 ),
71- step
72- )
73-
74- for method , name in [
75- (tps1 .one_way_shooting , 'one-way-shooting' ),
76- (tps1 .two_way_shooting , 'two-way-shooting' ),
77- ]:
78- if os .path .exists (f'{ save_dir } /paths-{ name } .npy' ) and os .path .exists (f'{ save_dir } /stats-{ name } .json' ):
79- print (f"Skipping { name } because the results are already present" )
80-
81- paths = np .load (f'{ save_dir } /paths-{ name } .npy' , allow_pickle = True )
82- paths = [jnp .array (p .astype (np .float32 )) for p in paths ]
83- with open (f'{ save_dir } /stats-{ name } .json' , 'r' ) as fp :
84- statistics = json .load (fp )
85- else :
86- print ('Generating paths for' , name )
87- paths , statistics = tps1 .mcmc_shooting (tps_config , method , initial_trajectory , num_paths ,
88- jax .random .PRNGKey (1 ), warmup = 0 , fixed_length = N )
89-
90- paths = [jnp .array (p ) for p in paths ]
91-
92- np .save (f'{ save_dir } /paths-{ name } .npy' , np .array (paths , dtype = object ), allow_pickle = True )
93- with open (f'{ save_dir } /stats-{ name } .json' , 'w' ) as fp :
94- json .dump (statistics , fp )
95-
96- system .plot (trajectories = paths )
97- show_or_save_fig (save_dir , f'mueller-{ name } ' , 'pdf' )
47+ for variable in [False , True ]:
48+ save_dir = f"out/baselines/mueller"
49+ if variable :
50+ save_dir += "-variable"
51+
52+ os .makedirs (save_dir , exist_ok = True )
53+
54+ xi = 5
55+ dt = 1e-4
56+ T = 275e-4
57+ N = 0 if variable else int (T / dt )
58+
59+ system = System .from_name ('mueller_brown' , float ('inf' ))
60+ initial_trajectory = [t .reshape (1 , 2 ) for t in interpolate (jnp .array ([system .A , system .B ]), 100 if variable else N )]
61+
62+ @jax .jit
63+ def step (_x , _key ):
64+ """Perform one step of forward euler"""
65+ return _x - dt * system .dUdx (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
66+
67+
68+ tps_config = tps1 .FirstOrderSystem (
69+ jax .jit (lambda s : jnp .linalg .norm (s - system .A ) <= 0.1 ),
70+ jax .jit (lambda s : jnp .linalg .norm (s - system .B ) <= 0.1 ),
71+ step
72+ )
73+
74+ for method , name in [
75+ (tps1 .one_way_shooting , 'one-way-shooting' ),
76+ (tps1 .two_way_shooting , 'two-way-shooting' ),
77+ ]:
78+ if os .path .exists (f'{ save_dir } /paths-{ name } .npy' ) and os .path .exists (f'{ save_dir } /stats-{ name } .json' ):
79+ print (f"Skipping { name } because the results are already present" )
80+
81+ paths = np .load (f'{ save_dir } /paths-{ name } .npy' , allow_pickle = True )
82+ paths = [jnp .array (p .astype (np .float32 )) for p in paths ]
83+ with open (f'{ save_dir } /stats-{ name } .json' , 'r' ) as fp :
84+ statistics = json .load (fp )
85+ else :
86+ print ('Generating paths for' , name )
87+ paths , statistics = tps1 .mcmc_shooting (tps_config , method , initial_trajectory , num_paths ,
88+ jax .random .PRNGKey (1 ), warmup = 0 , fixed_length = N )
89+
90+ paths = [jnp .array (p ) for p in paths ]
91+
92+ np .save (f'{ save_dir } /paths-{ name } .npy' , np .array (paths , dtype = object ), allow_pickle = True )
93+ with open (f'{ save_dir } /stats-{ name } .json' , 'w' ) as fp :
94+ json .dump (statistics , fp )
95+
96+ system .plot (trajectories = paths )
97+ show_or_save_fig (save_dir , f'mueller-{ name } ' , 'pdf' )
0 commit comments