diff --git a/src/libtmux/neo.py b/src/libtmux/neo.py index 932f969e1..30ac475b9 100644 --- a/src/libtmux/neo.py +++ b/src/libtmux/neo.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import functools import logging import typing as t from collections.abc import Iterable @@ -177,13 +178,109 @@ def _refresh( setattr(self, k, v) +@functools.cache +def get_output_format() -> tuple[tuple[str, ...], str]: + """Return field names and tmux format string for all Obj fields. + + Excludes the ``server`` field, which is a Python object reference + rather than a tmux format variable. + + Returns + ------- + tuple[tuple[str, ...], str] + A tuple of (field_names, tmux_format_string). + + Examples + -------- + >>> from libtmux.neo import get_output_format + >>> fields, fmt = get_output_format() + >>> 'session_id' in fields + True + >>> 'server' in fields + False + """ + # Exclude 'server' - it's a Python object, not a tmux format variable + formats = tuple(f for f in Obj.__dataclass_fields__ if f != "server") + tmux_formats = [f"#{{{f}}}{FORMAT_SEPARATOR}" for f in formats] + return formats, "".join(tmux_formats) + + +def parse_output(output: str) -> OutputRaw: + """Parse tmux output formatted with get_output_format() into a dict. + + Parameters + ---------- + output : str + Raw tmux output produced with the format string from + :func:`get_output_format`. + + Returns + ------- + OutputRaw + A dict mapping field names to non-empty string values. + + Examples + -------- + >>> from libtmux.neo import get_output_format, parse_output + >>> from libtmux.formats import FORMAT_SEPARATOR + >>> fields, fmt = get_output_format() + >>> values = [''] * len(fields) + >>> values[fields.index('session_id')] = '$1' + >>> result = parse_output(FORMAT_SEPARATOR.join(values) + FORMAT_SEPARATOR) + >>> result['session_id'] + '$1' + >>> 'buffer_sample' in result + False + """ + formats, _ = get_output_format() + formatter = dict(zip(formats, output.split(FORMAT_SEPARATOR), strict=False)) + return {k: v for k, v in formatter.items() if v} + + def fetch_objs( server: Server, list_cmd: ListCmd, list_extra_args: ListExtraArgs = None, ) -> OutputsRaw: - """Fetch a listing of raw data from a tmux command.""" - formats = list(Obj.__dataclass_fields__.keys()) + """Fetch a listing of raw data from a tmux command. + + Runs a tmux list command (e.g. ``list-sessions``) with the format string + from :func:`get_output_format` and parses each line of output into a dict. + + Parameters + ---------- + server : :class:`~libtmux.server.Server` + The tmux server to query. + list_cmd : ListCmd + The tmux list command to run, e.g. ``"list-sessions"``, + ``"list-windows"``, or ``"list-panes"``. + list_extra_args : ListExtraArgs, optional + Extra arguments appended to the tmux command (e.g. ``("-a",)`` + for all windows/panes, or ``["-t", session_id]`` to filter). + + Returns + ------- + OutputsRaw + A list of dicts, each mapping tmux format field names to their + non-empty string values. + + Raises + ------ + :exc:`~libtmux.exc.LibTmuxException` + If the tmux command writes to stderr. + + Examples + -------- + >>> from libtmux.neo import fetch_objs + >>> objs = fetch_objs(server=server, list_cmd="list-sessions") + >>> isinstance(objs, list) + True + >>> isinstance(objs[0], dict) + True + >>> 'session_id' in objs[0] + True + """ + _fields, format_string = get_output_format() cmd_args: list[str | int] = [] @@ -191,7 +288,6 @@ def fetch_objs( cmd_args.insert(0, f"-L{server.socket_name}") if server.socket_path: cmd_args.insert(0, f"-S{server.socket_path}") - tmux_formats = [f"#{{{f}}}{FORMAT_SEPARATOR}" for f in formats] tmux_cmds = [ *cmd_args, @@ -201,22 +297,14 @@ def fetch_objs( if list_extra_args is not None and isinstance(list_extra_args, Iterable): tmux_cmds.extend(list(list_extra_args)) - tmux_cmds.append("-F{}".format("".join(tmux_formats))) + tmux_cmds.append(f"-F{format_string}") proc = tmux_cmd(*tmux_cmds) # output if proc.stderr: raise exc.LibTmuxException(proc.stderr) - obj_output = proc.stdout - - obj_formatters = [ - dict(zip(formats, formatter.split(FORMAT_SEPARATOR), strict=False)) - for formatter in obj_output - ] - - # Filter empty values - return [{k: v for k, v in formatter.items() if v} for formatter in obj_formatters] + return [parse_output(line) for line in proc.stdout] def fetch_obj( diff --git a/src/libtmux/server.py b/src/libtmux/server.py index 71f9f84a7..edb52067e 100644 --- a/src/libtmux/server.py +++ b/src/libtmux/server.py @@ -14,12 +14,12 @@ import subprocess import typing as t -from libtmux import exc, formats +from libtmux import exc from libtmux._internal.query_list import QueryList from libtmux.common import tmux_cmd from libtmux.constants import OptionScope from libtmux.hooks import HooksMixin -from libtmux.neo import fetch_objs +from libtmux.neo import fetch_objs, get_output_format, parse_output from libtmux.pane import Pane from libtmux.session import Session from libtmux.window import Window @@ -539,9 +539,11 @@ def new_session( if env: del os.environ["TMUX"] + _fields, format_string = get_output_format() + tmux_args: tuple[str | int, ...] = ( "-P", - "-F#{session_id}", # output + f"-F{format_string}", ) if session_name is not None: @@ -580,18 +582,9 @@ def new_session( if env: os.environ["TMUX"] = env - session_formatters = dict( - zip( - ["session_id"], - session_stdout.split(formats.FORMAT_SEPARATOR), - strict=False, - ), - ) + session_data = parse_output(session_stdout) - return Session.from_session_id( - server=self, - session_id=session_formatters["session_id"], - ) + return Session(server=self, **session_data) # # Relations diff --git a/tests/test_server.py b/tests/test_server.py index 9b85d279c..cb9d83a9c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -104,6 +104,13 @@ def test_new_session(server: Server) -> None: assert server.has_session("test_new_session") +def test_new_session_returns_populated_session(server: Server) -> None: + """Server.new_session returns Session populated from -P output.""" + session = server.new_session(session_name="test_populated") + assert session.session_id is not None + assert session.session_name == "test_populated" + + def test_new_session_no_name(server: Server) -> None: """Server.new_session works with no name.""" first_session = server.new_session()