Skip to content

Commit ac173e6

Browse files
authored
configurable write context (#276)
First pass at #256. WriteContext class inspired by imod-python's, to control how MODFLOW 6 input files are written, including float precision, binary vs ASCII format, path handling, etc. There are 2 usage patterns: 1. Pass to `write()`: `sim.write(context=WriteContext(float_precision=4))` 2. Context manager: `with WriteContext(float_precision=8): sim.write()` Nested contexts are supported, with inner contexts overwriting settings from outer ones. The context manager approach is thread-local so should support multi-threaded writing later on. Precision defaults to 8 decimal places to match the current default in the writer jinja filters.
1 parent 41a8948 commit ac173e6

File tree

8 files changed

+457
-34
lines changed

8 files changed

+457
-34
lines changed

flopy4/mf6/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@ def _load_toml(path: Path) -> Component:
4040
return structure(load_toml(fp), path)
4141

4242

43-
def _write_mf6(component: Component) -> None:
43+
def _write_mf6(component: Component, context=None, **kwargs) -> None:
44+
from flopy4.mf6.write_context import WriteContext
45+
46+
# Use provided context or default
47+
ctx = context if context is not None else WriteContext.default()
48+
4449
with open(component.path, "w") as fp:
4550
data = unstructure(component)
4651
try:
47-
dump_mf6(data, fp)
52+
dump_mf6(data, fp, context=ctx)
4853
except Exception as e:
4954
raise WriteError(
5055
f"Failed to write MF6 format file for component '{component.name}' " # type: ignore

flopy4/mf6/codec/writer/__init__.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,69 @@
1919
_JINJA_ENV.filters["array2string"] = writer_filters.array2string
2020
_JINJA_ENV.filters["data2list"] = writer_filters.data2list
2121
_JINJA_TEMPLATE_NAME = "blocks.jinja"
22-
_PRINT_OPTIONS = {
23-
"precision": 4,
24-
"linewidth": sys.maxsize,
25-
"threshold": sys.maxsize,
26-
}
2722

2823

29-
def dumps(data) -> str:
24+
def _get_print_options(context=None):
25+
"""Get numpy print options from WriteContext."""
26+
if context is not None:
27+
return context.to_numpy_printoptions()
28+
# Default options
29+
return {
30+
"precision": 4,
31+
"linewidth": sys.maxsize,
32+
"threshold": sys.maxsize,
33+
}
34+
35+
36+
def dumps(data, context=None) -> str:
37+
"""
38+
Serialize data to MF6 format string.
39+
40+
Parameters
41+
----------
42+
data : dict
43+
Data to serialize
44+
context : WriteContext, optional
45+
Configuration context for writing
46+
47+
Returns
48+
-------
49+
str
50+
Serialized MF6 format string
51+
"""
52+
from flopy4.mf6.write_context import WriteContext
53+
54+
if context is None:
55+
context = WriteContext.default()
56+
3057
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
31-
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
32-
return template.render(blocks=data)
58+
print_opts = _get_print_options(context)
59+
with np.printoptions(**print_opts): # type: ignore
60+
result = template.render(blocks=data, context=context)
61+
62+
return result
63+
64+
65+
def dump(data, fp: IO[str], context=None) -> None:
66+
"""
67+
Serialize data to MF6 format and write to file.
68+
69+
Parameters
70+
----------
71+
data : dict
72+
Data to serialize
73+
fp : IO[str]
74+
File pointer to write to
75+
context : WriteContext, optional
76+
Configuration context for writing
77+
"""
78+
from flopy4.mf6.write_context import WriteContext
3379

80+
if context is None:
81+
context = WriteContext.default()
3482

35-
def dump(data, fp: IO[str]) -> None:
3683
template = _JINJA_ENV.get_template(_JINJA_TEMPLATE_NAME)
37-
iterator = template.generate(blocks=data)
38-
with np.printoptions(**_PRINT_OPTIONS): # type: ignore
84+
iterator = template.generate(blocks=data, context=context)
85+
print_opts = _get_print_options(context)
86+
with np.printoptions(**print_opts): # type: ignore
3987
fp.writelines(iterator)

flopy4/mf6/codec/writer/filters.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,21 @@ def array_how(value: xr.DataArray) -> ArrayHow:
4141
raise ValueError(f"Arrays with ndim > 3 are not supported, got ndim={value.ndim}")
4242

4343

44-
def array2const(value: xr.DataArray) -> Scalar:
44+
def array2const(value: xr.DataArray, precision: int = 8) -> Scalar:
45+
"""
46+
Convert array to constant scalar value.
47+
48+
Parameters
49+
----------
50+
value : xr.DataArray
51+
Array to convert
52+
precision : int, optional
53+
Number of decimal places for float output. Default is 8.
54+
"""
4555
if np.issubdtype(value.dtype, np.integer):
4656
return value.max().item()
4757
if np.issubdtype(value.dtype, np.floating):
48-
return f"{value.max().item():.8f}"
58+
return f"{value.max().item():.{precision}f}"
4959
return value.ravel()[0]
5060

5161

@@ -95,14 +105,21 @@ def array2chunks(value: xr.DataArray, chunks: Mapping[Hashable, int] | None = No
95105
yield np.squeeze(value.values)
96106

97107

98-
def array2string(value: NDArray) -> str:
108+
def array2string(value: NDArray, precision: int = 9) -> str:
99109
"""
100110
Convert an array to a string. The array can be 1D or 2D.
101111
If the array is 1D, it is converted to a 1-line string,
102112
with elements separated by whitespace. If the array is
103113
2D, each row becomes a line in the string.
104114
105115
Used for writing array-based input to MF6 input files.
116+
117+
Parameters
118+
----------
119+
value : NDArray
120+
Array to convert
121+
precision : int, optional
122+
Number of decimal places for float output. Default is 9.
106123
"""
107124
buffer = StringIO()
108125
value = np.asarray(value)
@@ -112,13 +129,14 @@ def array2string(value: NDArray) -> str:
112129
# add an axis to 1d arrays so np.savetxt writes elements on 1 line
113130
value = value[None]
114131
value = np.atleast_1d(value)
115-
format = (
116-
"%d"
117-
if np.issubdtype(value.dtype, np.integer)
118-
else "%.9e"
119-
if np.issubdtype(value.dtype, np.floating)
120-
else "%s"
121-
)
132+
133+
if np.issubdtype(value.dtype, np.floating):
134+
format = f"%.{precision}e"
135+
elif np.issubdtype(value.dtype, np.integer):
136+
format = "%d"
137+
else:
138+
format = "%s"
139+
122140
np.savetxt(buffer, value, fmt=format, delimiter=" ")
123141
return buffer.getvalue().strip()
124142

flopy4/mf6/codec/writer/templates/macros.jinja

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@
3232
{{ inset ~ name.upper() }}{% if "layered" in how %} LAYERED{% endif %}
3333

3434
{% if how == "constant" %}
35-
{{ inset }}CONSTANT {{ value|array2const }}
35+
{{ inset }}CONSTANT {{ value|array2const(context.float_precision) }}
3636
{% elif how == "layered constant" %}
3737
{% for layer in value -%}
38-
{{ "\n" ~ inset }}CONSTANT {{ layer|array2const }}
38+
{{ "\n" ~ inset }}CONSTANT {{ layer|array2const(context.float_precision) }}
3939
{%- endfor %}
4040
{% elif how == "layered internal" %}
4141
{% for layer in value %}
4242

4343
{{ inset }}INTERNAL
4444
{% for chunk in layer|array2chunks -%}
45-
{{ (2 * inset) ~ chunk|array2string }}
45+
{{ (2 * inset) ~ chunk|array2string(context.float_precision) }}
4646
{%- endfor %}
4747
{%- endfor %}
4848
{% elif how == "internal" %}
4949
{{ inset }}INTERNAL
5050
{% for chunk in value|array2chunks -%}
51-
{{ (2 * inset) ~ chunk|array2string }}
51+
{{ (2 * inset) ~ chunk|array2string(context.float_precision) }}
5252
{%- endfor %}
5353
{% elif how == "external" %}
5454
OPEN/CLOSE {{ value }}

flopy4/mf6/component.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC
22
from collections.abc import MutableMapping
33
from pathlib import Path
4-
from typing import Any, ClassVar
4+
from typing import Any, ClassVar, Optional
55

66
from attrs import fields
77
from modflow_devtools.dfn import Dfn, Field
@@ -12,6 +12,7 @@
1212
from flopy4.mf6.constants import MF6
1313
from flopy4.mf6.spec import field, fields_dict, to_field
1414
from flopy4.mf6.utils.grid import update_maxbound
15+
from flopy4.mf6.write_context import WriteContext
1516
from flopy4.uio import IO, Loader, Writer
1617

1718
COMPONENTS = {}
@@ -142,16 +143,31 @@ def load(self, format: str = MF6) -> None:
142143
for child in self.children.values(): # type: ignore
143144
child.load(format=format)
144145

145-
def write(self, format: str = MF6) -> None:
146-
"""Write the component and any children."""
146+
def write(self, format: str = MF6, context: Optional[WriteContext] = None) -> None:
147+
"""
148+
Write the component and any children.
149+
150+
Parameters
151+
----------
152+
format : str, optional
153+
Output format. Default is MF6.
154+
context : WriteContext, optional
155+
Configuration context for writing. If not provided,
156+
uses the current context from the context manager stack,
157+
or default settings.
158+
"""
147159
# TODO: setting filename is a temp hack to get the parent's
148160
# name as this component's filename stem, if it has one. an
149161
# actual solution is to auto-set the filename when children
150162
# are attached to parents.
151163
self.filename = self.filename or self.default_filename()
152-
self._write(format=format)
164+
165+
# Determine active context: provided > current > default
166+
active_context = context or WriteContext.current()
167+
168+
self._write(format=format, context=active_context)
153169
for child in self.children.values(): # type: ignore
154-
child.write(format=format)
170+
child.write(format=format, context=context)
155171

156172
def to_dict(self, blocks: bool = False, strict: bool = False) -> dict[str, Any]:
157173
"""

flopy4/mf6/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def load(self, format=MF6):
5252
with cd(self.workspace):
5353
super().load(format=format)
5454

55-
def write(self, format=MF6):
55+
def write(self, format=MF6, context=None):
5656
with cd(self.workspace):
57-
super().write(format=format)
57+
super().write(format=format, context=context)
5858

5959
def to_xarray(self):
6060
return self.data # type: ignore

0 commit comments

Comments
 (0)