diff --git a/changelog.md b/changelog.md index 30ac12cb..a41ee47b 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ Features --------- * Prioritize common functions in the "value" position. * Improve value-position keywords. +* Allow warning-count in status output to be styled. 1.59.0 (2026/03/03) diff --git a/mycli/clistyle.py b/mycli/clistyle.py index 45772986..6398ff8e 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -38,6 +38,7 @@ Token.Output.EvenRow: "output.even-row", Token.Output.Null: "output.null", Token.Output.Status: "output.status", + Token.Output.Status.WarningCount: "output.status.warning-count", Token.Output.Timing: "output.timing", Token.Warnings.TableSeparator: "warnings.table-separator", Token.Warnings.Header: "warnings.header", @@ -45,6 +46,7 @@ Token.Warnings.EvenRow: "warnings.even-row", Token.Warnings.Null: "warnings.null", Token.Warnings.Status: "warnings.status", + Token.Warnings.Status.WarningCount: "warnings.status.warning-count", Token.Warnings.Timing: "warnings.timing", Token.Prompt: "prompt", Token.Continuation: "continuation", diff --git a/mycli/main.py b/mycli/main.py index 31ae82b0..a6ccdc3e 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -986,7 +986,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: sys.exit(1) else: watch_count += 1 - if is_select(result.status) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: + if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: self.echo( f"The result set has more than {threshold} rows.", fg="red", @@ -1018,7 +1018,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: if result_count > 0: self.echo("") try: - self.output(formatted, result.status) + self.output(formatted, result) except KeyboardInterrupt: pass if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: @@ -1031,7 +1031,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: start = time() result_count += 1 - mutating = mutating or is_mutating(result.status) + mutating = mutating or is_mutating(result.status_plain) # get and display warnings if enabled if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: @@ -1051,7 +1051,7 @@ def output_res(results: Generator[SQLResult], start: float) -> None: is_warnings_style=True, ) self.echo("") - self.output(formatted, warning.status, is_warnings_style=True) + self.output(formatted, warning, is_warnings_style=True) if saw_warning and special.is_timing_enabled(): self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True) @@ -1417,7 +1417,7 @@ def get_output_margin(self, status: str | None = None) -> int: def output( self, output: itertools.chain[str], - status: str | None = None, + result: SQLResult, is_warnings_style: bool = False, ) -> None: """Output text to stdout or a pager command. @@ -1438,7 +1438,7 @@ def output( size_columns = DEFAULT_WIDTH size_rows = DEFAULT_HEIGHT - margin = self.get_output_margin(status) + margin = self.get_output_margin(result.status_plain) fits = True buf = [] @@ -1480,12 +1480,14 @@ def newlinewrapper(text: list[str]) -> Generator[str, None, None]: for line in buf: click.secho(line) - if status: - # todo allow status to be a FormattedText, but strip before logging - self.log_output(status) + if result.status: + self.log_output(result.status_plain) add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' - formatted_status = FormattedText([('', status)]) - styled_status = to_formatted_text(formatted_status, style=add_style) + if isinstance(result.status, FormattedText): + status = result.status + else: + status = FormattedText([('', result.status_plain)]) + styled_status = to_formatted_text(status, style=add_style) print_formatted_text(styled_status, style=self.toolkit_style) def configure_pager(self) -> None: @@ -2466,20 +2468,20 @@ def need_completion_reset(queries: str) -> bool: return False -def is_mutating(status: str | None) -> bool: +def is_mutating(status_plain: str | None) -> bool: """Determines if the statement is mutating based on the status.""" - if not status: + if not status_plain: return False mutating = {"insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"} - return status.split(None, 1)[0].lower() in mutating + return status_plain.split(None, 1)[0].lower() in mutating -def is_select(status: str | None) -> bool: +def is_select(status_plain: str | None) -> bool: """Returns true if the first word in status is 'select'.""" - if not status: + if not status_plain: return False - return status.split(None, 1)[0].lower() == "select" + return status_plain.split(None, 1)[0].lower() == "select" def thanks_picker() -> str: diff --git a/mycli/myclirc b/mycli/myclirc index f6f3e819..057f6c30 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -270,6 +270,7 @@ output.odd-row = "" output.even-row = "" output.null = "#808080" output.status = "" +output.status.warning-count = "" output.timing = "" # SQL syntax highlighting overrides diff --git a/mycli/packages/sqlresult.py b/mycli/packages/sqlresult.py index 4ff3eebc..1edbebab 100644 --- a/mycli/packages/sqlresult.py +++ b/mycli/packages/sqlresult.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +from functools import cached_property +from prompt_toolkit.formatted_text import FormattedText, to_plain_text from pymysql.cursors import Cursor @@ -9,7 +11,7 @@ class SQLResult: header: list[str] | str | None = None rows: Cursor | list[tuple] | None = None postamble: str | None = None - status: str | None = None + status: str | FormattedText | None = None command: dict[str, str | float] | None = None def __iter__(self): @@ -17,3 +19,9 @@ def __iter__(self): def __str__(self): return f"{self.preamble}, {self.header}, {self.rows}, {self.postamble}, {self.status}, {self.command}" + + @cached_property + def status_plain(self): + if self.status is None: + return None + return to_plain_text(self.status) diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index e4343f7f..2b70957e 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -7,6 +7,7 @@ import ssl from typing import Any, Generator, Iterable +from prompt_toolkit.formatted_text import FormattedText import pymysql from pymysql.connections import Connection from pymysql.constants import FIELD_TYPE @@ -393,14 +394,17 @@ def get_result(self, cursor: Cursor) -> SQLResult: plural = '' if cursor.rowcount == 1 else 's' if cursor.description: header = [x[0] for x in cursor.description] - status = f'{cursor.rowcount} row{plural} in set' + status = FormattedText([('', f'{cursor.rowcount} row{plural} in set')]) else: _logger.debug("No rows in result.") - status = f'Query OK, {cursor.rowcount} row{plural} affected' + status = FormattedText([('', f'Query OK, {cursor.rowcount} row{plural} affected')]) if cursor.warning_count > 0: plural = '' if cursor.warning_count == 1 else 's' - status = f'{status}, {cursor.warning_count} warning{plural}' + comma = FormattedText([('', ', ')]) + warning_count = FormattedText([('class:output.status.warning-count', f'{cursor.warning_count} warning{plural}')]) + status.extend(comma) + status.extend(warning_count) return SQLResult(preamble=preamble, header=header, rows=cursor, status=status) diff --git a/test/myclirc b/test/myclirc index 0b8f094c..4b37d012 100644 --- a/test/myclirc +++ b/test/myclirc @@ -268,6 +268,7 @@ output.odd-row = "" output.even-row = "" output.null = "#808080" output.status = "" +output.status.warning-count = "" output.timing = "" # SQL syntax highlighting overrides diff --git a/test/test_main.py b/test/test_main.py index 34ac1aaf..25c95e6e 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -17,6 +17,7 @@ from mycli.packages.parseutils import is_valid_connection_scheme import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.sqlresult import SQLResult from mycli.sqlexecute import ServerInfo, SQLExecute from test.utils import DATABASE, HOST, PASSWORD, PORT, TEMPFILE_PREFIX, USER, dbtest, run @@ -76,7 +77,7 @@ def test_binary_display_hex(executor): ) f = io.StringIO() with redirect_stdout(f): - m.output(formatted, sqlresult.status) + m.output(formatted, sqlresult) expected = " 0x6a " output = f.getvalue() assert expected in output @@ -115,7 +116,7 @@ def test_binary_display_utf8(executor): ) f = io.StringIO() with redirect_stdout(f): - m.output(formatted, sqlresult.status) + m.output(formatted, sqlresult) expected = " j " output = f.getvalue() assert expected in output @@ -651,7 +652,7 @@ def secho(s): monkeypatch.setattr(click, "echo_via_pager", echo_via_pager) monkeypatch.setattr(click, "secho", secho) - m.output(testdata) + m.output(testdata, SQLResult()) if clickoutput.endswith("\n"): clickoutput = clickoutput[:-1] assert clickoutput == "\n".join(testdata) diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index c1d40fe3..c57541f8 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -3,6 +3,7 @@ from datetime import time import os +from prompt_toolkit.formatted_text import FormattedText import pymysql import pytest @@ -16,19 +17,22 @@ def assert_result_equal( header=None, rows=None, status=None, + status_plain=None, postamble=None, auto_status=True, assert_contains=False, ): """Assert that an sqlexecute.run() result matches the expected values.""" - if status is None and auto_status and rows: - status = f"{len(rows)} row{'s' if len(rows) > 1 else ''} in set" + if status_plain is None and auto_status and rows: + status_plain = f"{len(rows)} row{'s' if len(rows) > 1 else ''} in set" + status = FormattedText([('', status_plain)]) fields = { "preamble": preamble, "header": header, "rows": rows, "postamble": postamble, "status": status, + "status_plain": status_plain, } if assert_contains: @@ -61,14 +65,19 @@ def test_timediff_positive_value(executor): def test_get_result_status_without_warning(executor): sql = "select 1" result = run(executor, sql) - assert result[0]["status"] == "1 row in set" + assert result[0]["status_plain"] == "1 row in set" @dbtest def test_get_result_status_with_warning(executor): sql = "SELECT 1 + '0 foo'" result = run(executor, sql) - assert result[0]["status"] == "1 row in set, 1 warning" + assert result[0]["status"] == FormattedText([ + ('', '1 row in set'), + ('', ', '), + ('class:output.status.warning-count', '1 warning'), + ]) + assert result[0]["status_plain"] == "1 row in set, 1 warning" @dbtest @@ -148,8 +157,22 @@ def test_multiple_queries_same_line(executor): results = run(executor, "select 'foo'; select 'bar'") expected = [ - {"preamble": None, "header": ["foo"], "rows": [("foo",)], "postamble": None, "status": "1 row in set"}, - {"preamble": None, "header": ["bar"], "rows": [("bar",)], "postamble": None, "status": "1 row in set"}, + { + "preamble": None, + "header": ["foo"], + "rows": [("foo",)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + { + "preamble": None, + "header": ["bar"], + "rows": [("bar",)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, ] assert expected == results @@ -170,13 +193,13 @@ def test_favorite_query(executor): run(executor, "insert into test values('def')") results = run(executor, "\\fs test-a select * from test where a like 'a%'") - assert_result_equal(results, status="Saved.") + assert_result_equal(results, status="Saved.", status_plain="Saved.") results = run(executor, "\\f test-a") assert_result_equal(results, preamble="> select * from test where a like 'a%'", header=["a"], rows=[("abc",)], auto_status=False) results = run(executor, "\\fd test-a") - assert_result_equal(results, status="test-a: Deleted.") + assert_result_equal(results, status="test-a: Deleted.", status_plain="test-a: Deleted.") @dbtest @@ -188,17 +211,31 @@ def test_favorite_query_multiple_statement(executor): run(executor, "insert into test values('def')") results = run(executor, "\\fs test-ad select * from test where a like 'a%'; select * from test where a like 'd%'") - assert_result_equal(results, status="Saved.") + assert_result_equal(results, status="Saved.", status_plain="Saved.") results = run(executor, "\\f test-ad") expected = [ - {"preamble": "> select * from test where a like 'a%'", "header": ["a"], "rows": [("abc",)], "postamble": None, "status": None}, - {"preamble": "> select * from test where a like 'd%'", "header": ["a"], "rows": [("def",)], "postamble": None, "status": None}, + { + "preamble": "> select * from test where a like 'a%'", + "header": ["a"], + "rows": [("abc",)], + "postamble": None, + "status": None, + "status_plain": None, + }, + { + "preamble": "> select * from test where a like 'd%'", + "header": ["a"], + "rows": [("def",)], + "postamble": None, + "status": None, + "status_plain": None, + }, ] assert expected == results results = run(executor, "\\fd test-ad") - assert_result_equal(results, status="test-ad: Deleted.") + assert_result_equal(results, status="test-ad: Deleted.", status_plain="test-ad: Deleted.") @dbtest @@ -209,7 +246,7 @@ def test_favorite_query_expanded_output(executor): run(executor, """insert into test values('abc')""") results = run(executor, "\\fs test-ae select * from test") - assert_result_equal(results, status="Saved.") + assert_result_equal(results, status="Saved.", status_plain="Saved.") results = run(executor, "\\f test-ae \\G") assert is_expanded_output() is True @@ -218,7 +255,7 @@ def test_favorite_query_expanded_output(executor): set_expanded_output(False) results = run(executor, "\\fd test-ae") - assert_result_equal(results, status="test-ae: Deleted.") + assert_result_equal(results, status="test-ae: Deleted.", status_plain="test-ae: Deleted.") @dbtest @@ -237,41 +274,45 @@ def test_special_command(executor): @dbtest def test_cd_command_without_a_folder_name(executor): results = run(executor, "system cd") - assert_result_equal(results, status="Exactly one directory name must be provided.") + assert_result_equal( + results, status="Exactly one directory name must be provided.", status_plain="Exactly one directory name must be provided." + ) @dbtest def test_cd_command_with_one_nonexistent_folder_name(executor): results = run(executor, 'system cd nonexistent_folder_name') - assert_result_equal(results, status='No such file or directory') + assert_result_equal(results, status='No such file or directory', status_plain='No such file or directory') @dbtest def test_cd_command_with_one_real_folder_name(executor): results = run(executor, 'system cd screenshots') # todo would be better to capture stderr but there was a problem with capsys - assert results[0]['status'] == '' + assert results[0]['status_plain'] == '' @dbtest def test_cd_command_with_two_folder_names(executor): results = run(executor, "system cd one two") - assert_result_equal(results, status='Exactly one directory name must be provided.') + assert_result_equal( + results, status='Exactly one directory name must be provided.', status_plain='Exactly one directory name must be provided.' + ) @dbtest def test_cd_command_unbalanced(executor): results = run(executor, "system cd 'one") - assert_result_equal(results, status='Cannot parse cd command.') + assert_result_equal(results, status='Cannot parse cd command.', status_plain='Cannot parse cd command.') @dbtest def test_system_command_not_found(executor): results = run(executor, "system xyz") if os.name == "nt": - assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True) + assert_result_equal(results, status_plain="OSError: The system cannot find the file specified", assert_contains=True) else: - assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True) + assert_result_equal(results, status_plain="OSError: No such file or directory", assert_contains=True) @dbtest @@ -280,7 +321,7 @@ def test_system_command_output(executor): test_dir = os.path.abspath(os.path.dirname(__file__)) test_file_path = os.path.join(test_dir, "test.txt") results = run(executor, f"system cat {test_file_path}") - assert_result_equal(results, status=f"mycli rocks!{eol}") + assert_result_equal(results, status=f"mycli rocks!{eol}", status_plain=f"mycli rocks!{eol}") @dbtest @@ -339,8 +380,22 @@ def test_multiple_results(executor): results = run(executor, "call dmtest;") expected = [ - {"preamble": None, "header": ["1"], "rows": [(1,)], "postamble": None, "status": "1 row in set"}, - {"preamble": None, "header": ["2"], "rows": [(2,)], "postamble": None, "status": "1 row in set"}, + { + "preamble": None, + "header": ["1"], + "rows": [(1,)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + { + "preamble": None, + "header": ["2"], + "rows": [(2,)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, ] assert results == expected diff --git a/test/utils.py b/test/utils.py index 72e8b833..d30472e1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -59,6 +59,7 @@ def run(executor, sql, rows_as_list=True): "rows": rows, "postamble": result.postamble, "status": result.status, + "status_plain": result.status_plain, }) return results