88import typing
99import warnings
1010from abc import ABC , abstractmethod
11+
12+ if sys .version_info >= (3 , 9 ):
13+ from argparse import BooleanOptionalAction
1114from argparse import SUPPRESS , ArgumentParser , Namespace , RawDescriptionHelpFormatter , _SubParsersAction
1215from collections import deque
1316from dataclasses import is_dataclass
@@ -124,6 +127,14 @@ class _CliPositionalArg:
124127 pass
125128
126129
130+ class _CliImplicitFlag :
131+ pass
132+
133+
134+ class _CliExplicitFlag :
135+ pass
136+
137+
127138class _CliInternalArgParser (ArgumentParser ):
128139 def __init__ (self , cli_exit_on_error : bool = True , ** kwargs : Any ) -> None :
129140 super ().__init__ (** kwargs )
@@ -138,6 +149,9 @@ def error(self, message: str) -> NoReturn:
138149T = TypeVar ('T' )
139150CliSubCommand = Annotated [Union [T , None ], _CliSubCommand ]
140151CliPositionalArg = Annotated [T , _CliPositionalArg ]
152+ _CliBoolFlag = TypeVar ('_CliBoolFlag' , bound = bool )
153+ CliImplicitFlag = Annotated [_CliBoolFlag , _CliImplicitFlag ]
154+ CliExplicitFlag = Annotated [_CliBoolFlag , _CliExplicitFlag ]
141155
142156
143157class EnvNoneType (str ):
@@ -905,6 +919,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]):
905919 cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
906920 Defaults to `True`.
907921 cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "".
922+ cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
923+ (e.g. --flag, --no-flag). Defaults to `False`.
908924 case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`.
909925 Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI
910926 subcommands.
@@ -932,6 +948,7 @@ def __init__(
932948 cli_use_class_docs_for_groups : bool | None = None ,
933949 cli_exit_on_error : bool | None = None ,
934950 cli_prefix : str | None = None ,
951+ cli_implicit_flags : bool | None = None ,
935952 case_sensitive : bool | None = True ,
936953 root_parser : Any = None ,
937954 parse_args_method : Callable [..., Any ] | None = ArgumentParser .parse_args ,
@@ -975,6 +992,11 @@ def __init__(
975992 if cli_prefix .startswith ('.' ) or cli_prefix .endswith ('.' ) or not cli_prefix .replace ('.' , '' ).isidentifier (): # type: ignore
976993 raise SettingsError (f'CLI settings source prefix is invalid: { cli_prefix } ' )
977994 self .cli_prefix += '.'
995+ self .cli_implicit_flags = (
996+ cli_implicit_flags
997+ if cli_implicit_flags is not None
998+ else settings_cls .model_config .get ('cli_implicit_flags' , False )
999+ )
9781000
9791001 case_sensitive = case_sensitive if case_sensitive is not None else True
9801002 if not case_sensitive and root_parser is not None :
@@ -1281,6 +1303,23 @@ def _get_resolved_names(
12811303 resolved_names = [resolved_name .lower () for resolved_name in resolved_names ]
12821304 return tuple (dict .fromkeys (resolved_names )), is_alias_path_only
12831305
1306+ def _verify_cli_flag_annotations (self , model : type [BaseModel ], field_name : str , field_info : FieldInfo ) -> None :
1307+ if _CliImplicitFlag in field_info .metadata :
1308+ cli_flag_name = 'CliImplicitFlag'
1309+ elif _CliExplicitFlag in field_info .metadata :
1310+ cli_flag_name = 'CliExplicitFlag'
1311+ else :
1312+ return
1313+
1314+ if field_info .annotation is not bool :
1315+ raise SettingsError (f'{ cli_flag_name } argument { model .__name__ } .{ field_name } is not of type bool' )
1316+ elif sys .version_info < (3 , 9 ) and (
1317+ field_info .default is PydanticUndefined and field_info .default_factory is None
1318+ ):
1319+ raise SettingsError (
1320+ f'{ cli_flag_name } argument { model .__name__ } .{ field_name } must have default for python versions < 3.9'
1321+ )
1322+
12841323 def _sort_arg_fields (self , model : type [BaseModel ]) -> list [tuple [str , FieldInfo ]]:
12851324 positional_args , subcommand_args , optional_args = [], [], []
12861325 fields = (
@@ -1310,6 +1349,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
13101349 raise SettingsError (f'positional argument { model .__name__ } .{ field_name } has an alias' )
13111350 positional_args .append ((field_name , field_info ))
13121351 else :
1352+ self ._verify_cli_flag_annotations (model , field_name , field_info )
13131353 optional_args .append ((field_name , field_info ))
13141354 return positional_args + subcommand_args + optional_args
13151355
@@ -1457,6 +1497,8 @@ def _add_parser_args(
14571497 del kwargs ['required' ]
14581498 arg_flag = ''
14591499
1500+ self ._convert_bool_flag (kwargs , field_info , model_default )
1501+
14601502 if sub_models and kwargs .get ('action' ) != 'append' :
14611503 self ._add_parser_submodels (
14621504 parser ,
@@ -1486,6 +1528,22 @@ def _add_parser_args(
14861528 self ._add_parser_alias_paths (parser , alias_path_args , added_args , arg_prefix , subcommand_prefix , group )
14871529 return parser
14881530
1531+ def _convert_bool_flag (self , kwargs : dict [str , Any ], field_info : FieldInfo , model_default : Any ) -> None :
1532+ if kwargs ['metavar' ] == 'bool' :
1533+ default = None
1534+ if field_info .default is not PydanticUndefined :
1535+ default = field_info .default
1536+ if model_default is not PydanticUndefined :
1537+ default = model_default
1538+ if sys .version_info >= (3 , 9 ) or isinstance (default , bool ):
1539+ if (self .cli_implicit_flags or _CliImplicitFlag in field_info .metadata ) and (
1540+ _CliExplicitFlag not in field_info .metadata
1541+ ):
1542+ del kwargs ['metavar' ]
1543+ kwargs ['action' ] = (
1544+ BooleanOptionalAction if sys .version_info >= (3 , 9 ) else f'store_{ str (not default ).lower ()} '
1545+ )
1546+
14891547 def _get_arg_names (
14901548 self , arg_prefix : str , subcommand_prefix : str , alias_prefixes : list [str ], resolved_names : tuple [str , ...]
14911549 ) -> list [str ]:
0 commit comments