-
Notifications
You must be signed in to change notification settings - Fork 1
[DNM] Refactor reader.py #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4b257b2
236ff72
d274b0c
f3634dd
e92adb3
b528ca2
a477bc1
ad84082
8ba8087
ead7951
f898a35
c675a5c
88e456d
73a8e4d
43aaca2
c165525
5f17770
c28286e
197b6ba
b45390a
1d70936
20084c3
1a1c916
59c7392
5e135ab
a9a8780
92af45b
8ea3585
220d504
8c44cb2
d13495f
c34c97c
0161673
9b6ca69
1d5c849
f4e88e2
bd0c8ee
ba4c912
157c02f
98ea023
c5b2d70
54576ab
3aa52a5
ff6991a
7a30f69
ac1fe7b
67a0913
e706b11
7277f20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,6 @@ | |
| from .reader import FEReader | ||
| from .transformations import ( | ||
| Aligner, | ||
| Minimiser, | ||
| ClosestImageShift, | ||
| NoJump, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from typing import Optional | ||
| import pathlib | ||
| from typing import Literal, Optional | ||
|
|
||
| import netCDF4 as nc | ||
| import numpy as np | ||
|
|
@@ -52,16 +53,20 @@ def _determine_iteration_dt(dataset) -> float: | |
|
|
||
|
|
||
| class FEReader(ReaderBase): | ||
| """A MDAnalysis Reader for NetCDF files created by | ||
| """ | ||
| MDAnalysis Reader for NetCDF files created by | ||
| `openmmtools.multistate.MultiStateReporter` | ||
|
|
||
| Looks along a multistate NetCDF file along one of two axes: | ||
| - constant state/lambda (varying replica) | ||
| - constant replica (varying lambda) | ||
| Provides a 1D trajectory view along either: | ||
|
|
||
| - constant Hamiltonian state (`view="state"`) | ||
| - constant replica (`view="replica"`) | ||
|
|
||
| selected via the `index` argument. | ||
| """ | ||
|
|
||
| _state_id: Optional[int] | ||
| _replica_id: Optional[int] | ||
| _index: Optional[int] | ||
| _view: Optional[str] | ||
| _frame_index: int | ||
| _dataset: nc.Dataset | ||
| _dataset_owner: bool | ||
|
|
@@ -70,35 +75,27 @@ class FEReader(ReaderBase): | |
|
|
||
| units = {"time": "ps", "length": "nanometer"} | ||
|
|
||
| def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, **kwargs): | ||
| def __init__( | ||
| self, | ||
| filename: str | pathlib.Path | nc.Dataset, | ||
| *, | ||
| index: int, | ||
| view: Literal["state", "replica"] = "state", | ||
| convert_units: bool = True, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
| filename : pathlike or nc.Dataset | ||
| path to the .nc file | ||
| Path to the .nc file or an open Dataset. | ||
| index : int | ||
| Index of the state or replica to extract. May be negative. | ||
| view : {"state", "replica"}, default "state" | ||
| Whether `index` refers to a Hamiltonian state or a replica. | ||
| convert_units : bool | ||
| convert positions to Angstrom | ||
| state_id : Optional[int] | ||
| The Hamiltonian state index to extract. Must be defined if | ||
| ``replica_id`` is not defined. May be negative (see notes below). | ||
| replica_id : Optional[int] | ||
| The replica index to extract. Must be defined if ``state_id`` | ||
| is not defined. May be negative (see notes below). | ||
|
|
||
| Notes | ||
| ----- | ||
| A negative index may be passed to either ``state_id`` or | ||
| ``replica_id``. This will be interpreted as indexing in reverse | ||
| starting from the last state/replica. For example, passing a | ||
| value of -2 for ``replica_id`` will select the before last replica. | ||
| Convert positions to Angstrom. | ||
| """ | ||
| if not ((state_id is None) ^ (replica_id is None)): | ||
| raise ValueError( | ||
| "Specify one and only one of state or replica, " | ||
| f"got state id={state_id} " | ||
| f"replica_id={replica_id}" | ||
| ) | ||
|
|
||
| super().__init__(filename, convert_units, **kwargs) | ||
|
|
||
| if isinstance(filename, nc.Dataset): | ||
|
|
@@ -108,29 +105,37 @@ def __init__(self, filename, convert_units=True, state_id=None, replica_id=None, | |
| self._dataset = nc.Dataset(filename) | ||
| self._dataset_owner = True | ||
|
|
||
| # Handle the negative ID case | ||
| if state_id is not None and state_id < 0: | ||
| state_id = range(self._dataset.dimensions["state"].size)[state_id] | ||
| if view not in {"state", "replica"}: | ||
| raise ValueError(f"View must be 'state' or 'replica', got {view}") | ||
|
|
||
| if replica_id is not None and replica_id < 0: | ||
| replica_id = range(self._dataset.dimensions["replica"].size)[replica_id] | ||
| self._view = view | ||
|
|
||
| self._state_id = state_id | ||
| self._replica_id = replica_id | ||
| # Handle the negative ID case | ||
| if view == "state": | ||
| size = self._dataset.dimensions["state"].size | ||
| else: | ||
| size = self._dataset.dimensions["replica"].size | ||
|
|
||
| self._index = index % size | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would call this |
||
|
|
||
| self._n_atoms = self._dataset.dimensions["atom"].size | ||
| self.ts = Timestep(self._n_atoms) | ||
| self._frames = _determine_position_indices(self._dataset) | ||
| # The MDAnalysis trajectory "dt" is the iteration dt | ||
| # multiplied by the number of iterations between frames. | ||
| self._dt = _determine_iteration_dt(self._dataset) * np.diff(self._frames)[0] | ||
| self._frame_index = -1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't do this, let |
||
| self._read_frame(0) | ||
|
|
||
| @staticmethod | ||
| def _format_hint(thing) -> bool: | ||
| # can pass raw nc datasets through to reduce open/close operations | ||
| return isinstance(thing, nc.Dataset) | ||
|
|
||
| @property | ||
| def index(self) -> int: | ||
| return self._index | ||
|
|
||
| @property | ||
| def n_atoms(self) -> int: | ||
| return self._n_atoms | ||
|
|
@@ -139,6 +144,10 @@ def n_atoms(self) -> int: | |
| def n_frames(self) -> int: | ||
| return len(self._frames) | ||
|
|
||
| @property | ||
| def view(self) -> str: | ||
| return self._view | ||
|
|
||
| @staticmethod | ||
| def parse_n_atoms(filename, **kwargs) -> int: | ||
| with nc.Dataset(filename) as ds: | ||
|
|
@@ -153,17 +162,19 @@ def _read_next_timestep(self, ts=None) -> Timestep: | |
| def _read_frame(self, frame: int) -> Timestep: | ||
| self._frame_index = frame | ||
|
|
||
| if self._state_id is not None: | ||
| frame = self._frames[self._frame_index] | ||
|
|
||
| if self._view == "state": | ||
| rep = multistate._state_to_replica( | ||
| self._dataset, self._state_id, self._frames[self._frame_index] | ||
| self._dataset, | ||
| self._index, | ||
| frame, | ||
| ) | ||
| else: | ||
| rep = self._replica_id | ||
| rep = self._index | ||
|
|
||
| pos = multistate._replica_positions_at_frame( | ||
| self._dataset, rep, self._frames[self._frame_index] | ||
| ) | ||
| dim = multistate._get_unitcell(self._dataset, rep, self._frames[self._frame_index]) | ||
| pos = multistate._replica_positions_at_frame(self._dataset, rep, frame) | ||
| dim = multistate._get_unitcell(self._dataset, rep, frame) | ||
|
|
||
| if pos is None: | ||
| errmsg = ( | ||
|
|
@@ -193,5 +204,7 @@ def _reopen(self): | |
| self._frame_index = -1 | ||
|
|
||
| def close(self): | ||
| if self._dataset_owner: | ||
| self._dataset.close() | ||
| if self._dataset is not None: | ||
| if self._dataset_owner: | ||
| self._dataset.close() | ||
| self._dataset = None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
viewis a bit unclear, I would maybe call it something likeindex_styleorindex_method? That way it's clear you're talking about how it's being indexed.