diff --git a/changelog.md b/changelog.md index 511f2438..7b943c81 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ Features * More complete and up-to-date set of MySQL reserved words for completions. * Place exact-leading completions first. * Allow history file location to be configured. +* Make destructive-warning keywords configurable. Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index a785146f..8c82bba9 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -220,6 +220,10 @@ def __init__( self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt self.multiline_continuation_char = c["main"]["prompt_continuation"] self.prompt_app = None + self.destructive_keywords = [ + keyword for keyword in c["main"].get("destructive_keywords", "DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE").split(' ') if keyword + ] + special.set_destructive_keywords(self.destructive_keywords) def close(self) -> None: if self.sqlexecute is not None: @@ -346,7 +350,7 @@ def execute_from_file(self, arg: str, **_) -> Iterable[SQLResult]: except IOError as e: return [SQLResult(status=str(e))] - if self.destructive_warning and confirm_destructive_query(query) is False: + if self.destructive_warning and confirm_destructive_query(self.destructive_keywords, query) is False: message = "Wise choice. Command execution stopped." return [SQLResult(status=message)] @@ -977,7 +981,7 @@ def one_iteration(text: str | None = None) -> None: return if self.destructive_warning: - destroy = confirm_destructive_query(text) + destroy = confirm_destructive_query(self.destructive_keywords, text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: @@ -1852,10 +1856,10 @@ def cli( click.secho("Sorry... :(", err=True, fg="red") sys.exit(1) - if mycli.destructive_warning and is_destructive(stdin_text): + if mycli.destructive_warning and is_destructive(mycli.destructive_keywords, stdin_text): try: sys.stdin = open("/dev/tty") - warn_confirmed = confirm_destructive_query(stdin_text) + warn_confirmed = confirm_destructive_query(mycli.destructive_keywords, stdin_text) except (IOError, OSError): mycli.logger.warning("Unable to open TTY as stdin.") if not warn_confirmed: diff --git a/mycli/myclirc b/mycli/myclirc index b49b81a6..62113850 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -27,6 +27,11 @@ multi_line = False # or "shutdown". destructive_warning = True +# Queries starting with these keywords will activate the destructive warning. +# UPDATE will not activate the warning if the statement includes a WHERE +# clause. +destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE + # interactive query history location. history_file = ~/.mycli-history diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index b29e7cbd..051a9826 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -264,15 +264,19 @@ def query_has_where_clause(query: str) -> bool: return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list) -def is_destructive(queries: str) -> bool: - """Returns if any of the queries in *queries* is destructive.""" - keywords = ("drop", "shutdown", "delete", "truncate", "alter") +def is_destructive(keywords: list[str], queries: str) -> bool: + """Returns True if any of the queries in *queries* is destructive.""" for query in sqlparse.split(queries): - if query: - if query_starts_with(query, list(keywords)) is True: - return True - elif query_starts_with(query, ["update"]) is True and not query_has_where_clause(query): + if not query: + continue + # subtle: if "UPDATE" is one of our keywords AND "query" starts with "UPDATE" + if query_starts_with(query, keywords) and query_starts_with(query, ["update"]): + if query_has_where_clause(query): + return False + else: return True + if query_starts_with(query, keywords): + return True return False diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py index 839fdcf6..68c468f6 100644 --- a/mycli/packages/prompt_utils.py +++ b/mycli/packages/prompt_utils.py @@ -25,7 +25,7 @@ def __repr__(self): BOOLEAN_TYPE = ConfirmBoolParamType() -def confirm_destructive_query(queries: str) -> bool | None: +def confirm_destructive_query(keywords: list[str], queries: str) -> bool | None: """Check if the query is destructive and prompts the user to confirm. Returns: @@ -35,7 +35,7 @@ def confirm_destructive_query(queries: str) -> bool | None: """ prompt_text = "You're about to run a destructive command.\nDo you want to proceed? (y/n)" - if is_destructive(queries) and sys.stdin.isatty(): + if is_destructive(keywords, queries) and sys.stdin.isatty(): return prompt(prompt_text, type=BOOLEAN_TYPE) else: return None diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index c96ffcb5..d3b60b7f 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -22,6 +22,7 @@ is_timing_enabled, open_external_editor, set_delimiter, + set_destructive_keywords, set_expanded_output, set_favorite_queries, set_forced_horizontal_output, @@ -77,6 +78,7 @@ 'parse_special_command', 'register_special_command', 'set_delimiter', + 'set_destructive_keywords', 'set_expanded_output', 'set_favorite_queries', 'set_forced_horizontal_output', diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index f9d3a94b..14437b5d 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -42,6 +42,7 @@ } delimiter_command = DelimiterCommand() favoritequeries = FavoriteQueries(ConfigObj()) +DESTRUCTIVE_KEYWORDS: list[str] = [] def set_favorite_queries(config): @@ -72,6 +73,11 @@ def is_show_favorite_query() -> bool: return SHOW_FAVORITE_QUERY +def set_destructive_keywords(val: list[str]) -> None: + global DESTRUCTIVE_KEYWORDS + DESTRUCTIVE_KEYWORDS = val + + @special_command( "pager", "\\P [command]", @@ -562,7 +568,7 @@ def watch_query(arg: str, **kwargs) -> Generator[SQLResult, None, None]: clear_screen = True continue statement = f"{left_arg} {arg}" - destructive_prompt = confirm_destructive_query(statement) + destructive_prompt = confirm_destructive_query(DESTRUCTIVE_KEYWORDS, statement) if destructive_prompt is False: click.secho("Wise choice!") return diff --git a/test/myclirc b/test/myclirc index 5f3c5a01..d3cdd4e9 100644 --- a/test/myclirc +++ b/test/myclirc @@ -27,6 +27,11 @@ multi_line = False # or "shutdown". destructive_warning = True +# Queries starting with these keywords will activate the destructive warning. +# UPDATE will not activate the warning if the statement includes a WHERE +# clause. +destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE + # interactive query history location. history_file = ~/.mycli-history diff --git a/test/test_parseutils.py b/test/test_parseutils.py index 4b06a07a..eb3972c1 100644 --- a/test/test_parseutils.py +++ b/test/test_parseutils.py @@ -149,17 +149,17 @@ def test_queries_start_with(): def test_is_destructive(): sql = "use test;\nshow databases;\ndrop database foo;" - assert is_destructive(sql) is True + assert is_destructive(["drop"], sql) is True def test_is_destructive_update_with_where_clause(): sql = "use test;\nshow databases;\nUPDATE test SET x = 1 WHERE id = 1;" - assert is_destructive(sql) is False + assert is_destructive(["update"], sql) is False def test_is_destructive_update_without_where_clause(): sql = "use test;\nshow databases;\nUPDATE test SET x = 1;" - assert is_destructive(sql) is True + assert is_destructive(["update"], sql) is True @pytest.mark.parametrize( diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py index 64e4ef31..236b7969 100644 --- a/test/test_prompt_utils.py +++ b/test/test_prompt_utils.py @@ -10,4 +10,4 @@ def test_confirm_destructive_query_notty() -> None: assert stdin.isatty() is False sql = "drop database foo;" - assert confirm_destructive_query(sql) is None + assert confirm_destructive_query(["drop"], sql) is None