mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
test_util.capture_stdout redirects using file descriptors rather than mocking the python interface.
PiperOrigin-RevId: 640183718
This commit is contained in:
parent
c3ac4b55da
commit
9939cc9974
@ -19,10 +19,10 @@ import datetime
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from typing import Any, Callable
|
||||
@ -178,11 +178,26 @@ def check_eq(xs, ys, err_msg=''):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture_stdout() -> Generator[Callable[[], str], None, None]:
|
||||
with unittest.mock.patch('sys.stdout', new_callable=io.StringIO) as fp:
|
||||
def _read() -> str:
|
||||
return fp.getvalue()
|
||||
yield _read
|
||||
def capture_stdout() -> Generator[Callable[[], str | None], None, None]:
|
||||
"""Context manager to capture all stdout output."""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f:
|
||||
original_stdout = os.dup(sys.stdout.fileno())
|
||||
os.dup2(f.fileno(), sys.stdout.fileno())
|
||||
|
||||
# if get_stdout returns not it means we are not done capturing
|
||||
# stdout. it should only be used after the context has exited.
|
||||
captured = None
|
||||
get_stdout: Callable[[], str | None] = lambda: captured
|
||||
|
||||
try:
|
||||
yield get_stdout
|
||||
finally:
|
||||
# Python also has its own buffers, make sure everything is flushed.
|
||||
sys.stdout.flush()
|
||||
f.seek(0)
|
||||
captured = f.read()
|
||||
os.dup2(original_stdout, sys.stdout.fileno())
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -807,7 +807,8 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
|
||||
"7: 14", "Out: 7.0", ""]
|
||||
jax.effects_barrier()
|
||||
self._assertLinesEqual(output(), "\n".join(lines))
|
||||
|
||||
self._assertLinesEqual(output(), "\n".join(lines))
|
||||
|
||||
def test_unordered_print_with_xmap(self):
|
||||
def f(x):
|
||||
|
@ -12,11 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
@ -29,21 +25,6 @@ import numpy as np
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_stdout():
|
||||
"""Context manager to capture all stdout output and return it as a string."""
|
||||
captured_output = [None]
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as f:
|
||||
original_stdout_fd = sys.stdout.fileno()
|
||||
os.dup2(f.fileno(), original_stdout_fd)
|
||||
try:
|
||||
yield captured_output
|
||||
finally:
|
||||
os.dup2(original_stdout_fd, sys.stdout.fileno())
|
||||
f.seek(0)
|
||||
captured_output[0] = f.read()
|
||||
|
||||
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -113,10 +94,10 @@ class PallasCallTest(PallasTest):
|
||||
pl.debug_print("It works!")
|
||||
|
||||
x = jnp.arange(256).astype(jnp.float32)
|
||||
with capture_stdout() as captured_output:
|
||||
kernel(x)
|
||||
with jtu.capture_stdout() as output:
|
||||
jax.block_until_ready(kernel(x))
|
||||
|
||||
self.assertEqual(captured_output[0], "It works!\n")
|
||||
self.assertEqual(output(), "It works!\n")
|
||||
|
||||
def test_print_with_values(self):
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user