Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Features
Bug Fixes
---------
* Make toolbar widths consistent on toggle actions.
* Don't write ANSI prompt escapes to `tee` output.


Internal
Expand All @@ -25,6 +26,7 @@ Internal
* Add more URL constants.
* Set `$VISUAL` whenever `$EDITOR` is set.
* Fix tempfile leak in test suite.
* Avoid refreshing the prompt unless needed.


1.58.0 (2026/02/28)
Expand Down
34 changes: 25 additions & 9 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def __init__(
self.my_cnf['mysqld'] = {}
prompt_cnf = self.read_my_cnf(self.my_cnf, ["prompt"])["prompt"]
self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt
self.prompt_lines = 0
self.multiline_continuation_char = c["main"]["prompt_continuation"]
self.toolbar_format = toolbar_format or c['main']['toolbar']
self.prompt_app = None
Expand Down Expand Up @@ -935,10 +936,13 @@ def run_cli(self) -> None:
def get_prompt_message(app) -> ANSI:
if app.current_buffer.text:
return self.last_prompt_message
prompt = self.get_prompt(self.prompt_format)
prompt = self.get_prompt(self.prompt_format, app.render_counter)
if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
prompt = self.get_prompt(self.default_prompt_splitln)
prompt = self.get_prompt(self.default_prompt_splitln, app.render_counter)
self.prompt_lines = prompt.count('\n') + 1
prompt = prompt.replace("\\x1b", "\x1b")
if not self.prompt_lines:
self.prompt_lines = prompt.count('\n') + 1
self.last_prompt_message = ANSI(prompt)
return self.last_prompt_message

Expand Down Expand Up @@ -1182,7 +1186,8 @@ def one_iteration(text: str | None = None) -> None:
try:
logger.debug("sql: %r", text)

special.write_tee(self.get_prompt(self.prompt_format) + text)
special.write_tee(self.last_prompt_message, nl=False)
special.write_tee(text)
self.log_query(text)

successful = False
Expand Down Expand Up @@ -1397,7 +1402,11 @@ def echo(self, s: str, **kwargs) -> None:
def get_output_margin(self, status: str | None = None) -> int:
"""Get the output margin (number of rows for the prompt, footer and
timing message."""
margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 1
if not self.prompt_lines:
# self.prompt_app.app.render_counter failed in the test suite
app = get_app()
self.prompt_lines = self.get_prompt(self.prompt_format, app.render_counter).count('\n') + 1
margin = self.get_reserved_space() + self.prompt_lines
if special.is_timing_enabled():
margin += 1
if status:
Expand Down Expand Up @@ -1534,13 +1543,18 @@ def get_completions(self, text: str, cursor_position: int) -> Iterable[Completio
def get_custom_toolbar(self, toolbar_format: str) -> ANSI:
if self.prompt_app and self.prompt_app.app.current_buffer.text:
return self.last_custom_toolbar_message
toolbar = self.get_prompt(toolbar_format)
app = get_app()
toolbar = self.get_prompt(toolbar_format, app.render_counter)
toolbar = toolbar.replace("\\x1b", "\x1b")
self.last_custom_toolbar_message = ANSI(toolbar)
return self.last_custom_toolbar_message

# todo: time/uptime update on every character typed, instead of after every return
def get_prompt(self, string: str) -> str:
# Memoizing a method leaks the instance, but we only expect one MyCli instance.
# Before memoizing, get_prompt() was called dozens of times per prompt.
# Even after memoizing, get_prompt's logic gets called twice per prompt, which
# should be addressed, because some format strings take a trip to the server.
@functools.lru_cache(maxsize=256) # noqa: B019
def get_prompt(self, string: str, _render_counter: int) -> str:
sqlexecute = self.sqlexecute
assert sqlexecute is not None
assert sqlexecute.server_info is not None
Expand Down Expand Up @@ -1569,6 +1583,8 @@ def get_prompt(self, string: str) -> str:
string = string.replace("\\k", os.path.basename(sqlexecute.socket or str(sqlexecute.port)))
string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port))
string = string.replace("\\A", self.dsn_alias or "(none)")
string = string.replace("\\_", " ")

# jump through hoops for the test environment, and for efficiency
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\y' in string:
Expand All @@ -1581,14 +1597,13 @@ def get_prompt(self, string: str) -> str:
string = string.replace('\\y', '(none)')
string = string.replace('\\Y', '(none)')

string = string.replace("\\_", " ")
# jump through hoops for the test environment and for efficiency
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\T' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\T', get_ssl_version(cur) or '(none)')
else:
string = string.replace('\\T', '(none)')

if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\w' in string:
with sqlexecute.conn.cursor() as cur:
Expand All @@ -1601,6 +1616,7 @@ def get_prompt(self, string: str) -> str:
string = string.replace('\\W', str(get_warning_count(cur) or ''))
else:
string = string.replace('\\W', '')

return string

def run_query(
Expand Down
13 changes: 8 additions & 5 deletions mycli/packages/special/iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import click
from configobj import ConfigObj
from prompt_toolkit.formatted_text import ANSI, FormattedText, to_plain_text
from pymysql.cursors import Cursor
import pyperclip
import sqlparse
Expand Down Expand Up @@ -432,12 +433,14 @@ def no_tee(arg: str, **_) -> list[SQLResult]:
return [SQLResult(status="")]


def write_tee(output: str) -> None:
def write_tee(output: str | ANSI | FormattedText, nl: bool = True) -> None:
global tee_file
if tee_file:
click.echo(output, file=tee_file, nl=False)
click.echo("\n", file=tee_file, nl=False)
tee_file.flush()
if not tee_file:
return
click.echo(to_plain_text(output), file=tee_file, nl=False)
if nl:
click.echo('\n', file=tee_file, nl=False)
tee_file.flush()


@special_command("\\once", "\\once [-o] <filename>", "Append next result to an output file (overwrite using -o).", aliases=["\\o"])
Expand Down
4 changes: 2 additions & 2 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test_prompt_no_host_only_socket(executor):
mycli.sqlexecute.user = "root"
mycli.sqlexecute.dbname = "mysql"
mycli.sqlexecute.port = "3306"
prompt = mycli.get_prompt(mycli.prompt_format)
prompt = mycli.get_prompt(mycli.prompt_format, 0)
assert prompt == "MySQL root@localhost:mysql> "


Expand All @@ -350,7 +350,7 @@ def test_prompt_socket_overrides_port(executor):
mycli.sqlexecute.user = "root"
mycli.sqlexecute.dbname = "mysql"
mycli.sqlexecute.port = "3306"
prompt = mycli.get_prompt(mycli.prompt_format)
prompt = mycli.get_prompt(mycli.prompt_format, 0)
assert prompt == "MySQL root@localhost:mysqld.sock mysql> "


Expand Down