Skip to content

Commit d1c7539

Browse files
committed
Convert output of INIT_CONDS parameter to trajs
1 parent de53902 commit d1c7539

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

paths_cli/parameters.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,31 @@
1818
store='schemes',
1919
)
2020

21-
INIT_CONDS = OPSStorageLoadMultiple(
21+
class InitCondsLoader(OPSStorageLoadMultiple):
22+
def _extract_trajectories(self, obj):
23+
import openpathsampling as paths
24+
if isinstance(obj, paths.SampleSet):
25+
yield from (s.trajectory for s in obj)
26+
elif isinstance(obj, paths.Sample):
27+
yield obj.trajectory
28+
elif isinstance(obj, paths.Trajectory):
29+
yield obj
30+
elif isinstance(obj, paths.BaseSnapshot):
31+
yield paths.Trajectory([obj])
32+
elif isinstance(obj, list):
33+
for o in obj:
34+
yield from self._extract_trajectories(o)
35+
else:
36+
raise RuntimeError("Unknown initial conditions type: "
37+
f"{obj} (type: {type(obj)}")
38+
39+
def get(self, storage, names):
40+
results = super().get(storage, names)
41+
final_results = list(self._extract_trajectories(results))
42+
return final_results
43+
44+
45+
INIT_CONDS = InitCondsLoader(
2246
param=Option('-t', '--init-conds', multiple=True,
2347
help=("identifier for initial conditions "
2448
+ "(sample set or trajectory)" + HELP_MULTIPLE)),

paths_cli/tests/test_parameters.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def test_get(self, getter):
243243
storage = paths.Storage(filename, mode='r')
244244
get_type, getter_style = self._parse_getter(getter)
245245
expected = {
246-
'sset': self.sample_set,
247-
'traj': self.traj
246+
'sset': [s.trajectory for s in self.sample_set],
247+
'traj': [self.traj]
248248
}[get_type]
249249
get_arg = {
250250
'name': 'traj',
@@ -277,7 +277,14 @@ def test_get_none(self, num_in_file):
277277

278278
st = paths.Storage(filename, mode='r')
279279
obj = INIT_CONDS.get(st, None)
280-
assert obj == stored_things[num_in_file - 1]
280+
# TODO: fix this for all being trajectories
281+
expected = [
282+
[self.traj],
283+
[s.trajectory for s in self.sample_set],
284+
[s.trajectory for s in self.other_sample_set],
285+
[s.trajectory for s in self.other_sample_set],
286+
]
287+
assert obj == expected[num_in_file - 1]
281288

282289
def test_get_multiple(self):
283290
filename = self.create_file('number-traj')

0 commit comments

Comments
 (0)