diff --git a/changelog.md b/changelog.md index f685262f..c2e8b335 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ TBD ============== +Features +-------- +* Add `--unbuffered` mode which fetches rows as needed, to save memory. + + Bug Fixes -------- * Fix CamelCase fuzzy matching. diff --git a/mycli/main.py b/mycli/main.py index fa1f9731..731489fd 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -477,6 +477,7 @@ def connect( ssh_password: str | None = "", ssh_key_filename: str | None = "", init_command: str | None = "", + unbuffered: bool | None = None, password_file: str | None = "", ) -> None: cnf = { @@ -563,6 +564,7 @@ def _connect() -> None: ssh_password, ssh_key_filename, init_command, + unbuffered, ) except pymysql.OperationalError as e1: if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": @@ -583,6 +585,7 @@ def _connect() -> None: ssh_password, ssh_key_filename, init_command, + unbuffered, ) except Exception as e2: raise e2 @@ -1521,6 +1524,9 @@ def get_last_query(self) -> str | None: @click.option("-g", "--login-path", type=str, help="Read this path from the login file.") @click.option("-e", "--execute", type=str, help="Execute command and quit.") @click.option("--init-command", type=str, help="SQL statement to execute after connecting.") +@click.option( + "--unbuffered", is_flag=True, help="Instead of copying every row of data into a buffer, fetch rows as needed, to save memory." +) @click.option("--charset", type=str, help="Character set for MySQL session.") @click.option( "--password-file", type=click.Path(), help="File or FIFO path containing the password to connect to the db if not specified otherwise." @@ -1570,6 +1576,7 @@ def cli( ssh_config_path: str, ssh_config_host: str | None, init_command: str | None, + unbuffered: bool | None, charset: str | None, password_file: str | None, ) -> None: @@ -1807,6 +1814,7 @@ def cli( ssh_password=ssh_password, ssh_key_filename=ssh_key_filename, init_command=combined_init_cmd, + unbuffered=unbuffered, charset=charset, password_file=password_file, ) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 800a5381..9448b5dc 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -162,6 +162,7 @@ def __init__( ssh_password: str | None, ssh_key_filename: str | None, init_command: str | None = None, + unbuffered: bool | None = None, ) -> None: self.dbname = database self.user = user @@ -180,6 +181,7 @@ def __init__( self.ssh_password = ssh_password self.ssh_key_filename = ssh_key_filename self.init_command = init_command + self.unbuffered = unbuffered self.conn: Connection | None = None self.connect() @@ -200,6 +202,7 @@ def connect( ssh_password: str | None = None, ssh_key_filename: str | None = None, init_command: str | None = None, + unbuffered: bool | None = None, ): db = database if database is not None else self.dbname user = user if user is not None else self.user @@ -216,6 +219,7 @@ def connect( ssh_password = ssh_password if ssh_password is not None else self.ssh_password ssh_key_filename = ssh_key_filename if ssh_key_filename is not None else self.ssh_key_filename init_command = init_command if init_command is not None else self.init_command + unbuffered = unbuffered if unbuffered is not None else self.unbuffered _logger.debug( "Connection DB Params: \n" "\tdatabase: %r" @@ -231,7 +235,8 @@ def connect( "\tssh_port: %r" "\tssh_password: %r" "\tssh_key_filename: %r" - "\tinit_command: %r", + "\tinit_command: %r" + "\tunbuffered: %r", db, user, host, @@ -246,6 +251,7 @@ def connect( ssh_password, ssh_key_filename, init_command, + unbuffered, ) conv = conversions.copy() conv.update({ @@ -285,6 +291,7 @@ def connect( program_name="mycli", defer_connect=defer_connect, init_command=init_command or None, + cursorclass=pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, ) # type: ignore[misc] if ssh_host: @@ -324,6 +331,7 @@ def connect( self.charset = charset self.ssl = ssl self.init_command = init_command + self.unbuffered = unbuffered # retrieve connection id self.reset_connection_id() self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined]