Skip to content

Commit 6f0e2d6

Browse files
authored
Merge pull request #87 from dwhswenson/more-flexible-init-conds
Convert output of `INIT_CONDS` parameter to list of trajectories
2 parents f17eebb + ff15749 commit 6f0e2d6

File tree

4 files changed

+61
-11
lines changed

4 files changed

+61
-11
lines changed

paths_cli/parameters.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,29 @@
1818
store='schemes',
1919
)
2020

21-
INIT_CONDS = OPSStorageLoadMultiple(
21+
class InitCondsLoader(OPSStorageLoadMultiple):
22+
def _extract_trajectories(self, obj):
23+
import openpathsampling as paths
24+
if isinstance(obj, paths.SampleSet):
25+
yield from (s.trajectory for s in obj)
26+
elif isinstance(obj, paths.Sample):
27+
yield obj.trajectory
28+
elif isinstance(obj, paths.Trajectory):
29+
yield obj
30+
elif isinstance(obj, list):
31+
for o in obj:
32+
yield from self._extract_trajectories(o)
33+
else:
34+
raise RuntimeError("Unknown initial conditions type: "
35+
f"{obj} (type: {type(obj)}")
36+
37+
def get(self, storage, names):
38+
results = super().get(storage, names)
39+
final_results = list(self._extract_trajectories(results))
40+
return final_results
41+
42+
43+
INIT_CONDS = InitCondsLoader(
2244
param=Option('-t', '--init-conds', multiple=True,
2345
help=("identifier for initial conditions "
2446
+ "(sample set or trajectory)" + HELP_MULTIPLE)),

paths_cli/tests/commands/test_equilibrate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def print_test(output_storage, scheme, init_conds, multiplier, extra_steps):
1212
print(isinstance(output_storage, paths.Storage))
1313
print(scheme.__uuid__)
14-
print(init_conds.__uuid__)
14+
print([o.__uuid__ for o in init_conds])
1515
print(multiplier, extra_steps)
1616

1717

@@ -31,8 +31,10 @@ def test_equilibrate(tps_fixture):
3131
["setup.nc", "-o", "foo.nc"]
3232
)
3333
out_str = "True\n{schemeid}\n{condsid}\n1 0\n"
34-
expected_output = out_str.format(schemeid=scheme.__uuid__,
35-
condsid=init_conds.__uuid__)
34+
expected_output = out_str.format(
35+
schemeid=scheme.__uuid__,
36+
condsid=[o.trajectory.__uuid__ for o in init_conds],
37+
)
3638
assert results.exit_code == 0
3739
assert results.output == expected_output
3840

paths_cli/tests/commands/test_pathsampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def print_test(output_storage, scheme, init_conds, n_steps):
1212
print(isinstance(output_storage, paths.Storage))
1313
print(scheme.__uuid__)
14-
print(init_conds.__uuid__)
14+
print([traj.__uuid__ for traj in init_conds])
1515
print(n_steps)
1616

1717
@patch('paths_cli.commands.pathsampling.pathsampling_main', print_test)
@@ -26,7 +26,8 @@ def test_pathsampling(tps_fixture):
2626

2727
results = runner.invoke(pathsampling, ['setup.nc', '-o', 'foo.nc',
2828
'-n', '1000'])
29-
expected_output = (f"True\n{scheme.__uuid__}\n{init_conds.__uuid__}"
29+
initcondsid = [samp.trajectory.__uuid__ for samp in init_conds]
30+
expected_output = (f"True\n{scheme.__uuid__}\n{initcondsid}"
3031
"\n1000\n")
3132

3233
assert results.output == expected_output

paths_cli/tests/test_parameters.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def create_file(self, getter):
214214
get_type, getter_style = self._parse_getter(getter)
215215
main, other = {
216216
'traj': (self.traj, self.other_traj),
217-
'sset': (self.sample_set, self.other_sample_set)
217+
'sset': (self.sample_set, self.other_sample_set),
218+
'samp': (self.sample_set[0], self.other_sample_set[0]),
218219
}[get_type]
220+
if get_type == 'samp':
221+
storage.save(main)
222+
storage.save(other)
219223
if get_type == 'sset':
220224
storage.save(self.sample_set)
221225
storage.save(self.other_sample_set)
@@ -231,20 +235,23 @@ def create_file(self, getter):
231235

232236
if other_tag:
233237
storage.tags[other_tag] = other
238+
234239
storage.close()
235240
return filename
236241

237242
@pytest.mark.parametrize("getter", [
238243
'name-traj', 'number-traj', 'tag-final-traj', 'tag-initial-traj',
239-
'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset'
244+
'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset',
245+
'name-samp', 'number-samp',
240246
])
241247
def test_get(self, getter):
242248
filename = self.create_file(getter)
243249
storage = paths.Storage(filename, mode='r')
244250
get_type, getter_style = self._parse_getter(getter)
245251
expected = {
246-
'sset': self.sample_set,
247-
'traj': self.traj
252+
'sset': [s.trajectory for s in self.sample_set],
253+
'traj': [self.traj],
254+
'samp': [self.sample_set[0].trajectory],
248255
}[get_type]
249256
get_arg = {
250257
'name': 'traj',
@@ -277,7 +284,13 @@ def test_get_none(self, num_in_file):
277284

278285
st = paths.Storage(filename, mode='r')
279286
obj = INIT_CONDS.get(st, None)
280-
assert obj == stored_things[num_in_file - 1]
287+
expected = [
288+
[self.traj],
289+
[s.trajectory for s in self.sample_set],
290+
[s.trajectory for s in self.other_sample_set],
291+
[s.trajectory for s in self.other_sample_set],
292+
]
293+
assert obj == expected[num_in_file - 1]
281294

282295
def test_get_multiple(self):
283296
filename = self.create_file('number-traj')
@@ -297,6 +310,18 @@ def test_cannot_guess(self):
297310
with pytest.raises(RuntimeError):
298311
self.PARAMETER.get(storage, None)
299312

313+
def test_get_bad_name(self):
314+
filename = self._filename("bad_tag")
315+
storage = paths.Storage(filename, 'w')
316+
storage.save(self.traj)
317+
storage.save(self.other_traj)
318+
storage.tags['bad_tag'] = "foo"
319+
storage.close()
320+
321+
storage = paths.Storage(filename, 'r')
322+
with pytest.raises(RuntimeError, match="initial conditions type"):
323+
self.PARAMETER.get(storage, "bad_tag")
324+
300325

301326
class TestINIT_SNAP(ParamInstanceTest):
302327
PARAMETER = INIT_SNAP

0 commit comments

Comments
 (0)