Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Features
* More complete and up-to-date set of MySQL reserved words for completions.
* Place exact-leading completions first.
* Allow history file location to be configured.
* Make destructive-warning keywords configurable.


Bug Fixes
Expand Down
12 changes: 8 additions & 4 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def __init__(
self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt
self.multiline_continuation_char = c["main"]["prompt_continuation"]
self.prompt_app = None
self.destructive_keywords = [
keyword for keyword in c["main"].get("destructive_keywords", "DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE").split(' ') if keyword
]
special.set_destructive_keywords(self.destructive_keywords)

def close(self) -> None:
if self.sqlexecute is not None:
Expand Down Expand Up @@ -346,7 +350,7 @@ def execute_from_file(self, arg: str, **_) -> Iterable[SQLResult]:
except IOError as e:
return [SQLResult(status=str(e))]

if self.destructive_warning and confirm_destructive_query(query) is False:
if self.destructive_warning and confirm_destructive_query(self.destructive_keywords, query) is False:
message = "Wise choice. Command execution stopped."
return [SQLResult(status=message)]

Expand Down Expand Up @@ -977,7 +981,7 @@ def one_iteration(text: str | None = None) -> None:
return

if self.destructive_warning:
destroy = confirm_destructive_query(text)
destroy = confirm_destructive_query(self.destructive_keywords, text)
if destroy is None:
pass # Query was not destructive. Nothing to do here.
elif destroy is True:
Expand Down Expand Up @@ -1852,10 +1856,10 @@ def cli(
click.secho("Sorry... :(", err=True, fg="red")
sys.exit(1)

if mycli.destructive_warning and is_destructive(stdin_text):
if mycli.destructive_warning and is_destructive(mycli.destructive_keywords, stdin_text):
try:
sys.stdin = open("/dev/tty")
warn_confirmed = confirm_destructive_query(stdin_text)
warn_confirmed = confirm_destructive_query(mycli.destructive_keywords, stdin_text)
except (IOError, OSError):
mycli.logger.warning("Unable to open TTY as stdin.")
if not warn_confirmed:
Expand Down
5 changes: 5 additions & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ multi_line = False
# or "shutdown".
destructive_warning = True

# Queries starting with these keywords will activate the destructive warning.
# UPDATE will not activate the warning if the statement includes a WHERE
# clause.
destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE

# interactive query history location.
history_file = ~/.mycli-history

Expand Down
18 changes: 11 additions & 7 deletions mycli/packages/parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,19 @@ def query_has_where_clause(query: str) -> bool:
return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list)


def is_destructive(queries: str) -> bool:
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
def is_destructive(keywords: list[str], queries: str) -> bool:
"""Returns True if any of the queries in *queries* is destructive."""
for query in sqlparse.split(queries):
if query:
if query_starts_with(query, list(keywords)) is True:
return True
elif query_starts_with(query, ["update"]) is True and not query_has_where_clause(query):
if not query:
continue
# subtle: if "UPDATE" is one of our keywords AND "query" starts with "UPDATE"
if query_starts_with(query, keywords) and query_starts_with(query, ["update"]):
if query_has_where_clause(query):
return False
else:
return True
if query_starts_with(query, keywords):
return True

return False

Expand Down
4 changes: 2 additions & 2 deletions mycli/packages/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __repr__(self):
BOOLEAN_TYPE = ConfirmBoolParamType()


def confirm_destructive_query(queries: str) -> bool | None:
def confirm_destructive_query(keywords: list[str], queries: str) -> bool | None:
"""Check if the query is destructive and prompts the user to confirm.

Returns:
Expand All @@ -35,7 +35,7 @@ def confirm_destructive_query(queries: str) -> bool | None:

"""
prompt_text = "You're about to run a destructive command.\nDo you want to proceed? (y/n)"
if is_destructive(queries) and sys.stdin.isatty():
if is_destructive(keywords, queries) and sys.stdin.isatty():
return prompt(prompt_text, type=BOOLEAN_TYPE)
else:
return None
Expand Down
2 changes: 2 additions & 0 deletions mycli/packages/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
is_timing_enabled,
open_external_editor,
set_delimiter,
set_destructive_keywords,
set_expanded_output,
set_favorite_queries,
set_forced_horizontal_output,
Expand Down Expand Up @@ -77,6 +78,7 @@
'parse_special_command',
'register_special_command',
'set_delimiter',
'set_destructive_keywords',
'set_expanded_output',
'set_favorite_queries',
'set_forced_horizontal_output',
Expand Down
8 changes: 7 additions & 1 deletion mycli/packages/special/iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
}
delimiter_command = DelimiterCommand()
favoritequeries = FavoriteQueries(ConfigObj())
DESTRUCTIVE_KEYWORDS: list[str] = []


def set_favorite_queries(config):
Expand Down Expand Up @@ -72,6 +73,11 @@ def is_show_favorite_query() -> bool:
return SHOW_FAVORITE_QUERY


def set_destructive_keywords(val: list[str]) -> None:
global DESTRUCTIVE_KEYWORDS
DESTRUCTIVE_KEYWORDS = val


@special_command(
"pager",
"\\P [command]",
Expand Down Expand Up @@ -562,7 +568,7 @@ def watch_query(arg: str, **kwargs) -> Generator[SQLResult, None, None]:
clear_screen = True
continue
statement = f"{left_arg} {arg}"
destructive_prompt = confirm_destructive_query(statement)
destructive_prompt = confirm_destructive_query(DESTRUCTIVE_KEYWORDS, statement)
if destructive_prompt is False:
click.secho("Wise choice!")
return
Expand Down
5 changes: 5 additions & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ multi_line = False
# or "shutdown".
destructive_warning = True

# Queries starting with these keywords will activate the destructive warning.
# UPDATE will not activate the warning if the statement includes a WHERE
# clause.
destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE

# interactive query history location.
history_file = ~/.mycli-history

Expand Down
6 changes: 3 additions & 3 deletions test/test_parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ def test_queries_start_with():

def test_is_destructive():
sql = "use test;\nshow databases;\ndrop database foo;"
assert is_destructive(sql) is True
assert is_destructive(["drop"], sql) is True


def test_is_destructive_update_with_where_clause():
sql = "use test;\nshow databases;\nUPDATE test SET x = 1 WHERE id = 1;"
assert is_destructive(sql) is False
assert is_destructive(["update"], sql) is False


def test_is_destructive_update_without_where_clause():
sql = "use test;\nshow databases;\nUPDATE test SET x = 1;"
assert is_destructive(sql) is True
assert is_destructive(["update"], sql) is True


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/test_prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def test_confirm_destructive_query_notty() -> None:
assert stdin.isatty() is False

sql = "drop database foo;"
assert confirm_destructive_query(sql) is None
assert confirm_destructive_query(["drop"], sql) is None