Skip to content

Commit 6ac1b0a

Browse files
committed
improve CLI for TIS networks
1 parent 4a9d497 commit 6ac1b0a

File tree

3 files changed

+187
-111
lines changed

3 files changed

+187
-111
lines changed

paths_cli/compiling/networks.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from paths_cli.compiling.core import (
2-
InstanceCompilerPlugin, Builder, Parameter
2+
InstanceCompilerPlugin, Builder, Parameter,
3+
listify, unlistify
34
)
45
from paths_cli.compiling.tools import custom_eval
56
from paths_cli.compiling.plugins import (
@@ -38,53 +39,76 @@
3839
description="final state for this transition",
3940
)
4041

42+
def mistis_trans_info_param_builder(dcts):
43+
default = 'volume-interface-set' # TODO: make this flexible?
44+
trans_info = []
45+
volume_compiler = compiler_for("volume")
46+
interface_set_compiler = compiler_for('interface_set')
47+
for dct in dcts:
48+
dct['type'] = dct.get('type', default)
49+
initial_state = volume_compiler(dct.pop('initial_state'))
50+
final_state = volume_compiler(dct.pop('final_state'))
51+
interface_set = interface_set_compiler(dct)
52+
trans_info.append((initial_state, interface_set, final_state))
53+
54+
return trans_info
55+
56+
57+
MISTIS_INTERFACE_SETS_PARAM = Parameter(
58+
'interface_sets', mistis_trans_info_param_builder,
59+
json_type=json_type_list(json_type_ref('interface-set')),
60+
description='interface sets for MISTIS'
61+
)
4162

42-
build_interface_set = InterfaceSetPlugin(
63+
# this is reused in the simple single TIS setup
64+
VOLUME_INTERFACE_SET_PARAMS = [
65+
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
66+
description=("the collective variable for this interface "
67+
"set")),
68+
Parameter('minvals', custom_eval,
69+
json_type=json_type_list(json_type_eval("Float")),
70+
description=("minimum value(s) for interfaces in this"
71+
"interface set")),
72+
Parameter('maxvals', custom_eval,
73+
json_type=json_type_list(json_type_eval("Float")),
74+
description=("maximum value(s) for interfaces in this"
75+
"interface set")),
76+
]
77+
78+
79+
VOLUME_INTERFACE_SET_PLUGIN = InterfaceSetPlugin(
4380
builder=Builder('openpathsampling.VolumeInterfaceSet'),
44-
parameters=[
45-
Parameter('cv', compiler_for('cv'), json_type=json_type_ref('cv'),
46-
description=("the collective variable for this interface "
47-
"set")),
48-
Parameter('minvals', custom_eval,
49-
json_type=json_type_list(json_type_eval("Float")),
50-
description=("minimum value(s) for interfaces in this"
51-
"interface set")),
52-
Parameter('maxvals', custom_eval,
53-
json_type=json_type_list(json_type_eval("Float")),
54-
description=("maximum value(s) for interfaces in this"
55-
"interface set")),
56-
],
57-
name='interface-set',
81+
parameters=VOLUME_INTERFACE_SET_PARAMS,
82+
name='volume-interface-set',
5883
description="Interface set used in transition interface sampling.",
5984
)
6085

6186

6287
def mistis_trans_info(dct):
6388
dct = dct.copy()
64-
transitions = dct.pop('transitions')
65-
volume_compiler = compiler_for('volume')
66-
trans_info = [
67-
(
68-
volume_compiler(trans['initial_state']),
69-
build_interface_set(trans['interfaces']),
70-
volume_compiler(trans['final_state'])
71-
)
72-
for trans in transitions
73-
]
74-
dct['trans_info'] = trans_info
89+
dct['trans_info'] = dct.pop('interface_sets')
7590
return dct
91+
# interface_sets = dct.pop('interface_sets')
92+
# volume_compiler = compiler_for('volume')
93+
# interface_set_compiler = compiler_for('interface_set')
94+
95+
# trans_info = []
96+
# for iset in interface_sets:
97+
# initial_state = volume_compiler(iset.pop("initial_state"))
98+
# final_state = volume_compiler(iset.pop("final_state"))
99+
# iset['type'] = iset.get('type', 'volume-interface-set')
100+
# interface_set = interface_set_compiler(iset)
101+
# trans_info.append((initial_state, interface_set, final_state))
102+
103+
# dct['trans_info'] = trans_info
104+
# return dct
76105

77106

78107
def tis_trans_info(dct):
79108
# remap TIS into MISTIS format
80109
dct = dct.copy()
81-
initial_state = dct.pop('initial_state')
82-
final_state = dct.pop('final_state')
83-
interface_set = dct.pop('interfaces')
84-
dct['transitions'] = [{'initial_state': initial_state,
85-
'final_state': final_state,
86-
'interfaces': interface_set}]
87-
return mistis_trans_info(dct)
110+
remapped = {'interface_sets': [dct]}
111+
return mistis_trans_info(remapped)
88112

89113

90114
TPS_NETWORK_PLUGIN = NetworkCompilerPlugin(
@@ -96,18 +120,27 @@ def tis_trans_info(dct):
96120
)
97121

98122

99-
# MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
100-
# parameters=[Parameter('trans_info', mistis_trans_info)],
101-
# builder=Builder('openpathsampling.MISTISNetwork'),
102-
# name='mistis'
103-
# )
123+
MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
124+
parameters=[MISTIS_INTERFACE_SETS_PARAM],
125+
builder=Builder('openpathsampling.MISTISNetwork',
126+
remapper=mistis_trans_info),
127+
name='mistis'
128+
)
104129

130+
def single_tis_builder(initial_state, final_state, cv, minvals, maxvals):
131+
import openpathsampling as paths
132+
interface_set = paths.VolumeInterfaceSet(cv, minvals, maxvals)
133+
return paths.MISTISNetwork([
134+
(initial_state, interface_set, final_state)
135+
])
136+
137+
TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
138+
builder=single_tis_builder,
139+
parameters=([INITIAL_STATE_PARAM, FINAL_STATE_PARAM]
140+
+ VOLUME_INTERFACE_SET_PARAMS),
141+
name='tis'
142+
)
105143

106-
# TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
107-
# builder=Builder('openpathsampling.MISTISNetwork'),
108-
# parameters=[Parameter('trans_info', tis_trans_info)],
109-
# name='tis'
110-
# )
111144

112145
# old names not yet replaced in testing THESE ARE WHY WE'RE DOUBLING! GET
113146
# RID OF THEM! (also, use an is-check)

paths_cli/compiling/schemes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,11 @@ def __call__(self, **dct):
9999
"that type (i.e., ``OrganizeByMoveGroupStrategy``)"),
100100
)
101101

102+
DEFAULT_TIS_SCHEME_PLUGIN = SchemeCompilerPlugin(
103+
builder=Builder('openpathsampling.DefaultScheme'),
104+
parameters=[NETWORK_PARAMETER, ENGINE_PARAMETER],
105+
name='default-tis',
106+
description="",
107+
)
108+
102109
SCHEME_COMPILER = CategoryPlugin(SchemeCompilerPlugin, aliases=['schemes'])

paths_cli/tests/compiling/test_networks.py

Lines changed: 103 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,74 +10,77 @@
1010

1111
_COMPILERS_LOC = 'paths_cli.compiling.root_compiler._COMPILERS'
1212

13-
14-
def check_unidirectional_tis(results, state_A, state_B, cv):
15-
assert len(results) == 1
16-
trans_info = results['trans_info']
17-
assert len(trans_info) == 1
18-
assert len(trans_info[0]) == 3
19-
trans = trans_info[0]
20-
assert isinstance(trans, tuple)
21-
assert trans[0] == state_A
22-
assert trans[2] == state_B
23-
assert isinstance(trans[1], paths.VolumeInterfaceSet)
24-
ifaces = trans[1]
25-
assert ifaces.cv == cv
26-
assert ifaces.minvals == float("-inf")
27-
np.testing.assert_allclose(ifaces.maxvals,
28-
[0, np.pi / 10.0, np.pi / 5.0])
29-
30-
31-
def test_mistis_trans_info(cv_and_states):
13+
@pytest.fixture
14+
def unidirectional_tis_compiler(cv_and_states):
3215
cv, state_A, state_B = cv_and_states
33-
dct = {
34-
'transitions': [{
35-
'initial_state': "A",
36-
'final_state': "B",
37-
'interfaces': {
38-
'cv': 'cv',
39-
'minvals': 'float("-inf")',
40-
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
41-
}
42-
}]
43-
}
44-
patch_base = 'paths_cli.compiling.networks'
45-
compiler = {
46-
'cv': mock_compiler('cv', named_objs={'cv': cv}),
47-
'volume': mock_compiler('volume', named_objs={
48-
"A": state_A, "B": state_B
49-
}),
50-
}
51-
with mock.patch.dict(_COMPILERS_LOC, compiler):
52-
results = mistis_trans_info(dct)
53-
54-
check_unidirectional_tis(results, state_A, state_B, cv)
55-
paths.InterfaceSet._reset()
56-
57-
58-
def test_tis_trans_info(cv_and_states):
59-
cv, state_A, state_B = cv_and_states
60-
dct = {
61-
'initial_state': "A",
62-
'final_state': "B",
63-
'interfaces': {
64-
'cv': 'cv',
65-
'minvals': 'float("-inf")',
66-
'maxvals': 'np.array([0, 0.1, 0.2]) * np.pi',
67-
}
68-
}
69-
70-
compiler = {
16+
return {
7117
'cv': mock_compiler('cv', named_objs={'cv': cv}),
7218
'volume': mock_compiler('volume', named_objs={
7319
"A": state_A, "B": state_B
7420
}),
21+
'interface_set': mock_compiler(
22+
'interface_set',
23+
type_dispatch={
24+
'volume-interface-set': VOLUME_INTERFACE_SET_PLUGIN
25+
}
26+
),
7527
}
76-
with mock.patch.dict(_COMPILERS_LOC, compiler):
77-
results = tis_trans_info(dct)
7828

79-
check_unidirectional_tis(results, state_A, state_B, cv)
80-
paths.InterfaceSet._reset()
29+
# def check_unidirectional_tis(results, state_A, state_B, cv):
30+
# assert len(results) == 1
31+
# trans_info = results['trans_info']
32+
# assert len(trans_info) == 1
33+
# assert len(trans_info[0]) == 3
34+
# trans = trans_info[0]
35+
# assert isinstance(trans, tuple)
36+
# assert trans[0] == state_A
37+
# assert trans[2] == state_B
38+
# assert isinstance(trans[1], paths.VolumeInterfaceSet)
39+
# ifaces = trans[1]
40+
# assert ifaces.cv == cv
41+
# assert ifaces.minvals == float("-inf")
42+
# np.testing.assert_allclose(ifaces.maxvals,
43+
# [0, np.pi / 10.0, np.pi / 5.0])
44+
45+
46+
# def test_mistis_trans_info(cv_and_states, mistis_dict,
47+
# unidirectional_tis_compiler):
48+
# cv, state_A, state_B = cv_and_states
49+
# patch_base = 'paths_cli.compiling.networks'
50+
# with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
51+
# results = mistis_trans_info(mistis_dict)
52+
53+
# check_unidirectional_tis(results, state_A, state_B, cv)
54+
# paths.InterfaceSet._reset()
55+
56+
57+
# def test_tis_trans_info(cv_and_states):
58+
# cv, state_A, state_B = cv_and_states
59+
# dct = {
60+
# 'initial_state': "A",
61+
# 'final_state': "B",
62+
# 'cv': 'cv',
63+
# 'minvals': 'float("-inf")',
64+
# 'maxvals': 'np.array([0, 0.1, 0.2]) * np.pi',
65+
# }
66+
67+
# compiler = {
68+
# 'cv': mock_compiler('cv', named_objs={'cv': cv}),
69+
# 'volume': mock_compiler('volume', named_objs={
70+
# "A": state_A, "B": state_B
71+
# }),
72+
# 'interface_set': mock_compiler(
73+
# 'interface_set',
74+
# type_dispatch={
75+
# 'volume-interface-set': VOLUME_INTERFACE_SET_PLUGIN
76+
# }
77+
# ),
78+
# }
79+
# with mock.patch.dict(_COMPILERS_LOC, compiler):
80+
# results = tis_trans_info(dct)
81+
82+
# check_unidirectional_tis(results, state_A, state_B, cv)
83+
# paths.InterfaceSet._reset()
8184

8285

8386
def test_build_tps_network(cv_and_states):
@@ -86,17 +89,50 @@ def test_build_tps_network(cv_and_states):
8689
dct = yaml.load(yml, yaml.FullLoader)
8790
compiler = {
8891
'volume': mock_compiler('volume', named_objs={"A": state_A,
89-
"B": state_B}),
92+
"B": state_B}),
9093
}
9194
with mock.patch.dict(_COMPILERS_LOC, compiler):
92-
network = build_tps_network(dct)
95+
network = TPS_NETWORK_PLUGIN(dct)
9396
assert isinstance(network, paths.TPSNetwork)
9497
assert len(network.initial_states) == len(network.final_states) == 1
9598
assert network.initial_states[0] == state_A
9699
assert network.final_states[0] == state_B
97100

98-
def test_build_mistis_network():
99-
pytest.skip()
100101

101-
def test_build_tis_network():
102-
pytest.skip()
102+
def test_build_mistis_network(cv_and_states, unidirectional_tis_compiler):
103+
cv, state_A, state_B = cv_and_states
104+
mistis_dict = {
105+
'interface_sets': [{
106+
'initial_state': "A",
107+
'final_state': "B",
108+
'cv': 'cv',
109+
'minvals': 'float("-inf")',
110+
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi"
111+
}]
112+
}
113+
114+
with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
115+
network = MISTIS_NETWORK_PLUGIN(mistis_dict)
116+
117+
assert isinstance(network, paths.MISTISNetwork)
118+
assert len(network.sampling_transitions) == 1
119+
assert len(network.transitions) == 1
120+
assert list(network.transitions) == [(state_A, state_B)]
121+
122+
def test_build_tis_network(cv_and_states, unidirectional_tis_compiler):
123+
cv, state_A, state_B = cv_and_states
124+
tis_dict = {
125+
'initial_state': "A",
126+
'final_state': "B",
127+
'cv': "cv",
128+
'minvals': 'float("inf")',
129+
'maxvals': "np.array([0, 0.1, 0.2]) * np.pi",
130+
}
131+
132+
with mock.patch.dict(_COMPILERS_LOC, unidirectional_tis_compiler):
133+
network = TIS_NETWORK_PLUGIN(tis_dict)
134+
135+
assert isinstance(network, paths.MISTISNetwork)
136+
assert len(network.sampling_transitions) == 1
137+
assert len(network.transitions) == 1
138+
assert list(network.transitions) == [(state_A, state_B)]

0 commit comments

Comments
 (0)