Skip to content

Commit 6d29361

Browse files
committed
better tests for the bootstrapping init
1 parent b18f499 commit 6d29361

File tree

1 file changed

+70
-3
lines changed

1 file changed

+70
-3
lines changed

paths_cli/tests/commands/test_bootstrap_init.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,69 @@
11
import pytest
22
from click.testing import CliRunner
33
from unittest.mock import patch
4+
import numpy as np
45

56
from paths_cli.commands.bootstrap_init import *
67
import openpathsampling as paths
8+
from openpathsampling.engines import toy
9+
from openpathsampling.tests.test_helpers import make_1d_traj
10+
11+
@pytest.fixture
12+
def toy_2_state_engine():
13+
pes = (
14+
toy.OuterWalls([1.0, 1.0], [0.0, 0.0]) +
15+
toy.Gaussian(-1.0, [12.0, 12.0], [-0.5, 0.0]) +
16+
toy.Gaussian(-1.0, [12.0, 12.0], [0.5, 0.0])
17+
)
18+
topology=toy.Topology(
19+
n_spatial = 2,
20+
masses =[1.0, 1.0],
21+
pes = pes
22+
)
23+
integ = toy.LangevinBAOABIntegrator(dt=0.02, temperature=0.1, gamma=2.5)
24+
options = {
25+
'integ': integ,
26+
'n_frames_max': 5000,
27+
'n_steps_per_frame': 1
28+
}
29+
30+
engine = toy.Engine(
31+
options=options,
32+
topology=topology
33+
)
34+
return engine
35+
36+
@pytest.fixture
37+
def toy_2_state_cv():
38+
return paths.FunctionCV("x", lambda s: s.xyz[0][0])
39+
40+
@pytest.fixture
41+
def toy_2_state_volumes(toy_2_state_cv):
42+
state_A = paths.CVDefinedVolume(
43+
toy_2_state_cv,
44+
float("-inf"),
45+
-0.3,
46+
).named("A")
47+
state_B = paths.CVDefinedVolume(
48+
toy_2_state_cv,
49+
0.3,
50+
float("inf"),
51+
).named("B")
52+
return state_A, state_B
53+
54+
@pytest.fixture
55+
def toy_2_state_tis(toy_2_state_cv, toy_2_state_volumes):
56+
state_A, state_B = toy_2_state_volumes
57+
interfaces = paths.VolumeInterfaceSet(
58+
toy_2_state_cv,
59+
float("-inf"),
60+
[-0.3, -0.2, -0.1],
61+
)
62+
tis = paths.MISTISNetwork(
63+
[(state_A, interfaces, state_B)],
64+
)
65+
return tis
66+
767

868
def print_test(init_frame, network, engine, transition, output_storage):
969
print(init_frame.__uuid__)
@@ -54,12 +114,19 @@ def test_bootstrap_init(tis_fixture):
54114
assert results.output == expected_output
55115

56116

57-
def test_bootstrap_init_main(tis_fixture, tmp_path):
58-
scheme, network, engine, init_conds = tis_fixture
59-
init_frame = init_conds[0][0]
117+
def test_bootstrap_init_main(toy_2_state_tis, toy_2_state_engine, tmp_path):
118+
network = toy_2_state_tis
119+
engine = toy_2_state_engine
120+
scheme = paths.DefaultScheme(network, engine)
121+
init_frame = toy.Snapshot(
122+
coordinates=np.array([[-0.5, -0.5]]),
123+
velocities=np.array([[0.0,0.0]]),
124+
engine=engine
125+
)
60126
assert len(network.transitions) == 1
61127
transition = list(network.transitions.values())[0]
62128
output_storage = paths.Storage(tmp_path / "output.nc", mode='w')
63129
init_conds, bootstrapper = bootstrap_init_main(init_frame, network,
64130
engine, transition,
65131
output_storage)
132+
init_conds.sanity_check()

0 commit comments

Comments
 (0)