88from abc import ABC , abstractmethod
99from dataclasses import dataclass
1010from datetime import datetime
11+ from enum import Enum
12+ from io import StringIO
1113from pathlib import Path
1214from types import UnionType
13- from typing import Annotated , Literal , Union , get_args , get_origin
15+ from typing import IO , Annotated , Any , Literal , Union , cast , get_args , get_origin
1416from uuid import UUID
1517
1618from pydantic import BaseModel , SecretStr , TypeAdapter
@@ -34,13 +36,22 @@ class EnvParser(ABC):
3436 def from_env (self , key : str ) -> JsonType :
3537 """Parse environment variables into a json like structure"""
3638
39+ def to_env (self , key : str , value : Any , output : IO ):
40+ """Produce a template based on this parser"""
41+ if value is None :
42+ value = ""
43+ output .write (f"{ key } ={ value } \n " )
44+
3745
3846class BoolEnvParser (EnvParser ):
3947 def from_env (self , key : str ) -> bool | MissingType :
4048 if key not in os .environ :
4149 return MISSING
4250 return os .environ [key ].upper () in ["1" , "TRUE" ] # type: ignore
4351
52+ def to_env (self , key : str , value : Any , output : IO ):
53+ output .write (f"{ key } ={ 1 if value else 0 } \n " )
54+
4455
4556class IntEnvParser (EnvParser ):
4657 def from_env (self , key : str ) -> int | MissingType :
@@ -65,15 +76,40 @@ def from_env(self, key: str) -> str | MissingType:
6576
6677class NoneEnvParser (EnvParser ):
6778 def from_env (self , key : str ) -> None | MissingType :
68- # Perversely, if the key is present this is not none so we consider it missing
69- if key in os .environ :
79+ key = f"{ key } _IS_NONE"
80+ value = (os .getenv (key ) or "" ).upper ()
81+ if value in ["1" , "TRUE" ]:
82+ return None
83+ return MISSING
84+
85+ def to_env (self , key : str , value : Any , output : IO ):
86+ if value is None :
87+ output .write (f"{ key } _IS_NONE=1\n " )
88+
89+
90+ @dataclass
91+ class LiteralEnvParser (EnvParser ):
92+ values : tuple [str , ...]
93+
94+ def from_env (self , key : str ) -> str | MissingType :
95+ value = os .getenv (key )
96+ if value not in self .values :
7097 return MISSING
71- return None
98+ return value
99+
100+ def to_env (self , key : str , value : Any , output : IO ):
101+ output .write (f"# Permitted Values: { ', ' .join (self .values )} \n " )
102+ # For enums, use the value instead of the string representation
103+ if hasattr (value , "value" ):
104+ output .write (f"{ key } ={ value .value } \n " )
105+ else :
106+ output .write (f"{ key } ={ value } \n " )
72107
73108
74109@dataclass
75110class ModelEnvParser (EnvParser ):
76111 parsers : dict [str , EnvParser ]
112+ descriptions : dict [str , str ]
77113
78114 def from_env (self , key : str ) -> dict | MissingType :
79115 # First we see is there a base value defined as json...
@@ -108,6 +144,19 @@ def from_env(self, key: str) -> dict | MissingType:
108144
109145 return result
110146
147+ def to_env (self , key : str , value : Any , output : IO ):
148+ for field_name , parser in self .parsers .items ():
149+ field_description = self .descriptions .get (field_name )
150+ if field_description :
151+ for line in field_description .split ("\n " ):
152+ output .write ("# " )
153+ output .write (line )
154+ output .write ("\n " )
155+ field_key = key + "_" + field_name .upper ()
156+ field_value = getattr (value , field_name )
157+ parser .to_env (field_key , field_value , output )
158+ output .write ("\n " )
159+
111160
112161class DictEnvParser (EnvParser ):
113162 def from_env (self , key : str ) -> dict | MissingType :
@@ -125,6 +174,7 @@ def from_env(self, key: str) -> dict | MissingType:
125174@dataclass
126175class ListEnvParser (EnvParser ):
127176 item_parser : EnvParser
177+ item_type : type
128178
129179 def from_env (self , key : str ) -> list | MissingType :
130180 if key not in os .environ :
@@ -162,18 +212,61 @@ def from_env(self, key: str) -> list | MissingType:
162212
163213 return result
164214
215+ def to_env (self , key : str , value : Any , output : IO ):
216+ if len (value ):
217+ for index , sub_value in enumerate (value ):
218+ sub_key = f"{ key } _{ index } "
219+ self .item_parser .to_env (sub_key , sub_value , output )
220+ else :
221+ # Try to produce a sample value based on the defaults...
222+ try :
223+ sub_key = f"{ key } _0"
224+ sample_output = StringIO ()
225+ self .item_parser .to_env (
226+ sub_key , _create_sample (self .item_type ), sample_output
227+ )
228+ for line in sample_output .getvalue ().strip ().split ("\n " ):
229+ output .write ("# " )
230+ output .write (line )
231+ output .write ("\n " )
232+ except Exception :
233+ # Couldn't create a sample value. Skip
234+ pass
235+
165236
166237@dataclass
167238class UnionEnvParser (EnvParser ):
168- parsers : list [ EnvParser ]
239+ parsers : dict [ type , EnvParser ]
169240
170241 def from_env (self , key : str ) -> JsonType :
171242 result = MISSING
172- for parser in self .parsers :
243+ for parser in self .parsers . values () :
173244 parser_result = parser .from_env (key )
174245 result = merge (result , parser_result )
175246 return result
176247
248+ def to_env (self , key : str , value : Any , output : IO ):
249+ for type_ , parser in self .parsers .items ():
250+ if not isinstance (value , type_ ):
251+ # Try to produce a sample value based on the defaults...
252+ try :
253+ sample_value = _create_sample (type_ )
254+ sample_output = StringIO ()
255+ sample_output .write (f"{ sample_value .__class__ .__name__ } \n " )
256+ parser .to_env (key , sample_value , sample_output )
257+ for line in sample_output .getvalue ().split ("\n " ):
258+ output .write ("# " )
259+ output .write (line )
260+ output .write ("\n " )
261+ except Exception :
262+ # Couldn't create a sample value. Skip
263+ pass
264+ for type_ , parser in self .parsers .items ():
265+ if isinstance (value , type_ ):
266+ output .write (f"# { value .__class__ .__name__ } \n " )
267+ parser .to_env (key , value , output )
268+ output .write ("\n " )
269+
177270
178271@dataclass
179272class DelayedParser (EnvParser ):
@@ -185,6 +278,10 @@ def from_env(self, key: str) -> JsonType:
185278 assert self .parser is not None
186279 return self .parser .from_env (key )
187280
281+ def to_env (self , key : str , value : Any , output : IO ):
282+ assert self .parser is not None
283+ return self .parser .to_env (key , value , output )
284+
188285
189286def merge (a , b ):
190287 if a is MISSING :
@@ -222,22 +319,23 @@ def get_env_parser(target_type: type, parsers: dict[type, EnvParser]) -> EnvPars
222319 # Strip annotations...
223320 return get_env_parser (get_args (target_type )[0 ], parsers )
224321 if origin is UnionType or origin is Union :
225- union_parsers = [
226- get_env_parser (t , parsers ) # type: ignore
322+ union_parsers = {
323+ t : get_env_parser (t , parsers ) # type: ignore
227324 for t in get_args (target_type )
228- ]
325+ }
229326 return UnionEnvParser (union_parsers )
230327 if origin is list :
231- parser = get_env_parser (get_args (target_type )[0 ], parsers )
232- return ListEnvParser (parser )
328+ item_type = get_args (target_type )[0 ]
329+ parser = get_env_parser (item_type , parsers )
330+ return ListEnvParser (parser , item_type )
233331 if origin is dict :
234332 args = get_args (target_type )
235333 assert args [0 ] is str
236334 assert args [1 ] in (str , int , float , bool )
237335 return DictEnvParser ()
238336 if origin is Literal :
239- return StrEnvParser ( )
240-
337+ args = cast ( tuple [ str , ...], get_args ( target_type ) )
338+ return LiteralEnvParser ( args )
241339 if origin and issubclass (origin , BaseModel ):
242340 target_type = origin
243341 if issubclass (target_type , DiscriminatedUnionMixin ) and (
@@ -249,34 +347,65 @@ def get_env_parser(target_type: type, parsers: dict[type, EnvParser]) -> EnvPars
249347 if issubclass (target_type , BaseModel ): # type: ignore
250348 delayed = DelayedParser ()
251349 parsers [target_type ] = delayed # Prevent circular dependency
252- field_parsers = {
253- name : get_env_parser (field .annotation , parsers ) # type: ignore
254- for name , field in target_type .model_fields .items ()
255- }
256- parser = ModelEnvParser (field_parsers )
350+ field_parsers = {}
351+ descriptions = {}
352+ for name , field in target_type .model_fields .items ():
353+ field_parsers [name ] = get_env_parser (field .annotation , parsers ) # type: ignore
354+ description = field .description
355+ if description :
356+ descriptions [name ] = description
357+
358+ parser = ModelEnvParser (field_parsers , descriptions )
257359 delayed .parser = parser
258360 parsers [target_type ] = parser
259361 return parser
362+ if issubclass (target_type , Enum ):
363+ values = tuple (e .value for e in target_type )
364+ return LiteralEnvParser (values )
260365 raise ValueError (f"unknown_type:{ target_type } " )
261366
262367
368+ def _get_default_parsers () -> dict [type , EnvParser ]:
369+ return {
370+ str : StrEnvParser (),
371+ int : IntEnvParser (),
372+ float : FloatEnvParser (),
373+ bool : BoolEnvParser (),
374+ type (None ): NoneEnvParser (),
375+ UUID : StrEnvParser (),
376+ Path : StrEnvParser (),
377+ datetime : StrEnvParser (),
378+ SecretStr : StrEnvParser (),
379+ }
380+
381+
382+ def _create_sample (type_ : type ):
383+ if type_ is None :
384+ return None
385+ if type_ is str :
386+ return "..."
387+ if type_ is int :
388+ return 0
389+ if type_ is float :
390+ return 0.0
391+ if type_ is bool :
392+ return False
393+ try :
394+ if issubclass (type_ , Enum ):
395+ return next (iter (type_ ))
396+ except Exception :
397+ pass
398+ # Try to initialize and raise exception if failure.
399+ return type_ ()
400+
401+
263402def from_env (
264403 target_type : type ,
265404 prefix : str = "" ,
266405 parsers : dict [type , EnvParser ] | None = None ,
267406):
268407 if parsers is None :
269- parsers = {
270- str : StrEnvParser (),
271- int : IntEnvParser (),
272- float : FloatEnvParser (),
273- bool : BoolEnvParser (),
274- type (None ): NoneEnvParser (),
275- UUID : StrEnvParser (),
276- Path : StrEnvParser (),
277- datetime : StrEnvParser (),
278- SecretStr : StrEnvParser (),
279- }
408+ parsers = _get_default_parsers ()
280409 parser = get_env_parser (target_type , parsers )
281410 json_data = parser .from_env (prefix )
282411 if json_data is MISSING :
@@ -286,3 +415,16 @@ def from_env(
286415 type_adapter = TypeAdapter (target_type )
287416 result = type_adapter .validate_json (json_str )
288417 return result
418+
419+
420+ def to_env (
421+ value : Any ,
422+ prefix : str = "" ,
423+ parsers : dict [type , EnvParser ] | None = None ,
424+ ) -> str :
425+ if parsers is None :
426+ parsers = _get_default_parsers ()
427+ parser = get_env_parser (value .__class__ , parsers )
428+ output = StringIO ()
429+ parser .to_env (prefix , value , output )
430+ return output .getvalue ()
0 commit comments