test_util.capture_stdout redirects using file descriptors rather than mocking the python interface.

PiperOrigin-RevId: 640183718
This commit is contained in:
Christos Perivolaropoulos 2024-06-04 09:40:56 -07:00 committed by jax authors
parent c3ac4b55da
commit 9939cc9974
3 changed files with 26 additions and 29 deletions

View File

@ -19,10 +19,10 @@ import datetime
import functools
from functools import partial
import inspect
import io
import math
import os
import re
import sys
import tempfile
import textwrap
from typing import Any, Callable
@ -178,11 +178,26 @@ def check_eq(xs, ys, err_msg=''):
@contextmanager
def capture_stdout() -> Generator[Callable[[], str], None, None]:
with unittest.mock.patch('sys.stdout', new_callable=io.StringIO) as fp:
def _read() -> str:
return fp.getvalue()
yield _read
def capture_stdout() -> Generator[Callable[[], str | None], None, None]:
"""Context manager to capture all stdout output."""
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f:
original_stdout = os.dup(sys.stdout.fileno())
os.dup2(f.fileno(), sys.stdout.fileno())
# if get_stdout returns not it means we are not done capturing
# stdout. it should only be used after the context has exited.
captured = None
get_stdout: Callable[[], str | None] = lambda: captured
try:
yield get_stdout
finally:
# Python also has its own buffers, make sure everything is flushed.
sys.stdout.flush()
f.seek(0)
captured = f.read()
os.dup2(original_stdout, sys.stdout.fileno())
@contextmanager

View File

@ -807,7 +807,8 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
"7: 14", "Out: 7.0", ""]
jax.effects_barrier()
self._assertLinesEqual(output(), "\n".join(lines))
self._assertLinesEqual(output(), "\n".join(lines))
def test_unordered_print_with_xmap(self):
def f(x):

View File

@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import os
import sys
import tempfile
from absl.testing import absltest
from absl.testing import parameterized
import jax
@ -29,21 +25,6 @@ import numpy as np
jax.config.parse_flags_with_absl()
@contextlib.contextmanager
def capture_stdout():
"""Context manager to capture all stdout output and return it as a string."""
captured_output = [None]
with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as f:
original_stdout_fd = sys.stdout.fileno()
os.dup2(f.fileno(), original_stdout_fd)
try:
yield captured_output
finally:
os.dup2(original_stdout_fd, sys.stdout.fileno())
f.seek(0)
captured_output[0] = f.read()
class PallasTest(jtu.JaxTestCase):
def setUp(self):
@ -113,10 +94,10 @@ class PallasCallTest(PallasTest):
pl.debug_print("It works!")
x = jnp.arange(256).astype(jnp.float32)
with capture_stdout() as captured_output:
kernel(x)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertEqual(captured_output[0], "It works!\n")
self.assertEqual(output(), "It works!\n")
def test_print_with_values(self):
@functools.partial(