Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from mycli.packages import special
from mycli.packages.filepaths import dir_path_exists, guess_socket_location
from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command
from mycli.packages.parseutils import is_destructive, is_dropping_database
from mycli.packages.parseutils import is_destructive, is_dropping_database, is_valid_connection_scheme
from mycli.packages.prompt_utils import confirm, confirm_destructive_query
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import ArgType
Expand Down Expand Up @@ -1584,7 +1584,14 @@ def cli(
password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True)
# if the password value looks like a DSN, treat it as such and
# prompt for password
elif database is None and password is not None and password.startswith("mysql://"):
elif database is None and password is not None and "://" in password:
# check if the scheme is valid. We do not actually have any logic for these, but
# it will most usefully catch the case where we erroneously catch someone's
# password, and give them an easy error message to follow / report
is_valid_scheme, scheme = is_valid_connection_scheme(password)
if not is_valid_scheme:
click.secho(f"Error: Unknown connection scheme provided for DSN URI ({scheme}://)", err=True, fg="red")
sys.exit(1)
database = password
password = click.prompt("Enter password", hide_input=True, show_default=False, default='', type=str, err=True)
# getting the envvar ourselves because the envvar from a click
Expand Down
11 changes: 11 additions & 0 deletions mycli/packages/parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
}


def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]:
# exit early if the text does not resemble a DSN URI
if "://" not in text:
return False, None
scheme = text.split("://")[0]
if scheme not in ("mysql", "mysqlx", "tcp", "socket", "ssh"):
return False, scheme
else:
return True, None


def last_word(text: str, include: str = "alphanum_underscore") -> str:
r"""
Find the last word in a sentence.
Expand Down
12 changes: 11 additions & 1 deletion test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from click.testing import CliRunner
from pymysql.err import OperationalError

from mycli.main import MyCli, cli, thanks_picker
from mycli.main import MyCli, cli, is_valid_connection_scheme, thanks_picker
import mycli.packages.special
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.sqlexecute import ServerInfo, SQLExecute
Expand Down Expand Up @@ -40,6 +40,16 @@
]


def test_is_valid_connection_scheme_valid(executor, capsys):
is_valid, scheme = is_valid_connection_scheme("mysql://test@localhost:3306/dev")
assert is_valid


def test_is_valid_connection_scheme_invalid(executor, capsys):
is_valid, scheme = is_valid_connection_scheme("nope://test@localhost:3306/dev")
assert not is_valid


@dbtest
def test_ssl_mode_on(executor, capsys):
runner = CliRunner()
Expand Down