Skip to content

Commit de53902

Browse files
authored
Merge pull request #86 from dwhswenson/tis-compiling
TIS compiling
2 parents e996e00 + e3bd443 commit de53902

File tree

3 files changed

+125
-115
lines changed

3 files changed

+125
-115
lines changed

paths_cli/compiling/networks.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -38,55 +38,58 @@
3838
description="final state for this transition",
3939
)
4040

41+
def mistis_trans_info_param_builder(dcts):
42+
default = 'volume-interface-set' # TODO: make this flexible?
43+
trans_info = []
44+
volume_compiler = compiler_for("volume")
45+
interface_set_compiler = compiler_for('interface_set')
46+
for dct in dcts:
47+
dct = dct.copy()
48+
dct['type'] = dct.get('type', default)
49+
initial_state = volume_compiler(dct.pop('initial_state'))
50+
final_state = volume_compiler(dct.pop('final_state'))
51+
interface_set = interface_set_compiler(dct)
52+
trans_info.append((initial_state, interface_set, final_state))
53+
54+
return trans_info
55+
56+
57+
MISTIS_INTERFACE_SETS_PARAM = Parameter(
58+
'interface_sets', mistis_trans_info_param_builder,
59+
json_type=json_type_list(json_type_ref('interface-set')),
60+
description='interface sets for MISTIS'
61+
)
4162

42-
build_interface_set = InterfaceSetPlugin(
63+
# this is reused in the simple single TIS setup
64+
VOLUME_INTERFACE_SET_PARAMS = [
65+
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
66+
description=("the collective variable for this interface "
67+
"set")),
68+
Parameter('minvals', custom_eval,
69+
json_type=json_type_list(json_type_eval("Float")),
70+
description=("minimum value(s) for interfaces in this"
71+
"interface set")),
72+
Parameter('maxvals', custom_eval,
73+
json_type=json_type_list(json_type_eval("Float")),
74+
description=("maximum value(s) for interfaces in this"
75+
"interface set")),
76+
]
77+
78+
79+
VOLUME_INTERFACE_SET_PLUGIN = InterfaceSetPlugin(
4380
builder=Builder('openpathsampling.VolumeInterfaceSet'),
44-
parameters=[
45-
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
46-
description=("the collective variable for this interface "
47-
"set")),
48-
Parameter('minvals', custom_eval,
49-
json_type=json_type_list(json_type_eval("Float")),
50-
description=("minimum value(s) for interfaces in this"
51-
"interface set")),
52-
Parameter('maxvals', custom_eval,
53-
json_type=json_type_list(json_type_eval("Float")),
54-
description=("maximum value(s) for interfaces in this"
55-
"interface set")),
56-
],
57-
name='interface-set',
81+
parameters=VOLUME_INTERFACE_SET_PARAMS,
82+
name='volume-interface-set',
5883
description="Interface set used in transition interface sampling.",
5984
)
6085

6186

6287
def mistis_trans_info(dct):
6388
dct = dct.copy()
64-
transitions = dct.pop('transitions')
65-
volume_compiler = compiler_for('volume')
66-
trans_info = [
67-
(
68-
volume_compiler(trans['initial_state']),
69-
build_interface_set(trans['interfaces']),
70-
volume_compiler(trans['final_state'])
71-
)
72-
for trans in transitions
73-
]
74-
dct['trans_info'] = trans_info
89+
dct['trans_info'] = dct.pop('interface_sets')
7590
return dct
7691

7792

78-
def tis_trans_info(dct):
79-
# remap TIS into MISTIS format
80-
dct = dct.copy()
81-
initial_state = dct.pop('initial_state')
82-
final_state = dct.pop('final_state')
83-
interface_set = dct.pop('interfaces')
84-
dct['transitions'] = [{'initial_state': initial_state,
85-
'final_state': final_state,
86-
'interfaces': interface_set}]
87-
return mistis_trans_info(dct)
88-
89-
9093
TPS_NETWORK_PLUGIN = NetworkCompilerPlugin(
9194
builder=Builder('openpathsampling.TPSNetwork'),
9295
parameters=[INITIAL_STATES_PARAM, FINAL_STATES_PARAM],
@@ -96,18 +99,27 @@ def tis_trans_info(dct):
9699
)
97100

98101

99-
# MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
100-
# parameters=[Parameter('trans_info', mistis_trans_info)],
101-
# builder=Builder('openpathsampling.MISTISNetwork'),
102-
# name='mistis'
103-
# )
102+
MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
103+
parameters=[MISTIS_INTERFACE_SETS_PARAM],
104+
builder=Builder('openpathsampling.MISTISNetwork',
105+
remapper=mistis_trans_info),
106+
name='mistis'
107+
)
104108

109+
def single_tis_builder(initial_state, final_state, cv, minvals, maxvals):
110+
import openpathsampling as paths
111+
interface_set = paths.VolumeInterfaceSet(cv, minvals, maxvals)
112+
return paths.MISTISNetwork([
113+
(initial_state, interface_set, final_state)
114+
])
115+
116+
TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
117+
builder=single_tis_builder,
118+
parameters=([INITIAL_STATE_PARAM, FINAL_STATE_PARAM]
119+
+ VOLUME_INTERFACE_SET_PARAMS),
120+
name='tis'
121+
)
105122

106-
# TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
107-
# builder=Builder('openpathsampling.MISTISNetwork'),
108-
# parameters=[Parameter('trans_info', tis_trans_info)],
109-
# name='tis'
110-
# )
111123

112124
# old names not yet replaced in testing THESE ARE WHY WE'RE DOUBLING! GET
113125
# RID OF THEM! (also, use an is-check)

paths_cli/compiling/schemes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,11 @@ def __call__(self, **dct):
9999
"that type (i.e., ``OrganizeByMoveGroupStrategy``)"),
100100
)
101101

102+
DEFAULT_TIS_SCHEME_PLUGIN = SchemeCompilerPlugin(
103+
builder=Builder('openpathsampling.DefaultScheme'),
104+
parameters=[NETWORK_PARAMETER, ENGINE_PARAMETER],
105+
name='default-tis',
106+
description="",
107+
)
108+
102109
SCHEME_COMPILER = CategoryPlugin(SchemeCompilerPlugin, aliases=['schemes'])

paths_cli/tests/compiling/test_networks.py

Lines changed: 58 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,74 +10,22 @@
1010

1111
_COMPILERS_LOC = 'paths_cli.compiling.root_compiler._COMPILERS'
1212

13-
14-
def check_unidirectional_tis(results, state_A, state_B, cv):
15-
assert len(results) == 1
16-
trans_info = results['trans_info']
17-
assert len(trans_info) == 1
18-
assert len(trans_info[0]) == 3
19-
trans = trans_info[0]
20-
assert isinstance(trans, tuple)
21-
assert trans[0] == state_A
22-
assert trans[2] == state_B
23-
assert isinstance(trans[1], paths.VolumeInterfaceSet)
24-
ifaces = trans[1]
25-
assert ifaces.cv == cv
26-
assert ifaces.minvals == float("-inf")
27-
np.testing.assert_allclose(ifaces.maxvals,
28-
[0, np.pi / 10.0, np.pi / 5.0])
29-
30-
31-
def test_mistis_trans_info(cv_and_states):
32-
cv, state_A, state_B = cv_and_states
33-
dct = {
34-
'transitions': [{
35-
'initial_state': "A",
36-
'final_state': "B",
37-
'interfaces': {
38-
'cv': 'cv',
39-
'minvals': 'float("-inf")',
40-
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
41-
}
42-
}]
43-
}
44-
patch_base = 'paths_cli.compiling.networks'
45-
compiler = {
46-
'cv': mock_compiler('cv', named_objs={'cv': cv}),
47-
'volume': mock_compiler('volume', named_objs={
48-
"A": state_A, "B": state_B
49-
}),
50-
}
51-
with mock.patch.dict(_COMPILERS_LOC, compiler):
52-
results = mistis_trans_info(dct)
53-
54-
check_unidirectional_tis(results, state_A, state_B, cv)
13+
@pytest.fixture
14+
def unidirectional_tis_compiler(cv_and_states):
5515
paths.InterfaceSet._reset()
56-
57-
58-
def test_tis_trans_info(cv_and_states):
5916
cv, state_A, state_B = cv_and_states
60-
dct = {
61-
'initial_state': "A",
62-
'final_state': "B",
63-
'interfaces': {
64-
'cv': 'cv',
65-
'minvals': 'float("-inf")',
66-
'maxvals': 'np.array([0, 0.1, 0.2]) * np.pi',
67-
}
68-
}
69-
70-
compiler = {
17+
return {
7118
'cv': mock_compiler('cv', named_objs={'cv': cv}),
7219
'volume': mock_compiler('volume', named_objs={
7320
"A": state_A, "B": state_B
7421
}),
22+
'interface_set': mock_compiler(
23+
'interface_set',
24+
type_dispatch={
25+
'volume-interface-set': VOLUME_INTERFACE_SET_PLUGIN
26+
}
27+
),
7528
}
76-
with mock.patch.dict(_COMPILERS_LOC, compiler):
77-
results = tis_trans_info(dct)
78-
79-
check_unidirectional_tis(results, state_A, state_B, cv)
80-
paths.InterfaceSet._reset()
8129

8230

8331
def test_build_tps_network(cv_and_states):
@@ -86,17 +34,60 @@ def test_build_tps_network(cv_and_states):
8634
dct = yaml.load(yml, yaml.FullLoader)
8735
compiler = {
8836
'volume': mock_compiler('volume', named_objs={"A": state_A,
89-
"B": state_B}),
37+
"B": state_B}),
9038
}
9139
with mock.patch.dict(_COMPILERS_LOC, compiler):
92-
network = build_tps_network(dct)
40+
network = TPS_NETWORK_PLUGIN(dct)
9341
assert isinstance(network, paths.TPSNetwork)
9442
assert len(network.initial_states) == len(network.final_states) == 1
9543
assert network.initial_states[0] == state_A
9644
assert network.final_states[0] == state_B
9745

98-
def test_build_mistis_network():
99-
pytest.skip()
10046

101-
def test_build_tis_network():
102-
pytest.skip()
47+
def test_build_mistis_network(cv_and_states, unidirectional_tis_compiler):
48+
cv, state_A, state_B = cv_and_states
49+
mistis_dict = {
50+
'interface_sets': [
51+
{
52+
'initial_state': "A",
53+
'final_state': "B",
54+
'cv': 'cv',
55+
'minvals': 'float("-inf")',
56+
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
57+
},
58+
{
59+
'initial_state': "B",
60+
'final_state': "A",
61+
'cv': 'cv',
62+
'minvals': "np.array([1.0, 0.9, 0.8])",
63+
'maxvals': "float('inf')",
64+
}
65+
]
66+
}
67+
68+
with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
69+
network = MISTIS_NETWORK_PLUGIN(mistis_dict)
70+
71+
assert isinstance(network, paths.MISTISNetwork)
72+
assert len(network.sampling_transitions) == 2
73+
assert len(network.transitions) == 2
74+
assert list(network.transitions) == [(state_A, state_B),
75+
(state_B, state_A)]
76+
77+
def test_build_tis_network(cv_and_states, unidirectional_tis_compiler):
78+
cv, state_A, state_B = cv_and_states
79+
tis_dict = {
80+
'initial_state': "A",
81+
'final_state': "B",
82+
'cv': "cv",
83+
'minvals': 'float("inf")',
84+
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi",
85+
}
86+
87+
with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
88+
network = TIS_NETWORK_PLUGIN(tis_dict)
89+
90+
assert isinstance(network, paths.MISTISNetwork)
91+
assert len(network.sampling_transitions) == 1
92+
assert len(network.transitions) == 1
93+
assert list(network.transitions) == [(state_A, state_B)]

0 commit comments

Comments
 (0)