@@ -263,16 +263,6 @@ def run_and_combine_outputs(command, *args):
263263 return subprocess .check_output (command_string , stderr = subprocess .STDOUT )
264264
265265
266- def find_endpoint (argv , shortcuts = {}):
267- # endpoint is first positional argument
268- pattern = re .compile ("(.*https?://.*|[a-zA-Z]:\\ .*)" )
269- indices = []
270- for index , arg in enumerate (argv ):
271- if arg in shortcuts or (Endpoint .is_endpoint (arg ) and not pattern .match (arg )):
272- indices .append (index )
273- return - 1 if len (indices ) == 0 else indices [- 1 ]
274-
275-
276266_default_log_levels = (
277267 "NOTSET" ,
278268 "DEBUG" ,
@@ -285,6 +275,45 @@ def find_endpoint(argv, shortcuts={}):
285275)
286276
287277
278+ class CustomArgParser (argparse .ArgumentParser ):
279+ def __init__ (self , * args , ** kwargs ):
280+ super ().__init__ (* args , ** kwargs )
281+ self ._found_unknown_hyphenated_args = False
282+ self ._found_endpoint = False
283+ self ._found_optionals = []
284+
285+ def _match_arguments_partial (self , actions , arg_strings_pattern ):
286+ # Doesnt support --additional-endpoints yet
287+ result = []
288+ args_after_double_equals = len (arg_strings_pattern .partition ("-" )[2 ])
289+ for i , arg_string in enumerate (self ._found_optionals ):
290+ if Endpoint .is_endpoint (arg_string ):
291+ rv = [
292+ i ,
293+ 1 ,
294+ len (self ._found_optionals ) - i - 1 + args_after_double_equals ,
295+ ]
296+ return rv
297+ return result
298+
299+ def _parse_optional (self , arg_string ):
300+ if arg_string .startswith ("-" ) and arg_string not in self ._option_string_actions :
301+ self ._found_unknown_hyphenated_args = True
302+ elif Endpoint .is_endpoint (arg_string ):
303+ self ._found_endpoint = True
304+
305+ if self ._found_unknown_hyphenated_args or self ._found_endpoint :
306+ self ._found_optionals .append (arg_string )
307+ return None
308+
309+ rv = super ()._parse_optional (arg_string )
310+ return rv
311+
312+ def error (self , message ):
313+ if message == "the following arguments are required: <endpoint>" :
314+ raise NoEndpointProvided ([])
315+
316+
288317def jgo_parser (log_levels = _default_log_levels ):
289318 usage = (
290319 "usage: jgo [-v] [-u] [-U] [-m] [-q] [--log-level] [--ignore-jgorc]\n "
@@ -307,7 +336,8 @@ def jgo_parser(log_levels=_default_log_levels):
307336and it will be auto-completed.
308337"""
309338
310- parser = argparse .ArgumentParser (
339+ parser = CustomArgParser (
340+ prog = "jgo" ,
311341 description = "Run Java main class from Maven coordinates." ,
312342 usage = usage [len ("usage: " ) :],
313343 epilog = epilog ,
@@ -376,6 +406,25 @@ def jgo_parser(log_levels=_default_log_levels):
376406 parser .add_argument (
377407 "--log-level" , default = None , type = str , help = "Set log level" , choices = log_levels
378408 )
409+ parser .add_argument (
410+ "jvm_args" ,
411+ help = "JVM arguments" ,
412+ metavar = "jvm-args" ,
413+ nargs = "*" ,
414+ default = [],
415+ )
416+ parser .add_argument (
417+ "endpoint" ,
418+ help = "Endpoint" ,
419+ metavar = "<endpoint>" ,
420+ )
421+ parser .add_argument (
422+ "program_args" ,
423+ help = "Program arguments" ,
424+ metavar = "main-args" ,
425+ nargs = "*" ,
426+ default = [],
427+ )
379428
380429 return parser
381430
@@ -719,15 +768,18 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
719768 repositories = config ["repositories" ]
720769 shortcuts = config ["shortcuts" ]
721770
722- endpoint_index = find_endpoint (argv , shortcuts )
723- if endpoint_index == - 1 :
724- raise HelpRequested (
725- argv
726- ) if "-h" in argv or "--help" in argv else NoEndpointProvided (argv )
771+ if "-h" in argv or "--help" in argv :
772+ raise HelpRequested (argv )
773+
774+ args = parser .parse_args (argv )
775+
776+ if not args .endpoint :
777+ raise NoEndpointProvided (argv )
778+ if args .endpoint in shortcuts and not Endpoint .is_endpoint (args .endpoint ):
779+ raise NoEndpointProvided (argv )
727780
728- args , unknown = parser .parse_known_args (argv [:endpoint_index ])
729- jvm_args = unknown if unknown else []
730- program_args = [] if endpoint_index == - 1 else argv [endpoint_index + 1 :]
781+ jvm_args = args .jvm_args
782+ program_args = args .program_args
731783 if args .log_level :
732784 logging .getLogger ().setLevel (logging .getLevelName (args .log_level ))
733785
@@ -757,7 +809,7 @@ def run(parser, argv=sys.argv[1:], stdout=None, stderr=None):
757809 if args .force_update :
758810 args .update_cache = True
759811
760- endpoint_string = "+" .join ([argv [ endpoint_index ] ] + args .additional_endpoints )
812+ endpoint_string = "+" .join ([args . endpoint ] + args .additional_endpoints )
761813
762814 primary_endpoint , workspace = resolve_dependencies (
763815 endpoint_string ,
0 commit comments