Skip to content

Commit 5916868

Browse files
Env template generator (#847)
Co-authored-by: openhands <openhands@all-hands.dev>
1 parent 1a76d34 commit 5916868

File tree

3 files changed

+575
-40
lines changed

3 files changed

+575
-40
lines changed

openhands-agent-server/openhands/agent_server/env_parser.py

Lines changed: 171 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass
1010
from datetime import datetime
11+
from enum import Enum
12+
from io import StringIO
1113
from pathlib import Path
1214
from types import UnionType
13-
from typing import Annotated, Literal, Union, get_args, get_origin
15+
from typing import IO, Annotated, Any, Literal, Union, cast, get_args, get_origin
1416
from uuid import UUID
1517

1618
from pydantic import BaseModel, SecretStr, TypeAdapter
@@ -34,13 +36,22 @@ class EnvParser(ABC):
3436
def from_env(self, key: str) -> JsonType:
3537
"""Parse environment variables into a json like structure"""
3638

39+
def to_env(self, key: str, value: Any, output: IO):
40+
"""Produce a template based on this parser"""
41+
if value is None:
42+
value = ""
43+
output.write(f"{key}={value}\n")
44+
3745

3846
class BoolEnvParser(EnvParser):
3947
def from_env(self, key: str) -> bool | MissingType:
4048
if key not in os.environ:
4149
return MISSING
4250
return os.environ[key].upper() in ["1", "TRUE"] # type: ignore
4351

52+
def to_env(self, key: str, value: Any, output: IO):
53+
output.write(f"{key}={1 if value else 0}\n")
54+
4455

4556
class IntEnvParser(EnvParser):
4657
def from_env(self, key: str) -> int | MissingType:
@@ -65,15 +76,40 @@ def from_env(self, key: str) -> str | MissingType:
6576

6677
class NoneEnvParser(EnvParser):
6778
def from_env(self, key: str) -> None | MissingType:
68-
# Perversely, if the key is present this is not none so we consider it missing
69-
if key in os.environ:
79+
key = f"{key}_IS_NONE"
80+
value = (os.getenv(key) or "").upper()
81+
if value in ["1", "TRUE"]:
82+
return None
83+
return MISSING
84+
85+
def to_env(self, key: str, value: Any, output: IO):
86+
if value is None:
87+
output.write(f"{key}_IS_NONE=1\n")
88+
89+
90+
@dataclass
91+
class LiteralEnvParser(EnvParser):
92+
values: tuple[str, ...]
93+
94+
def from_env(self, key: str) -> str | MissingType:
95+
value = os.getenv(key)
96+
if value not in self.values:
7097
return MISSING
71-
return None
98+
return value
99+
100+
def to_env(self, key: str, value: Any, output: IO):
101+
output.write(f"# Permitted Values: {', '.join(self.values)}\n")
102+
# For enums, use the value instead of the string representation
103+
if hasattr(value, "value"):
104+
output.write(f"{key}={value.value}\n")
105+
else:
106+
output.write(f"{key}={value}\n")
72107

73108

74109
@dataclass
75110
class ModelEnvParser(EnvParser):
76111
parsers: dict[str, EnvParser]
112+
descriptions: dict[str, str]
77113

78114
def from_env(self, key: str) -> dict | MissingType:
79115
# First we see is there a base value defined as json...
@@ -108,6 +144,19 @@ def from_env(self, key: str) -> dict | MissingType:
108144

109145
return result
110146

147+
def to_env(self, key: str, value: Any, output: IO):
148+
for field_name, parser in self.parsers.items():
149+
field_description = self.descriptions.get(field_name)
150+
if field_description:
151+
for line in field_description.split("\n"):
152+
output.write("# ")
153+
output.write(line)
154+
output.write("\n")
155+
field_key = key + "_" + field_name.upper()
156+
field_value = getattr(value, field_name)
157+
parser.to_env(field_key, field_value, output)
158+
output.write("\n")
159+
111160

112161
class DictEnvParser(EnvParser):
113162
def from_env(self, key: str) -> dict | MissingType:
@@ -125,6 +174,7 @@ def from_env(self, key: str) -> dict | MissingType:
125174
@dataclass
126175
class ListEnvParser(EnvParser):
127176
item_parser: EnvParser
177+
item_type: type
128178

129179
def from_env(self, key: str) -> list | MissingType:
130180
if key not in os.environ:
@@ -162,18 +212,61 @@ def from_env(self, key: str) -> list | MissingType:
162212

163213
return result
164214

215+
def to_env(self, key: str, value: Any, output: IO):
216+
if len(value):
217+
for index, sub_value in enumerate(value):
218+
sub_key = f"{key}_{index}"
219+
self.item_parser.to_env(sub_key, sub_value, output)
220+
else:
221+
# Try to produce a sample value based on the defaults...
222+
try:
223+
sub_key = f"{key}_0"
224+
sample_output = StringIO()
225+
self.item_parser.to_env(
226+
sub_key, _create_sample(self.item_type), sample_output
227+
)
228+
for line in sample_output.getvalue().strip().split("\n"):
229+
output.write("# ")
230+
output.write(line)
231+
output.write("\n")
232+
except Exception:
233+
# Couldn't create a sample value. Skip
234+
pass
235+
165236

166237
@dataclass
167238
class UnionEnvParser(EnvParser):
168-
parsers: list[EnvParser]
239+
parsers: dict[type, EnvParser]
169240

170241
def from_env(self, key: str) -> JsonType:
171242
result = MISSING
172-
for parser in self.parsers:
243+
for parser in self.parsers.values():
173244
parser_result = parser.from_env(key)
174245
result = merge(result, parser_result)
175246
return result
176247

248+
def to_env(self, key: str, value: Any, output: IO):
249+
for type_, parser in self.parsers.items():
250+
if not isinstance(value, type_):
251+
# Try to produce a sample value based on the defaults...
252+
try:
253+
sample_value = _create_sample(type_)
254+
sample_output = StringIO()
255+
sample_output.write(f"{sample_value.__class__.__name__}\n")
256+
parser.to_env(key, sample_value, sample_output)
257+
for line in sample_output.getvalue().split("\n"):
258+
output.write("# ")
259+
output.write(line)
260+
output.write("\n")
261+
except Exception:
262+
# Couldn't create a sample value. Skip
263+
pass
264+
for type_, parser in self.parsers.items():
265+
if isinstance(value, type_):
266+
output.write(f"# {value.__class__.__name__}\n")
267+
parser.to_env(key, value, output)
268+
output.write("\n")
269+
177270

178271
@dataclass
179272
class DelayedParser(EnvParser):
@@ -185,6 +278,10 @@ def from_env(self, key: str) -> JsonType:
185278
assert self.parser is not None
186279
return self.parser.from_env(key)
187280

281+
def to_env(self, key: str, value: Any, output: IO):
282+
assert self.parser is not None
283+
return self.parser.to_env(key, value, output)
284+
188285

189286
def merge(a, b):
190287
if a is MISSING:
@@ -222,22 +319,23 @@ def get_env_parser(target_type: type, parsers: dict[type, EnvParser]) -> EnvPars
222319
# Strip annotations...
223320
return get_env_parser(get_args(target_type)[0], parsers)
224321
if origin is UnionType or origin is Union:
225-
union_parsers = [
226-
get_env_parser(t, parsers) # type: ignore
322+
union_parsers = {
323+
t: get_env_parser(t, parsers) # type: ignore
227324
for t in get_args(target_type)
228-
]
325+
}
229326
return UnionEnvParser(union_parsers)
230327
if origin is list:
231-
parser = get_env_parser(get_args(target_type)[0], parsers)
232-
return ListEnvParser(parser)
328+
item_type = get_args(target_type)[0]
329+
parser = get_env_parser(item_type, parsers)
330+
return ListEnvParser(parser, item_type)
233331
if origin is dict:
234332
args = get_args(target_type)
235333
assert args[0] is str
236334
assert args[1] in (str, int, float, bool)
237335
return DictEnvParser()
238336
if origin is Literal:
239-
return StrEnvParser()
240-
337+
args = cast(tuple[str, ...], get_args(target_type))
338+
return LiteralEnvParser(args)
241339
if origin and issubclass(origin, BaseModel):
242340
target_type = origin
243341
if issubclass(target_type, DiscriminatedUnionMixin) and (
@@ -249,34 +347,65 @@ def get_env_parser(target_type: type, parsers: dict[type, EnvParser]) -> EnvPars
249347
if issubclass(target_type, BaseModel): # type: ignore
250348
delayed = DelayedParser()
251349
parsers[target_type] = delayed # Prevent circular dependency
252-
field_parsers = {
253-
name: get_env_parser(field.annotation, parsers) # type: ignore
254-
for name, field in target_type.model_fields.items()
255-
}
256-
parser = ModelEnvParser(field_parsers)
350+
field_parsers = {}
351+
descriptions = {}
352+
for name, field in target_type.model_fields.items():
353+
field_parsers[name] = get_env_parser(field.annotation, parsers) # type: ignore
354+
description = field.description
355+
if description:
356+
descriptions[name] = description
357+
358+
parser = ModelEnvParser(field_parsers, descriptions)
257359
delayed.parser = parser
258360
parsers[target_type] = parser
259361
return parser
362+
if issubclass(target_type, Enum):
363+
values = tuple(e.value for e in target_type)
364+
return LiteralEnvParser(values)
260365
raise ValueError(f"unknown_type:{target_type}")
261366

262367

368+
def _get_default_parsers() -> dict[type, EnvParser]:
369+
return {
370+
str: StrEnvParser(),
371+
int: IntEnvParser(),
372+
float: FloatEnvParser(),
373+
bool: BoolEnvParser(),
374+
type(None): NoneEnvParser(),
375+
UUID: StrEnvParser(),
376+
Path: StrEnvParser(),
377+
datetime: StrEnvParser(),
378+
SecretStr: StrEnvParser(),
379+
}
380+
381+
382+
def _create_sample(type_: type):
383+
if type_ is None:
384+
return None
385+
if type_ is str:
386+
return "..."
387+
if type_ is int:
388+
return 0
389+
if type_ is float:
390+
return 0.0
391+
if type_ is bool:
392+
return False
393+
try:
394+
if issubclass(type_, Enum):
395+
return next(iter(type_))
396+
except Exception:
397+
pass
398+
# Try to initialize and raise exception if failure.
399+
return type_()
400+
401+
263402
def from_env(
264403
target_type: type,
265404
prefix: str = "",
266405
parsers: dict[type, EnvParser] | None = None,
267406
):
268407
if parsers is None:
269-
parsers = {
270-
str: StrEnvParser(),
271-
int: IntEnvParser(),
272-
float: FloatEnvParser(),
273-
bool: BoolEnvParser(),
274-
type(None): NoneEnvParser(),
275-
UUID: StrEnvParser(),
276-
Path: StrEnvParser(),
277-
datetime: StrEnvParser(),
278-
SecretStr: StrEnvParser(),
279-
}
408+
parsers = _get_default_parsers()
280409
parser = get_env_parser(target_type, parsers)
281410
json_data = parser.from_env(prefix)
282411
if json_data is MISSING:
@@ -286,3 +415,16 @@ def from_env(
286415
type_adapter = TypeAdapter(target_type)
287416
result = type_adapter.validate_json(json_str)
288417
return result
418+
419+
420+
def to_env(
421+
value: Any,
422+
prefix: str = "",
423+
parsers: dict[type, EnvParser] | None = None,
424+
) -> str:
425+
if parsers is None:
426+
parsers = _get_default_parsers()
427+
parser = get_env_parser(value.__class__, parsers)
428+
output = StringIO()
429+
parser.to_env(prefix, value, output)
430+
return output.getvalue()

scripts/build_config_template.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Generate a .env file containing all config options
4+
"""
5+
6+
import argparse
7+
8+
from openhands.agent_server.config import get_default_config
9+
from openhands.agent_server.env_parser import to_env
10+
11+
12+
if __name__ == "__main__":
13+
parser = argparse.ArgumentParser(
14+
description="Generate a .env file containing all config options"
15+
)
16+
parser.add_argument("--file", default=".env", help="File path")
17+
args = parser.parse_args()
18+
print(f"🛠️ Building: {args.file}")
19+
with open(args.file, "w") as f:
20+
content = to_env(get_default_config(), "OH")
21+
f.write(content)

0 commit comments

Comments
 (0)