diff --git a/changelog.md b/changelog.md index 0e69cbe7..eb245d70 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +-------- +* Add a `--checkpoint=` argument to log successful queries in batch mode. + + Bug Fixes -------- * Fix timediff output when the result is a negative value (#1113). diff --git a/mycli/main.py b/mycli/main.py index 73593c3d..46893d4e 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1322,7 +1322,12 @@ def get_prompt(self, string: str) -> str: string = string.replace("\\_", " ") return string - def run_query(self, query: str, new_line: bool = True) -> None: + def run_query( + self, + query: str, + checkpoint: TextIOWrapper | None = None, + new_line: bool = True, + ) -> None: """Runs *query*.""" assert self.sqlexecute is not None self.log_query(query) @@ -1362,6 +1367,9 @@ def run_query(self, query: str, new_line: bool = True) -> None: ) for line in output: click.echo(line, nl=new_line) + if checkpoint: + checkpoint.write(query.rstrip('\n') + '\n') + checkpoint.flush() def format_output( self, @@ -1507,6 +1515,9 @@ def get_last_query(self) -> str | None: @click.option("--ssh-warning-off", is_flag=True, help="Suppress the SSH deprecation notice.") @click.option("-R", "--prompt", "prompt", help=f'Prompt format (Default: "{MyCli.default_prompt}").') @click.option("-l", "--logfile", type=click.File(mode="a", encoding="utf-8"), help="Log every query and its results to a file.") +@click.option( + "--checkpoint", type=click.File(mode="a", encoding="utf-8"), help="In batch or --execute mode, log successful queries to a file." +) @click.option("--defaults-group-suffix", type=str, help="Read MySQL config groups with the specified suffix.") @click.option("--defaults-file", type=click.Path(), help="Only read MySQL options from the given file.") @click.option("--myclirc", type=click.Path(), default="~/.myclirc", help="Location of myclirc file.") @@ -1550,6 +1561,7 @@ def cli( verbose: bool, prompt: str | None, logfile: TextIOWrapper | None, + checkpoint: TextIOWrapper | None, defaults_group_suffix: str | None, defaults_file: str | None, login_path: str | None, @@ -1876,7 +1888,7 @@ def cli( else: mycli.main_formatter.format_name = 'tsv' - mycli.run_query(execute) + mycli.run_query(execute, checkpoint=checkpoint) sys.exit(0) except Exception as e: click.secho(str(e), err=True, fg="red") @@ -1919,7 +1931,7 @@ def cli( sys.exit(1) try: if warn_confirmed: - mycli.run_query(stdin_text, new_line=True) + mycli.run_query(stdin_text, checkpoint=checkpoint, new_line=True) except Exception as e: click.secho(str(e), err=True, fg="red") sys.exit(1) diff --git a/test/test_main.py b/test/test_main.py index ebbed6c7..66a2ef85 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -346,6 +346,42 @@ def test_execute_arg(executor): assert expected in result.output +@dbtest +def test_execute_arg_with_checkpoint(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + + with NamedTemporaryFile(mode="w", delete=False) as checkpoint: + checkpoint.close() + + result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql, f"--checkpoint={checkpoint.name}"]) + assert result.exit_code == 0 + + with open(checkpoint.name, 'r') as f: + contents = f.read() + assert sql in contents + os.remove(checkpoint.name) + + sql = 'select 10 from nonexistent_table;' + result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql, f"--checkpoint={checkpoint.name}"]) + assert result.exit_code != 0 + + with open(checkpoint.name, 'r') as f: + contents = f.read() + assert sql not in contents + + # delete=False means we should try to clean up + # we don't really need "try" here as open() would have already failed + try: + if os.path.exists(checkpoint.name): + os.remove(checkpoint.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + @dbtest def test_execute_arg_with_table(executor): run(executor, "create table test (a text)")