diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 5b646fabca0225..e6f01577826221 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -2,6 +2,7 @@ import abc import os import sys +import threading import _collections_abc from collections import deque from functools import wraps @@ -390,22 +391,76 @@ async def __aexit__(self, *exc_info): await self.thing.aclose() +class _PerThreadStream: + def __init__(self, default_stream): + self.default_stream = default_stream + # each stack entry is (stream, thread_id). thread_id is None if + # per_thread=False. + self._stack = [] + self._lock = threading.Lock() + + @property + def _current_stream(self): + thread_id = threading.get_ident() + # look for the most recent redirect which was either: + # * per_thread=False + # * per_thread=True, and in our thread + # + # If none match, fall back to the default stream. + with self._lock: + for stream, entry_thread_id in reversed(self._stack): + if entry_thread_id is None or entry_thread_id == thread_id: + return stream + return self.default_stream + + def add_entry(self, entry): + with self._lock: + self._stack.append(entry) + + def remove_entry(self, entry): + # remove by identity, not equality, in case two streams compare equal + with self._lock: + for i, e in enumerate(self._stack): + if e is entry: + del self._stack[i] + return + + def __getattr__(self, name): + return getattr(self._current_stream, name) + + class _RedirectStream(AbstractContextManager): _stream = None + _lock = None + _stream_ref = None - def __init__(self, new_target): + def __init__(self, new_target, *, per_thread=False): self._new_target = new_target - # We use a list of old targets to make this CM re-entrant - self._old_targets = [] + self._per_thread = per_thread + self._entries = [] # stack for reentrant usage def __enter__(self): - self._old_targets.append(getattr(sys, self._stream)) - setattr(sys, self._stream, self._new_target) + with self._lock: + if self._stream_ref is None: + type(self)._stream_ref = _PerThreadStream(getattr(sys, self._stream)) + setattr(sys, self._stream, self._stream_ref) + entry = ( + self._new_target, + threading.get_ident() if self._per_thread else None, + ) + self._entries.append(entry) + self._stream_ref.add_entry(entry) + return self._new_target def __exit__(self, exctype, excinst, exctb): - setattr(sys, self._stream, self._old_targets.pop()) + with self._lock: + entry = self._entries.pop() + self._stream_ref.remove_entry(entry) + if len(self._stream_ref._stack) == 0: + setattr(sys, self._stream, self._stream_ref.default_stream) + type(self)._stream_ref = None class redirect_stdout(_RedirectStream): @@ -422,12 +477,16 @@ class redirect_stdout(_RedirectStream): """ _stream = "stdout" + _lock = threading.Lock() + _stream_ref = None class redirect_stderr(_RedirectStream): """Context manager for temporarily redirecting stderr to another file.""" _stream = "stderr" + _lock = threading.Lock() + _stream_ref = None class suppress(AbstractContextManager): diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 6a3329fa5aaace..d72078e7fd993b 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1362,5 +1362,379 @@ def test_exception(self): self.assertEqual(os.getcwd(), old_cwd) +class TestRedirectStreamPerThread(unittest.TestCase): + def test_basic_redirect_stdout(self): + buf = io.StringIO() + with redirect_stdout(buf, per_thread=True): + print("foo") + self.assertEqual(buf.getvalue(), "foo\n") + + original = sys.stdout + self.assertIs(sys.stdout, original) + + def test_basic_redirect_stderr(self): + buf = io.StringIO() + with redirect_stderr(buf, per_thread=True): + print("foo", file=sys.stderr) + self.assertEqual(buf.getvalue(), "foo\n") + + original = sys.stderr + self.assertIs(sys.stderr, original) + + def test_stdout_and_stderr_are_independent(self): + stdout_buf = io.StringIO() + stderr_buf = io.StringIO() + + with ( + redirect_stdout(stdout_buf, per_thread=True), + redirect_stderr(stderr_buf, per_thread=True), + ): + print("to stdout") + print("to stderr", file=sys.stderr) + + self.assertEqual(stdout_buf.getvalue(), "to stdout\n") + self.assertEqual(stderr_buf.getvalue(), "to stderr\n") + + def test_single_thread(self): + buf = io.StringIO() + with redirect_stdout(buf, per_thread=True): + print("foo") + self.assertEqual(buf.getvalue(), "foo\n") + + def test_delegates_explicit_writes(self): + buf = io.StringIO() + with redirect_stdout(buf, per_thread=True): + sys.stdout.write("line1\n") + sys.stdout.writelines(["line2\n", "line3\n"]) + self.assertEqual(buf.getvalue(), "line1\nline2\nline3\n") + + def test_redirect_to_file(self): + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + with redirect_stdout(f, per_thread=True): + print("foo") + try: + with open(f.name) as rf: + self.assertEqual(rf.read(), "foo\n") + finally: + os.unlink(f.name) + + def test_no_per_thread(self): + buf = io.StringIO() + with redirect_stdout(buf): + print("foo") + self.assertEqual(buf.getvalue(), "foo\n") + + def test_restores_sys_stdout(self): + original_stdout = sys.stdout + with redirect_stdout(io.StringIO(), per_thread=True): + print("foo") + self.assertIs(sys.stdout, original_stdout) + + def test_per_thread_doesnt_affect_main_thread(self): + main_buf = io.StringIO() + thread_buf = io.StringIO() + entered = threading.Event() + done = threading.Event() + + def thread1(): + with redirect_stdout(thread_buf, per_thread=True): + entered.set() + done.wait() + + original_stdout = sys.stdout + sys.stdout = main_buf + try: + t = threading.Thread(target=thread1) + t.start() + entered.wait() + # the thread's redirect is set, but it shouldn't affect us. + print("from main") + done.set() + t.join() + finally: + sys.stdout = original_stdout + + self.assertEqual(thread_buf.getvalue(), "") + self.assertEqual(main_buf.getvalue(), "from main\n") + + def test_cleans_up_on_exception(self): + original_stdout = sys.stdout + buf = io.StringIO() + + with self.assertRaises(ValueError): + with redirect_stdout(buf, per_thread=True): + print("foo") + raise ValueError("bar") + + self.assertIn("foo", buf.getvalue()) + self.assertIs(sys.stdout, original_stdout) + + def test_stress_test(self): + original_stdout = sys.stdout + + for i in range(50): + buf = io.StringIO() + with redirect_stdout(buf, per_thread=True): + print(f"iteration {i}") + self.assertEqual(buf.getvalue(), f"iteration {i}\n") + + self.assertIs(sys.stdout, original_stdout) + + def test_mixing_per_thread_true_and_false(self): + stream1 = io.StringIO() + stream2 = io.StringIO() + + with redirect_stdout(stream1, per_thread=True): + print("per_thread true") + with redirect_stdout(stream2, per_thread=False): + print("per_thread false") + print("back to true") + + self.assertEqual(stream1.getvalue(), "per_thread true\nback to true\n") + self.assertEqual(stream2.getvalue(), "per_thread false\n") + + def test_nested_single_thread(self): + s1 = io.StringIO() + s2 = io.StringIO() + s3 = io.StringIO() + + with redirect_stdout(s1, per_thread=True): + print("start1") + with redirect_stdout(s2, per_thread=True): + print("start2") + with redirect_stdout(s3, per_thread=True): + print("start3") + print("end3") + print("end2") + print("end1") + + self.assertEqual(s1.getvalue(), "start1\nend1\n") + self.assertEqual(s2.getvalue(), "start2\nend2\n") + self.assertEqual(s3.getvalue(), "start3\nend3\n") + + def test_simultaneous_threads(self): + n_threads = 10 + barrier1 = threading.Barrier(n_threads) + barrier2 = threading.Barrier(n_threads) + bufs = [io.StringIO() for _ in range(n_threads)] + + def f(n, stream): + barrier1.wait() + with redirect_stdout(stream, per_thread=True): + print(f"thread {n}") + barrier2.wait() + + threads = [threading.Thread(target=f, args=(i, bufs[i])) for i in range(n_threads)] + + for t in threads: + t.start() + for t in threads: + t.join() + + for i in range(n_threads): + self.assertEqual(bufs[i].getvalue(), f"thread {i}\n") + + def test_per_thread_true_then_false(self): + # this test sets up two threads: + # + # - A enters per_thread=True + # - B enters per_thread=False + # - A prints + # - A exits (restore to actual sys.stdout) + # - B exits (restore to A's PerThreadState) + # + # When B enters, the stdlib behavior is to overwrite sys.stdout completely, + # overwriting the per-thread setup of A. + # + # We have a design decision to make here. If you request a per_thread=False + # overwrite while a per_thread=True overwrite is active, we can either: + # + # (1) respect this global overwrite and redirect all per-thread streams to the + # new context manager + # (2) treat this per_thread=False request as if it were per_thread=True + # + # We choose (1). + + original_stdout = sys.stdout + + b_start = threading.Event() + a_start = threading.Event() + a_end = threading.Event() + stream_a = io.StringIO() + stream_b = io.StringIO() + + def thread_a(): + with redirect_stdout(stream_a, per_thread=True): + a_start.set() + b_start.wait() + print("from_a") + a_end.set() + + def thread_b(): + a_start.wait() + with redirect_stdout(stream_b): + b_start.set() + a_end.wait() + + t_a = threading.Thread(target=thread_a) + t_b = threading.Thread(target=thread_b) + + t_a.start() + t_b.start() + t_a.join() + t_b.join() + + self.assertEqual(stream_a.getvalue(), "") + self.assertEqual(stream_b.getvalue(), "from_a\n") + self.assertIs(sys.stdout, original_stdout) + + def test_per_thread_false_then_true(self): + # this test sets up two threads: + # + # - A enters per_thread=False + # - B enters per_thread=True + # - A exits (restore to actual sys.stdout) + # - B prints + # - B exits + + original_stdout = sys.stdout + + b_start = threading.Event() + a_start = threading.Event() + a_end = threading.Event() + stream_a = io.StringIO() + stream_b = io.StringIO() + + def thread_a(): + with redirect_stdout(stream_a): + a_start.set() + b_start.wait() + a_end.set() + + def thread_b(): + a_start.wait() + with redirect_stdout(stream_b, per_thread=True): + b_start.set() + a_end.wait() + print("from_b") + + t_a = threading.Thread(target=thread_a) + t_b = threading.Thread(target=thread_b) + + t_a.start() + t_b.start() + t_a.join() + t_b.join() + + self.assertEqual(stream_a.getvalue(), "") + self.assertEqual(stream_b.getvalue(), "from_b\n") + self.assertIs(sys.stdout, original_stdout) + + def test_stacked_globals_resurface(self): + # Thread timeline: + # A enters per_thread=True + # B enters per_thread=False + # C enters per_thread=False + # A prints (captured by C global) + # C exits + # A prints (captured by B global) + # B exits + # A prints (captured by A per-thread) + # A exits + + original_stdout = sys.stdout + stream_a = io.StringIO() + stream_b = io.StringIO() + stream_c = io.StringIO() + + a_entered = threading.Event() + b_entered = threading.Event() + c_entered = threading.Event() + print1_done = threading.Event() + c_exited = threading.Event() + print2_done = threading.Event() + b_exited = threading.Event() + + def thread_a(): + with redirect_stdout(stream_a, per_thread=True): + a_entered.set() + c_entered.wait() + print("to_c") + print1_done.set() + c_exited.wait() + print("to_b") + print2_done.set() + b_exited.wait() + print("to_a") + + def thread_b(): + a_entered.wait() + with redirect_stdout(stream_b, per_thread=False): + b_entered.set() + print2_done.wait() + b_exited.set() + + def thread_c(): + b_entered.wait() + with redirect_stdout(stream_c, per_thread=False): + c_entered.set() + print1_done.wait() + c_exited.set() + + threads = [threading.Thread(target=t) for t in [thread_a, thread_b, thread_c]] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(stream_c.getvalue(), "to_c\n") + self.assertEqual(stream_b.getvalue(), "to_b\n") + self.assertEqual(stream_a.getvalue(), "to_a\n") + self.assertIs(sys.stdout, original_stdout) + + def test_equal_streams_pops_correctly(self): + class NamedStream(io.StringIO): + def __init__(self, name): + super().__init__() + self.name = name + + def __eq__(self, other): + return isinstance(other, NamedStream) and self.name == other.name + + original = sys.stdout + + stream_a = NamedStream("a") + stream_b = NamedStream("a") + self.assertEqual(stream_a, stream_b) + + a_start = threading.Event() + b_start = threading.Event() + b_end = threading.Event() + + def thread_a(): + with redirect_stdout(stream_a, per_thread=False): + a_start.set() + b_start.wait() + b_end.wait() + print("from_a") + + def thread_b(): + a_start.wait() + with redirect_stdout(stream_b, per_thread=False): + b_start.set() + b_end.set() + + t_a = threading.Thread(target=thread_a) + t_b = threading.Thread(target=thread_b) + t_a.start() + t_b.start() + t_a.join() + t_b.join() + + self.assertEqual(stream_a.getvalue(), "from_a\n") + self.assertEqual(stream_b.getvalue(), "") + self.assertIs(sys.stdout, original) + + if __name__ == "__main__": unittest.main()