Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Features
---------
* Prioritize common functions in the "value" position.
* Improve value-position keywords.
* Allow warning-count in status output to be styled.


1.59.0 (2026/03/03)
Expand Down
2 changes: 2 additions & 0 deletions mycli/clistyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
Token.Output.EvenRow: "output.even-row",
Token.Output.Null: "output.null",
Token.Output.Status: "output.status",
Token.Output.Status.WarningCount: "output.status.warning-count",
Token.Output.Timing: "output.timing",
Token.Warnings.TableSeparator: "warnings.table-separator",
Token.Warnings.Header: "warnings.header",
Token.Warnings.OddRow: "warnings.odd-row",
Token.Warnings.EvenRow: "warnings.even-row",
Token.Warnings.Null: "warnings.null",
Token.Warnings.Status: "warnings.status",
Token.Warnings.Status.WarningCount: "warnings.status.warning-count",
Token.Warnings.Timing: "warnings.timing",
Token.Prompt: "prompt",
Token.Continuation: "continuation",
Expand Down
36 changes: 19 additions & 17 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
sys.exit(1)
else:
watch_count += 1
if is_select(result.status) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold:
if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold:
self.echo(
f"The result set has more than {threshold} rows.",
fg="red",
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
if result_count > 0:
self.echo("")
try:
self.output(formatted, result.status)
self.output(formatted, result)
except KeyboardInterrupt:
pass
if self.beep_after_seconds > 0 and t >= self.beep_after_seconds:
Expand All @@ -1031,7 +1031,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:

start = time()
result_count += 1
mutating = mutating or is_mutating(result.status)
mutating = mutating or is_mutating(result.status_plain)

# get and display warnings if enabled
if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0:
Expand All @@ -1051,7 +1051,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None:
is_warnings_style=True,
)
self.echo("")
self.output(formatted, warning.status, is_warnings_style=True)
self.output(formatted, warning, is_warnings_style=True)

if saw_warning and special.is_timing_enabled():
self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True)
Expand Down Expand Up @@ -1417,7 +1417,7 @@ def get_output_margin(self, status: str | None = None) -> int:
def output(
self,
output: itertools.chain[str],
status: str | None = None,
result: SQLResult,
is_warnings_style: bool = False,
) -> None:
"""Output text to stdout or a pager command.
Expand All @@ -1438,7 +1438,7 @@ def output(
size_columns = DEFAULT_WIDTH
size_rows = DEFAULT_HEIGHT

margin = self.get_output_margin(status)
margin = self.get_output_margin(result.status_plain)

fits = True
buf = []
Expand Down Expand Up @@ -1480,12 +1480,14 @@ def newlinewrapper(text: list[str]) -> Generator[str, None, None]:
for line in buf:
click.secho(line)

if status:
# todo allow status to be a FormattedText, but strip before logging
self.log_output(status)
if result.status:
self.log_output(result.status_plain)
add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status'
formatted_status = FormattedText([('', status)])
styled_status = to_formatted_text(formatted_status, style=add_style)
if isinstance(result.status, FormattedText):
status = result.status
else:
status = FormattedText([('', result.status_plain)])
styled_status = to_formatted_text(status, style=add_style)
print_formatted_text(styled_status, style=self.toolkit_style)

def configure_pager(self) -> None:
Expand Down Expand Up @@ -2466,20 +2468,20 @@ def need_completion_reset(queries: str) -> bool:
return False


def is_mutating(status: str | None) -> bool:
def is_mutating(status_plain: str | None) -> bool:
"""Determines if the statement is mutating based on the status."""
if not status:
if not status_plain:
return False

mutating = {"insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"}
return status.split(None, 1)[0].lower() in mutating
return status_plain.split(None, 1)[0].lower() in mutating


def is_select(status: str | None) -> bool:
def is_select(status_plain: str | None) -> bool:
"""Returns true if the first word in status is 'select'."""
if not status:
if not status_plain:
return False
return status.split(None, 1)[0].lower() == "select"
return status_plain.split(None, 1)[0].lower() == "select"


def thanks_picker() -> str:
Expand Down
1 change: 1 addition & 0 deletions mycli/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ output.odd-row = ""
output.even-row = ""
output.null = "#808080"
output.status = ""
output.status.warning-count = ""
output.timing = ""

# SQL syntax highlighting overrides
Expand Down
10 changes: 9 additions & 1 deletion mycli/packages/sqlresult.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from functools import cached_property

from prompt_toolkit.formatted_text import FormattedText, to_plain_text
from pymysql.cursors import Cursor


Expand All @@ -9,11 +11,17 @@ class SQLResult:
header: list[str] | str | None = None
rows: Cursor | list[tuple] | None = None
postamble: str | None = None
status: str | None = None
status: str | FormattedText | None = None
command: dict[str, str | float] | None = None

def __iter__(self):
return self

def __str__(self):
return f"{self.preamble}, {self.header}, {self.rows}, {self.postamble}, {self.status}, {self.command}"

@cached_property
def status_plain(self):
if self.status is None:
return None
return to_plain_text(self.status)
10 changes: 7 additions & 3 deletions mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ssl
from typing import Any, Generator, Iterable

from prompt_toolkit.formatted_text import FormattedText
import pymysql
from pymysql.connections import Connection
from pymysql.constants import FIELD_TYPE
Expand Down Expand Up @@ -393,14 +394,17 @@ def get_result(self, cursor: Cursor) -> SQLResult:
plural = '' if cursor.rowcount == 1 else 's'
if cursor.description:
header = [x[0] for x in cursor.description]
status = f'{cursor.rowcount} row{plural} in set'
status = FormattedText([('', f'{cursor.rowcount} row{plural} in set')])
else:
_logger.debug("No rows in result.")
status = f'Query OK, {cursor.rowcount} row{plural} affected'
status = FormattedText([('', f'Query OK, {cursor.rowcount} row{plural} affected')])

if cursor.warning_count > 0:
plural = '' if cursor.warning_count == 1 else 's'
status = f'{status}, {cursor.warning_count} warning{plural}'
comma = FormattedText([('', ', ')])
warning_count = FormattedText([('class:output.status.warning-count', f'{cursor.warning_count} warning{plural}')])
status.extend(comma)
status.extend(warning_count)

return SQLResult(preamble=preamble, header=header, rows=cursor, status=status)

Expand Down
1 change: 1 addition & 0 deletions test/myclirc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ output.odd-row = ""
output.even-row = ""
output.null = "#808080"
output.status = ""
output.status.warning-count = ""
output.timing = ""

# SQL syntax highlighting overrides
Expand Down
7 changes: 4 additions & 3 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mycli.packages.parseutils import is_valid_connection_scheme
import mycli.packages.special
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.packages.sqlresult import SQLResult
from mycli.sqlexecute import ServerInfo, SQLExecute
from test.utils import DATABASE, HOST, PASSWORD, PORT, TEMPFILE_PREFIX, USER, dbtest, run

Expand Down Expand Up @@ -76,7 +77,7 @@ def test_binary_display_hex(executor):
)
f = io.StringIO()
with redirect_stdout(f):
m.output(formatted, sqlresult.status)
m.output(formatted, sqlresult)
expected = " 0x6a "
output = f.getvalue()
assert expected in output
Expand Down Expand Up @@ -115,7 +116,7 @@ def test_binary_display_utf8(executor):
)
f = io.StringIO()
with redirect_stdout(f):
m.output(formatted, sqlresult.status)
m.output(formatted, sqlresult)
expected = " j "
output = f.getvalue()
assert expected in output
Expand Down Expand Up @@ -651,7 +652,7 @@ def secho(s):

monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
monkeypatch.setattr(click, "secho", secho)
m.output(testdata)
m.output(testdata, SQLResult())
if clickoutput.endswith("\n"):
clickoutput = clickoutput[:-1]
assert clickoutput == "\n".join(testdata)
Expand Down
Loading