Skip to content

Commit 02b3374

Browse files
authored
Merge pull request #88 from dwhswenson/cmd-bootstrap-init
New command: `bootstrap-init`
2 parents 6f0e2d6 + 8352adc commit 02b3374

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import click
2+
from paths_cli import OPSCommandPlugin
3+
from paths_cli.parameters import (
4+
INIT_SNAP, SCHEME, ENGINE, OUTPUT_FILE, INPUT_FILE
5+
)
6+
from paths_cli.param_core import OPSStorageLoadSingle, Option
7+
8+
INIT_STATE = OPSStorageLoadSingle(
9+
param=Option("--initial-state", help="initial state"),
10+
store='volumes'
11+
)
12+
FINAL_STATE = OPSStorageLoadSingle(
13+
param=Option("--final-state", help="final state"),
14+
store='volumes'
15+
)
16+
17+
@click.command(
18+
"bootstrap-init",
19+
short_help="TIS interface set initial trajectories from a snapshot",
20+
)
21+
@INIT_STATE.clicked(required=True)
22+
@FINAL_STATE.clicked(required=True)
23+
@INIT_SNAP.clicked()
24+
@SCHEME.clicked()
25+
@ENGINE.clicked()
26+
@OUTPUT_FILE.clicked()
27+
@INPUT_FILE.clicked()
28+
def bootstrap_init(initial_state, final_state, scheme, engine, init_frame,
29+
output_file, input_file):
30+
"""Use ``FullBootstrapping`` to create initial conditions for TIS.
31+
32+
This approach starts from a snapshot, runs MD to generate an initial
33+
path for the innermost ensemble, and then performs one-way shooting
34+
moves within each ensemble until the next ensemble as reached. This
35+
continues until all ensembles have valid trajectories.
36+
37+
Note that intermediate sampling in this is not saved to disk.
38+
"""
39+
storage = INPUT_FILE.get(input_file)
40+
scheme = SCHEME.get(storage, scheme)
41+
network = scheme.network
42+
engine = ENGINE.get(storage, engine)
43+
init_state = INIT_STATE.get(storage, initial_state)
44+
final_state = FINAL_STATE.get(storage, final_state)
45+
transition = network.transitions[(init_state, final_state)]
46+
bootstrap_init_main(
47+
init_frame=INIT_SNAP.get(storage, init_frame),
48+
network=network,
49+
engine=engine,
50+
transition=transition,
51+
output_storage=OUTPUT_FILE.get(output_file)
52+
)
53+
54+
55+
def bootstrap_init_main(init_frame, network, engine, transition,
56+
output_storage):
57+
import openpathsampling as paths
58+
all_states = set(network.initial_states) | set(network.final_states)
59+
allowed_states = {transition.stateA, transition.stateB}
60+
forbidden_states = list(all_states - allowed_states)
61+
try:
62+
extra_ensembles = network.ms_outers
63+
except KeyError:
64+
extra_ensembles = None
65+
66+
bootstrapper = paths.FullBootstrapping(
67+
transition=transition,
68+
snapshot=init_frame,
69+
engine=engine,
70+
forbidden_states=forbidden_states,
71+
extra_ensembles=extra_ensembles,
72+
)
73+
init_conds = bootstrapper.run()
74+
if output_storage:
75+
output_storage.tags['final_conditions'] = init_conds
76+
77+
return init_conds, bootstrapper
78+
79+
80+
PLUGIN = OPSCommandPlugin(
81+
command=bootstrap_init,
82+
section="Simulation",
83+
requires_ops=(1, 0),
84+
requires_cli=(0, 4),
85+
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import pytest
2+
from click.testing import CliRunner
3+
from unittest.mock import patch
4+
import numpy as np
5+
6+
from paths_cli.commands.bootstrap_init import *
7+
import openpathsampling as paths
8+
from openpathsampling.engines import toy
9+
from openpathsampling.tests.test_helpers import make_1d_traj
10+
11+
@pytest.fixture
12+
def toy_2_state_engine():
13+
pes = (
14+
toy.OuterWalls([1.0, 1.0], [0.0, 0.0]) +
15+
toy.Gaussian(-1.0, [12.0, 12.0], [-0.5, 0.0]) +
16+
toy.Gaussian(-1.0, [12.0, 12.0], [0.5, 0.0])
17+
)
18+
topology=toy.Topology(
19+
n_spatial = 2,
20+
masses =[1.0, 1.0],
21+
pes = pes
22+
)
23+
integ = toy.LangevinBAOABIntegrator(dt=0.02, temperature=0.1, gamma=2.5)
24+
options = {
25+
'integ': integ,
26+
'n_frames_max': 5000,
27+
'n_steps_per_frame': 1
28+
}
29+
30+
engine = toy.Engine(
31+
options=options,
32+
topology=topology
33+
)
34+
return engine
35+
36+
@pytest.fixture
37+
def toy_2_state_cv():
38+
return paths.FunctionCV("x", lambda s: s.xyz[0][0])
39+
40+
@pytest.fixture
41+
def toy_2_state_volumes(toy_2_state_cv):
42+
state_A = paths.CVDefinedVolume(
43+
toy_2_state_cv,
44+
float("-inf"),
45+
-0.3,
46+
).named("A")
47+
state_B = paths.CVDefinedVolume(
48+
toy_2_state_cv,
49+
0.3,
50+
float("inf"),
51+
).named("B")
52+
return state_A, state_B
53+
54+
@pytest.fixture
55+
def toy_2_state_tis(toy_2_state_cv, toy_2_state_volumes):
56+
state_A, state_B = toy_2_state_volumes
57+
interfaces = paths.VolumeInterfaceSet(
58+
toy_2_state_cv,
59+
float("-inf"),
60+
[-0.3, -0.2, -0.1],
61+
)
62+
tis = paths.MISTISNetwork(
63+
[(state_A, interfaces, state_B)],
64+
)
65+
return tis
66+
67+
68+
def print_test(init_frame, network, engine, transition, output_storage):
69+
print(init_frame.__uuid__)
70+
print(network.__uuid__)
71+
print(engine.__uuid__)
72+
# apparently transition UUID isn't preserved, but these are?
73+
print(transition.stateA.__uuid__)
74+
print(transition.stateB.__uuid__)
75+
print([e.__uuid__ for e in transition.ensembles])
76+
print(isinstance(output_storage, paths.Storage))
77+
78+
@patch('paths_cli.commands.bootstrap_init.bootstrap_init_main', print_test)
79+
def test_bootstrap_init(tis_fixture):
80+
scheme, network, engine, init_conds = tis_fixture
81+
runner = CliRunner()
82+
with runner.isolated_filesystem():
83+
storage = paths.Storage("setup.nc", 'w')
84+
storage.save(init_conds)
85+
for obj in tis_fixture:
86+
storage.save(obj)
87+
88+
storage.tags["init_snap"] = init_conds[0][0]
89+
storage.close()
90+
91+
results = runner.invoke(bootstrap_init, [
92+
'setup.nc',
93+
'-o', 'foo.nc',
94+
'--initial-state', "A",
95+
'--final-state', "B",
96+
'--init-frame', 'init_snap',
97+
])
98+
99+
transitions = list(network.transitions.values())
100+
assert len(transitions) == 1
101+
transition = transitions[0]
102+
stateA = transition.stateA
103+
stateB = transition.stateB
104+
ensembles = transition.ensembles
105+
106+
expected_output = (
107+
f"{init_conds[0][0].__uuid__}\n{network.__uuid__}\n"
108+
f"{engine.__uuid__}\n"
109+
f"{stateA.__uuid__}\n{stateB.__uuid__}\n"
110+
f"{[e.__uuid__ for e in ensembles]}\n"
111+
"True\n"
112+
)
113+
assert results.exit_code == 0
114+
assert results.output == expected_output
115+
116+
117+
def test_bootstrap_init_main(toy_2_state_tis, toy_2_state_engine, tmp_path):
118+
network = toy_2_state_tis
119+
engine = toy_2_state_engine
120+
scheme = paths.DefaultScheme(network, engine)
121+
init_frame = toy.Snapshot(
122+
coordinates=np.array([[-0.5, -0.5]]),
123+
velocities=np.array([[0.0,0.0]]),
124+
engine=engine
125+
)
126+
assert len(network.transitions) == 1
127+
transition = list(network.transitions.values())[0]
128+
output_storage = paths.Storage(tmp_path / "output.nc", mode='w')
129+
init_conds, bootstrapper = bootstrap_init_main(init_frame, network,
130+
engine, transition,
131+
output_storage)
132+
init_conds.sanity_check()

paths_cli/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,11 @@ def tis_network(cv_and_states):
5858
[0.0, 0.1, 0.2])
5959
network = paths.MISTISNetwork([(state_A, interfaces, state_B)])
6060
return network
61+
62+
@pytest.fixture
63+
def tis_fixture(flat_engine, tis_network, transition_traj):
64+
paths.InterfaceSet._reset()
65+
scheme = paths.DefaultScheme(network=tis_network,
66+
engine=flat_engine)
67+
init_conds = scheme.initial_conditions_from_trajectories(transition_traj)
68+
return (scheme, tis_network, flat_engine, init_conds)

0 commit comments

Comments
 (0)