mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime. * The new runtime makes it unnecessary to start the outfeed_receiver in the user's code * We don't need msgpack anymore * There is an interaction between host_callback and using lax.outfeed. I am trying to solve this by (a) making host_callback_test stop the outfeed receiver on finish and infeed_test on start, and (b) telling pytest-xdist to run all the tests from one file into a single worker.
This commit is contained in:
parent
904f34a9ba
commit
4f3011f320
4
.github/workflows/ci-build.yaml
vendored
4
.github/workflows/ci-build.yaml
vendored
@ -101,5 +101,5 @@ jobs:
|
||||
pip install -r docs/requirements.txt
|
||||
- name: Test documentation
|
||||
run: |
|
||||
pytest docs
|
||||
pytest --doctest-modules jax/api.py
|
||||
pytest -n 1 docs
|
||||
pytest -n 1 --doctest-modules jax/api.py
|
||||
|
@ -1,6 +1,5 @@
|
||||
flake8
|
||||
jaxlib==0.1.51
|
||||
msgpack
|
||||
mypy==0.770
|
||||
pytest-benchmark
|
||||
pytest-xdist
|
||||
|
@ -5,12 +5,11 @@ ipykernel
|
||||
nbsphinx
|
||||
sphinx-autodoc-typehints
|
||||
myst-parser[sphinx]
|
||||
# For host_callback.py
|
||||
msgpack
|
||||
# The next packages are for notebooks
|
||||
matplotlib
|
||||
sklearn
|
||||
# For CI tests.
|
||||
pytest
|
||||
pytest-xdist
|
||||
# Must install jax itself for notebook execution to work
|
||||
.
|
||||
|
@ -354,7 +354,7 @@ def xla_computation(fun: Callable,
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
|
||||
instantiate=instantiate_const_outputs,
|
||||
stage_out=True)
|
||||
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -709,7 +709,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
|
||||
jaxpr, uses_outfeed = xla.apply_outfeed_rewriter(jaxpr)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_pvs, out_consts = unzip2(out_pvals)
|
||||
|
||||
@ -862,7 +862,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
|
||||
num_partitions, out_parts,
|
||||
out_pvals, compiled.local_devices(),
|
||||
backend)
|
||||
return partial(execute_replicated, compiled, uses_outfeed, backend, handle_args,
|
||||
return partial(execute_replicated, compiled, backend, handle_args,
|
||||
handle_outs)
|
||||
|
||||
multi_host_supported_collectives: Set[core.Primitive] = set()
|
||||
@ -1105,9 +1105,7 @@ def partitioned_sharding_spec(num_partitions: int,
|
||||
replication_factors=[])
|
||||
|
||||
|
||||
def execute_replicated(compiled,
|
||||
uses_outfeed, backend, in_handler, out_handler, *args):
|
||||
xla.check_before_outfeed_execution(uses_outfeed)
|
||||
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
|
||||
input_bufs = in_handler(args)
|
||||
out_bufs = compiled.execute_on_local_devices(list(input_bufs))
|
||||
return out_handler(out_bufs)
|
||||
|
@ -174,12 +174,12 @@ pytype_aval_mappings.update(
|
||||
# We can optionally set a Jaxpr rewriter that can be applied just before
|
||||
# compilation. This mechanism is used for compiling id_tap, we can
|
||||
# remove it once we bring the id_tap implementation into the core.
|
||||
outfeed_rewriter: Optional[Callable[[core.Jaxpr], Tuple[core.Jaxpr, bool]]] = None
|
||||
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, bool]:
|
||||
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
|
||||
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
if outfeed_rewriter is not None:
|
||||
return outfeed_rewriter(jaxpr)
|
||||
else:
|
||||
return jaxpr, False
|
||||
return jaxpr
|
||||
|
||||
outfeed_primitives: Set[core.Primitive] = set()
|
||||
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
|
||||
@ -207,13 +207,6 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
|
||||
return True
|
||||
return False
|
||||
|
||||
# TODO(necula): remove this when we start the outfeed receiver automatically.
|
||||
can_execute_outfeed_computations: bool = False # Set by outfeed_receiver
|
||||
def check_before_outfeed_execution(uses_outfeed: bool):
|
||||
if uses_outfeed and not can_execute_outfeed_computations:
|
||||
raise ValueError("Attempting to execute compiled code using outfeed, "
|
||||
"but outfeed_receiver is not started.")
|
||||
|
||||
### op-by-op execution
|
||||
|
||||
def arg_spec(x):
|
||||
@ -607,7 +600,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr(
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True)
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
@ -667,9 +660,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
compiled = backend.compile(built, compile_options=options)
|
||||
if nreps == 1:
|
||||
return partial(_execute_compiled, compiled, uses_outfeed, result_handlers)
|
||||
return partial(_execute_compiled, compiled, result_handlers)
|
||||
else:
|
||||
return partial(_execute_replicated, compiled, uses_outfeed, result_handlers)
|
||||
return partial(_execute_replicated, compiled, result_handlers)
|
||||
|
||||
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
|
||||
"""Configures input/output "must" aliasing based on `donated_args`."""
|
||||
@ -767,18 +760,14 @@ def _pval_to_result_handler(device, pval):
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
handlers, *args):
|
||||
check_before_outfeed_execution(uses_outfeed)
|
||||
def _execute_compiled(compiled: XlaExecutable, handlers, *args):
|
||||
device, = compiled.local_devices()
|
||||
input_bufs = [device_put(x, device) for x in args if x is not token]
|
||||
out_bufs = compiled.execute(input_bufs)
|
||||
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
|
||||
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
|
||||
|
||||
def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
|
||||
handlers, *args):
|
||||
check_before_outfeed_execution(uses_outfeed)
|
||||
def _execute_replicated(compiled: XlaExecutable, handlers, *args):
|
||||
input_bufs = [
|
||||
[device_put(x, device) for x in args if x is not token]
|
||||
for device in compiled.local_devices()]
|
||||
|
@ -4,9 +4,14 @@ filterwarnings =
|
||||
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
|
||||
ignore:Explicitly requested dtype.*is not available.*:UserWarning
|
||||
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
|
||||
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
|
||||
# The rest are for experimental/jax_to_tf
|
||||
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
|
||||
ignore:can't resolve package from __spec__ or __package__:ImportWarning
|
||||
ignore:Using or importing the ABCs.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst"
|
||||
addopts = --doctest-glob="*.rst" --dist=loadfile
|
||||
# --dist=loadfile ensure that all the tests in one file are sent to the same runner. This is useful
|
||||
# for host_callback_test which start and then stop on teardown the C++ outfeed receiver
|
||||
# runtime. If we do not stop the receiver, other tests that use outfeed are going to fail.
|
||||
|
||||
|
@ -16,11 +16,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import partial
|
||||
import functools
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Sequence
|
||||
from unittest import SkipTest
|
||||
|
||||
@ -34,27 +35,30 @@ from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def skip_if_jit_not_enabled():
|
||||
if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false":
|
||||
raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.")
|
||||
|
||||
|
||||
def supported_dtypes():
|
||||
return sorted(jtu.supported_dtypes(), key=lambda x: np.dtype(x).name)
|
||||
|
||||
|
||||
class _TestingOutputStream(object):
|
||||
"""Use as `output_stream` for tests."""
|
||||
|
||||
def __init__(self):
|
||||
self._output = []
|
||||
self.testMethodName = None
|
||||
self.test_method_name = None
|
||||
|
||||
def write(self, what: str) -> None:
|
||||
print(f"output_stream[{self.testMethodName}]: {what}", end="")
|
||||
print(f"output_stream[{self.test_method_name}]: {what}", end="")
|
||||
self._output.append(what)
|
||||
|
||||
@property
|
||||
@ -80,14 +84,16 @@ def fun1(a):
|
||||
def fun1_equiv(a): # Numerical equivalent of fun`
|
||||
return (a * 2.)**2
|
||||
|
||||
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str):
|
||||
|
||||
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase,
|
||||
expected: str, what: str):
|
||||
"""A variant that preprocesses the string to eliminate non-determinism in
|
||||
floating point values, and several uninteresting id_tap primitive params."""
|
||||
floating point values, and several uninteresting id_tap primitive params.
|
||||
"""
|
||||
# Sometimes we get floating points in the output; we round them
|
||||
def repl_floats(match_group):
|
||||
matched = match_group.group(0)
|
||||
if matched == ".": return matched
|
||||
# TODO: why can't we use here np.around?
|
||||
x = np.around(float(matched), decimals=2)
|
||||
return f"{x:.2f}"
|
||||
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
|
||||
@ -98,23 +104,29 @@ def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str)
|
||||
def repl_func(match_group):
|
||||
matched = match_group.group(0)
|
||||
if "function _print_consumer" in matched:
|
||||
return "func=_print"
|
||||
return "tap_func_=_print"
|
||||
else:
|
||||
return "..."
|
||||
what = re.sub(r"func=(.*)", repl_func, what)
|
||||
what = re.sub(r"tap_func_=(.*)", repl_func, what)
|
||||
tst.assertMultiLineStrippedEqual(expected, what)
|
||||
|
||||
|
||||
class HostCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
testing_stream.reset()
|
||||
testing_stream.testMethodName = self._testMethodName
|
||||
testing_stream.test_method_name = self._testMethodName
|
||||
self.old_flags = os.getenv("XLA_FLAGS", "")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if os.getenv("XLA_FLAGS") != self.old_flags:
|
||||
os.environ["XLA_FLAGS"] = self.old_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
hcb.barrier_wait()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
hcb.stop_outfeed_receiver()
|
||||
|
||||
def helper_set_devices(self, nr_devices):
|
||||
flags_str = os.getenv("XLA_FLAGS", "")
|
||||
@ -135,8 +147,8 @@ class HostCallbackTest(jtu.JaxTestCase):
|
||||
# TODO: renable jaxpr golden tests when changing host_callback
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.)))
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
@ -150,8 +162,9 @@ what: y * 3
|
||||
return x1 + y1
|
||||
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.)))
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
hcb.barrier_wait()
|
||||
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
[ 6.00
|
||||
9.00 ]""", testing_stream.output)
|
||||
@ -162,8 +175,8 @@ what: y * 3
|
||||
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
|
||||
return res["a"] + res["b"]
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
self.assertEqual(3. * (2. + 3.), func2(3.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ a=6.00
|
||||
b=9.00 }""", testing_stream.output)
|
||||
@ -175,8 +188,8 @@ what: y * 3
|
||||
output_stream=testing_stream)
|
||||
return x1
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertEqual(3. * 4., func2(3.))
|
||||
self.assertEqual(3. * 4., func2(3.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
[ 6.00
|
||||
9.00 ]""", testing_stream.output)
|
||||
@ -194,8 +207,8 @@ what: y * 3
|
||||
return x3
|
||||
|
||||
with self.assertRaises(hcb.TapFunctionException):
|
||||
with hcb.outfeed_receiver():
|
||||
_ = func(0)
|
||||
func(0)
|
||||
hcb.barrier_wait()
|
||||
|
||||
# We should have received everything before the error
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
@ -208,11 +221,8 @@ what: x3
|
||||
def test_jit_simple(self):
|
||||
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
|
||||
2. * x, what="here", output_stream=testing_stream))
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = jit_fun1(5.)
|
||||
|
||||
self.assertAllClose(6. * 5., res)
|
||||
self.assertAllClose(6. * 5., jit_fun1(5.))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: here
|
||||
10.00""", testing_stream.output)
|
||||
@ -224,8 +234,8 @@ what: here
|
||||
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5)))
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertAllClose(5, api.jit(func)(5))
|
||||
self.assertAllClose(5, api.jit(func)(5))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
42""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
@ -239,9 +249,9 @@ what: here
|
||||
api.make_jaxpr(func)(1))
|
||||
logging.info("%s: %s", self._testMethodName,
|
||||
api.xla_computation(func)(1).as_hlo_text())
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
hcb.barrier_wait()
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -256,10 +266,9 @@ where: 2
|
||||
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
||||
return x2
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertEqual(11, api.jit(func)(10))
|
||||
|
||||
self.assertEqual(2, api.jit(func)(1))
|
||||
self.assertEqual(11, api.jit(func)(10))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -280,8 +289,8 @@ where: 2
|
||||
x3 = api.jit(func_nested)(x1)
|
||||
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(3, api.jit(func)(1))
|
||||
self.assertEqual(3, api.jit(func)(1))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -300,9 +309,9 @@ where: 3
|
||||
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
|
||||
return x2
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
for d in devices:
|
||||
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
|
||||
for d in devices:
|
||||
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
|
||||
hcb.barrier_wait()
|
||||
logging.info(f"{self._testMethodName}: found output {testing_stream.output}")
|
||||
self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output)))
|
||||
self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output)))
|
||||
@ -332,13 +341,12 @@ where: 3
|
||||
self.assertEqual(func(5, what), a)
|
||||
|
||||
transform = api.jit if with_jit else lambda f: f
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
for what in ("pair_1_x", "pair_x_2x", "dict"):
|
||||
self.assertEqual(func(10, what),
|
||||
transform(lambda x: hcb.id_tap(tap_func, func(x, what),
|
||||
result=func(x * 2, what),
|
||||
what=what))(5))
|
||||
# Wait for receivers to be done
|
||||
for what in ("pair_1_x", "pair_x_2x", "dict"):
|
||||
self.assertEqual(func(10, what),
|
||||
transform(lambda x: hcb.id_tap(tap_func, func(x, what),
|
||||
result=func(x * 2, what),
|
||||
what=what))(5))
|
||||
hcb.barrier_wait() # Wait for receivers to be done
|
||||
self.assertEqual(3, tap_count)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -354,15 +362,17 @@ where: 3
|
||||
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
||||
|
||||
x4 = lax.cond(x % 2 == 0,
|
||||
lambda x: hcb.id_print(x, where="cond_t", output_stream=testing_stream),
|
||||
lambda x: hcb.id_print(-1, where="cond_f", result=x, output_stream=testing_stream),
|
||||
lambda x: hcb.id_print(x, where="cond_t",
|
||||
output_stream=testing_stream),
|
||||
lambda x: hcb.id_print(-1, where="cond_f", result=x,
|
||||
output_stream=testing_stream),
|
||||
x2 + 1)
|
||||
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
|
||||
return x5
|
||||
|
||||
transform = api.jit if with_jit else lambda f: f
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(4, transform(func)(1))
|
||||
self.assertEqual(4, transform(func)(1))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -376,9 +386,8 @@ where: end
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
dict(testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
def test_while_cond(self, with_jit=False):
|
||||
def func(x):
|
||||
@ -398,8 +407,8 @@ where: end
|
||||
return res
|
||||
|
||||
transform = api.jit if with_jit else lambda f: f
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(4, transform(func)(1))
|
||||
self.assertEqual(4, transform(func)(1))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -434,8 +443,8 @@ where: end
|
||||
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
|
||||
return res
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertEqual(3, api.jit(func)(1))
|
||||
self.assertEqual(3, api.jit(func)(1))
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self,
|
||||
"""
|
||||
where: w_p
|
||||
@ -475,11 +484,11 @@ where: 3
|
||||
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
|
||||
return res
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
if with_jit:
|
||||
func = api.jit(func)
|
||||
res = func(1)
|
||||
self.assertAllClose(jnp.array([1, 2, 3]), res)
|
||||
if with_jit:
|
||||
func = api.jit(func)
|
||||
res = func(1)
|
||||
self.assertAllClose(jnp.array([1, 2, 3]), res)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
where: 1
|
||||
1
|
||||
@ -529,25 +538,23 @@ where: 10
|
||||
xs,
|
||||
a_new_test="************",
|
||||
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
_ = jit_fun1(args)
|
||||
# self.assertAllClose(args, res)
|
||||
|
||||
res = jit_fun1(args)
|
||||
self.assertAllClose(args, res)
|
||||
|
||||
def test_jit_large(self):
|
||||
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
api.jit(hcb.id_print)(arg)
|
||||
api.jit(hcb.id_print)(arg)
|
||||
|
||||
def test_jit_several_together(self):
|
||||
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
|
||||
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
|
||||
|
||||
def test_jit_interleaving(self):
|
||||
# Several jit's without data dependencies; they may interfere
|
||||
count = 0 # Count tap invocations
|
||||
nr_arrays = 5
|
||||
def tap_func(arg, **kwargs):
|
||||
def tap_func(arg, **_):
|
||||
nonlocal count
|
||||
assert len(arg) == nr_arrays
|
||||
count += 1
|
||||
@ -556,12 +563,13 @@ where: 10
|
||||
for i in range(count):
|
||||
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1]
|
||||
return x
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
x = jnp.array(1, dtype=np.int32)
|
||||
res = 0
|
||||
for i in range(10):
|
||||
# No dependencies between the jit invocations
|
||||
res += api.jit(lambda x: func(x, 10))(x)
|
||||
|
||||
x = jnp.array(1, dtype=np.int32)
|
||||
res = 0
|
||||
for _ in range(10):
|
||||
# No dependencies between the jit invocations
|
||||
res += api.jit(lambda x: func(x, 10))(x)
|
||||
hcb.barrier_wait()
|
||||
self.assertEqual(100, count)
|
||||
|
||||
def test_jit_tap_exception(self):
|
||||
@ -574,9 +582,10 @@ where: 10
|
||||
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
||||
return x3
|
||||
|
||||
res = api.jit(func)(0) # No error yet
|
||||
with self.assertRaises(hcb.TapFunctionException):
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = api.jit(func)(0)
|
||||
hcb.barrier_wait()
|
||||
|
||||
# Even though the receiver thread raised, the main thread should still
|
||||
# return 3.
|
||||
self.assertEqual(3, res)
|
||||
@ -588,49 +597,20 @@ what: x3
|
||||
3""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_jit_unknown_tap(self):
|
||||
# Simulate an unknown tap function
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
||||
x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
|
||||
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
||||
return x3
|
||||
|
||||
with self.assertRaises(hcb.TapFunctionException):
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = api.jit(func)(0)
|
||||
# Even though the receiver thread raised, the main thread should still
|
||||
# return 3.
|
||||
self.assertEqual(3, res)
|
||||
# We should have received all others
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x1
|
||||
1
|
||||
what: x3
|
||||
3""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
# On CPU and GPU the device code blocks
|
||||
# On GPU it seems that there is a 5 min timeout?
|
||||
# On TPU the client does not block, but messes up the rest somehow
|
||||
@jtu.skip_on_devices("cpu", "gpu", "tpu")
|
||||
def test_jit_receiver_ends_prematurely(self):
|
||||
# Simulate an unknown tap function
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
||||
x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1) # Will end the consumer loop
|
||||
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
||||
return x3
|
||||
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
_ = api.jit(func)(0)
|
||||
|
||||
assert False # It seems that the previous jit blocks above
|
||||
|
||||
def test_jit_error_no_consumer(self):
|
||||
# Check for errors if starting jit without a consumer active
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
||||
api.jit(lambda x: hcb.id_print(x))(0)
|
||||
def test_jit_nested_cond_no_print(self):
|
||||
"""A nested conditional, without any prints"""
|
||||
raise SkipTest("skip this")
|
||||
@api.jit
|
||||
def cfun(x):
|
||||
return lax.cond(
|
||||
lax.lt(x, 2),
|
||||
lambda x: x,
|
||||
lambda x: lax.cond(x < 5,
|
||||
3, lambda x: x,
|
||||
4, lambda y: y),
|
||||
x)
|
||||
print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text())
|
||||
cfun(1)
|
||||
|
||||
def test_while(self):
|
||||
"""Executing while, even without JIT uses compiled code"""
|
||||
@ -641,8 +621,8 @@ what: x3
|
||||
lambda c: c[1] < 5,
|
||||
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
||||
(x, 1))
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
func(y)
|
||||
func(y)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
1
|
||||
2
|
||||
@ -650,27 +630,13 @@ what: x3
|
||||
4""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_while_error_no_receiver(self):
|
||||
"""Executing while needs the receiver"""
|
||||
y = jnp.ones(5) # captured const
|
||||
def func(x):
|
||||
return lax.while_loop(
|
||||
lambda c: c[1] < 5,
|
||||
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
||||
(x, 1))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, ".*outfeed_receiver.*not started"):
|
||||
func(y).block_until_ready()
|
||||
|
||||
|
||||
def test_jvp(self):
|
||||
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
||||
#assertMultiLineStrippedEqual(self, "",
|
||||
# str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1))))
|
||||
with hcb.outfeed_receiver():
|
||||
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
||||
#assertMultiLineStrippedEqual(self, "")
|
||||
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
||||
self.assertAllClose(100., res_primals, check_dtypes=False)
|
||||
self.assertAllClose(4., res_tangents, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
@ -685,23 +651,24 @@ transforms: ({'name': 'jvp'},) what: y * 3
|
||||
def test_grad_primal_unused(self):
|
||||
# The output of id_print is not needed for backwards pass
|
||||
def func(x):
|
||||
return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream)
|
||||
return 2. * hcb.id_print(x * 3., what="x * 3",
|
||||
output_stream=testing_stream)
|
||||
|
||||
grad_func = api.grad(func)
|
||||
with hcb.outfeed_receiver():
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
jaxpr = str(api.make_jaxpr(grad_func)(5.))
|
||||
# Just making the Jaxpr invokes the id_print once
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
{ lambda ; a.
|
||||
let
|
||||
in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
||||
|
||||
# Just making the Jaxpr invokes the id_print once
|
||||
in (6.00,) }""", jaxpr)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
hcb.barrier_wait()
|
||||
|
||||
self.assertAllClose(6., res_grad, check_dtypes=False)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
@ -714,13 +681,14 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
|
||||
def test_grad_simple(self):
|
||||
def func(x):
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
||||
return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
|
||||
return x * hcb.id_print(y * 3., what="y * 3",
|
||||
output_stream=testing_stream)
|
||||
grad_func = api.grad(func)
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))
|
||||
|
||||
with hcb.outfeed_receiver():
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
@ -738,18 +706,19 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
||||
return x * (y * 3.)
|
||||
|
||||
grad_func = api.grad(api.grad(func))
|
||||
with hcb.outfeed_receiver():
|
||||
_ = api.make_jaxpr(grad_func)(5.)
|
||||
# Just making the Jaxpr invokes the id_print twiceonce
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
# Just making the Jaxpr invokes the id_print twice
|
||||
_ = api.make_jaxpr(grad_func)(5.)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
||||
3.00
|
||||
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
testing_stream.reset()
|
||||
res_grad = grad_func(jnp.float32(5.))
|
||||
|
||||
self.assertAllClose(12., res_grad, check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: x * 2
|
||||
10.00
|
||||
@ -765,8 +734,8 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
||||
vmap_fun1 = api.vmap(fun1)
|
||||
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs)))
|
||||
with hcb.outfeed_receiver():
|
||||
_ = vmap_fun1(vargs)
|
||||
vmap_fun1(vargs)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
|
||||
[ 8.00 10.00]
|
||||
@ -784,8 +753,8 @@ transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
|
||||
vmap_func = api.vmap(func)
|
||||
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs)))
|
||||
with hcb.outfeed_receiver():
|
||||
_ = vmap_func(vargs)
|
||||
_ = vmap_func(vargs)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
|
||||
[ 3.00
|
||||
@ -804,8 +773,8 @@ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
|
||||
xv = jnp.arange(5, dtype=np.int32)
|
||||
yv = jnp.arange(3, dtype=np.int32)
|
||||
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv)))
|
||||
with hcb.outfeed_receiver():
|
||||
_ = sum_all(xv, yv)
|
||||
_ = sum_all(xv, yv)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
|
||||
[[0 1 2 3 4]
|
||||
@ -827,9 +796,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dim
|
||||
return res
|
||||
|
||||
inputs = np.arange(5, dtype=np.int32)
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
||||
check_dtypes=False)
|
||||
hcb.barrier_wait()
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
|
||||
[0 1 2 3 4]
|
||||
@ -856,9 +825,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
|
||||
return res
|
||||
|
||||
inputs = np.arange(5, dtype=np.int32)
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
||||
check_dtypes=False)
|
||||
res = api.jit(api.vmap(func))(inputs)
|
||||
hcb.barrier_wait()
|
||||
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
|
||||
[0 1 2 3 4]
|
||||
@ -880,21 +849,15 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
|
||||
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
|
||||
|
||||
pmap_fun1 = api.pmap(fun1, axis_name="i")
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
res = pmap_fun1(vargs)
|
||||
res = pmap_fun1(vargs)
|
||||
hcb.barrier_wait()
|
||||
expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
|
||||
self.assertAllClose(expected_res, res, check_dtypes=False)
|
||||
|
||||
def test_pmap_error_no_receiver(self):
|
||||
# Check for errors if starting jit without a consumer active
|
||||
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
|
||||
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
||||
api.pmap(lambda x: hcb.id_print(x))(vargs)
|
||||
|
||||
def test_mask(self):
|
||||
# TODO(necula)
|
||||
raise SkipTest("masking has regressed")
|
||||
@partial(api.mask, in_shapes=['n'], out_shape='')
|
||||
@functools.partial(api.mask, in_shapes=['n'], out_shape='')
|
||||
def padded_sum(x):
|
||||
return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
|
||||
args = [jnp.arange(4)], dict(n=np.int64(2))
|
||||
@ -916,21 +879,133 @@ logical_shapes: [(2,)] transforms: ('mask',) what: x
|
||||
""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
def test_outfeed_receiver(self):
|
||||
"""Test the deprecated outfeed_receiver"""
|
||||
with hcb.outfeed_receiver():
|
||||
self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True)
|
||||
assertMultiLineStrippedEqual(self, """
|
||||
what: a * 2
|
||||
10.00
|
||||
what: y * 3
|
||||
30.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
|
||||
def test_callback_delay(self):
|
||||
hcb.callback_extra = lambda dev: time.sleep(1)
|
||||
|
||||
def func(x):
|
||||
for i in range(5):
|
||||
x = hcb.id_print(x * i, what="x times i")
|
||||
return x
|
||||
|
||||
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
||||
|
||||
def test_callback_delay_barrier(self):
|
||||
hcb.callback_extra = lambda dev: time.sleep(2)
|
||||
|
||||
def func(x):
|
||||
for i in range(1, 4):
|
||||
x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream)
|
||||
return x
|
||||
|
||||
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
||||
# Wait for the results
|
||||
hcb.barrier_wait()
|
||||
expected = """
|
||||
what: x times i
|
||||
[[0. 1. 2.]
|
||||
[3. 4. 5.]]
|
||||
what: x times i
|
||||
[[ 0. 2. 4.]
|
||||
[ 6. 8. 10.]]
|
||||
what: x times i
|
||||
[[ 0. 6. 12.]
|
||||
[18. 24. 30.]]"""
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
testing_stream.reset()
|
||||
# Call again
|
||||
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
||||
hcb.barrier_wait()
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
|
||||
def test_multiple_barriers(self):
|
||||
"""Call barrier_wait concurrently."""
|
||||
|
||||
def pause_tap(*args, **kwargs):
|
||||
logging.info("pause_tap waiting")
|
||||
time.sleep(2)
|
||||
logging.info("pause_tap done")
|
||||
|
||||
def long_run(x):
|
||||
return hcb.id_tap(pause_tap, x)
|
||||
|
||||
api.jit(long_run)(5.)
|
||||
|
||||
def try_barrier(idx):
|
||||
logging.info(f"Starting test barrier {idx}")
|
||||
hcb.barrier_wait()
|
||||
logging.info(f"Finished test barrier {idx}")
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
|
||||
for idx in range(3)
|
||||
]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
def test_error_bad_consumer_id(self):
|
||||
"""Try to use reserved consumer ID 0.
|
||||
|
||||
Check that we get the proper error from the runtime."""
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Consumer ID cannot be a reserved value: 0"):
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
comp, token, 0,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
|
||||
def test_error_different_shapes(self):
|
||||
"""Try to register different shapes for the same consumer ID."""
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, ".*does not match previous shape element_type.*"):
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, ".*does not match previous shape element_type.*"):
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))])
|
||||
|
||||
|
||||
class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
|
||||
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
||||
has_input_token=True, has_output_token=True):
|
||||
"""Check that the rewrite of func(*args) matches expected."""
|
||||
_ = api.make_jaxpr(func)(*args)
|
||||
jaxpr = api.make_jaxpr(func)(*args)
|
||||
# TODO: re-enable when we change the host_callback rewriter
|
||||
#assertMultiLineStrippedEqual(self, expected,
|
||||
# str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
|
||||
#rewritten = hcb._rewrite_typed_jaxpr(jaxpr,
|
||||
# has_input_token, has_output_token)
|
||||
#assertMultiLineStrippedEqual(self, expected, str(rewritten))
|
||||
del jaxpr
|
||||
|
||||
def test_no_outfeed(self):
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a.
|
||||
let b = mul a a
|
||||
c = add a b
|
||||
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False)
|
||||
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False,
|
||||
has_output_token=False)
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a d.
|
||||
let b = mul a a
|
||||
@ -946,9 +1021,11 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
self.assertRewrite("""
|
||||
{ lambda ; a d.
|
||||
let b = add a a
|
||||
c e = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
] b d
|
||||
c e = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] b d
|
||||
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
|
||||
|
||||
def test_cond(self):
|
||||
@ -962,8 +1039,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
d = convert_element_type[ new_dtype=int32
|
||||
old_dtype=bool ] c
|
||||
g h j = cond[ branches=( { lambda ; f_ e a b c g.
|
||||
let d h = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
let d h = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] c g
|
||||
in (d, e, h) }
|
||||
{ lambda ; d g_ a b c h.
|
||||
@ -980,15 +1059,13 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
|
||||
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
|
||||
(x, np.float32(1.)))
|
||||
# TODO: we should not need to start a receiver here!!! I believe this is
|
||||
# because of the partial evaluation of while, which calls impl, which
|
||||
# uses JIT.
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertRewrite("""
|
||||
self.assertRewrite("""
|
||||
{ lambda b c ; a f.
|
||||
let d e g = while[ body_jaxpr={ lambda ; c a b f.
|
||||
let d g = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
let d g = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] b f
|
||||
e = add d 1.00
|
||||
in (c, e, g) }
|
||||
@ -1011,43 +1088,47 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
|
||||
(x, 1))
|
||||
|
||||
# TODO: we should not need to start a receiver here!!! I believe this is
|
||||
# because of the partial evaluation of while, which calls impl, which
|
||||
# uses JIT.
|
||||
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
||||
self.assertRewrite("""
|
||||
{ lambda b c ; a f.
|
||||
let h i = xla_call[ call_jaxpr={ lambda ; c a b g.
|
||||
let d e h = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
] c b g
|
||||
f = lt e 5
|
||||
in (f, h) }
|
||||
name=cond_before ] b a 1 f
|
||||
y d e g = while[ body_jaxpr={ lambda ; n o p q r s.
|
||||
let t u v = xla_call[ call_jaxpr={ lambda ; c a b f.
|
||||
let d g = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
] b f
|
||||
e = add d 1
|
||||
in (c, e, g) }
|
||||
name=body ] o q r s
|
||||
w x = xla_call[ call_jaxpr={ lambda ; c a b g.
|
||||
let d e h = id_tap[ arg_treedef=*
|
||||
func=_print
|
||||
nr_untapped=1
|
||||
] c b g
|
||||
f = lt e 5
|
||||
in (f, h) }
|
||||
name=cond_body ] n t u v
|
||||
in (w, t, u, x) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; j k l m.
|
||||
let
|
||||
in (j,) }
|
||||
cond_nconsts=0 ] b c h a 1 i
|
||||
in (d, 5, g) }""", func, [ct_body])
|
||||
self.assertRewrite("""
|
||||
{ lambda b c ; a e.
|
||||
let g h = xla_call[ call_jaxpr={ lambda ; c a b f.
|
||||
let _ d g = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] c b f
|
||||
e = lt d 5
|
||||
in (e, g) }
|
||||
donated_invars=(False, False, False, False)
|
||||
name=cond_before ] b a 1 e
|
||||
x d _ f =
|
||||
while[ body_jaxpr={ lambda ; m n o p q r.
|
||||
let s t u = xla_call[ call_jaxpr={ lambda ; c a b f.
|
||||
let d g = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] b f
|
||||
e = add d 1
|
||||
in (c, e, g) }
|
||||
donated_invars=(False, False, False, False, False, False, False)
|
||||
name=body ] n p q r
|
||||
v w = xla_call[ call_jaxpr={ lambda ; c a b f.
|
||||
let _ d g = id_tap[ arg_treedef_=*
|
||||
has_token_=True
|
||||
nr_tapped_args_=1
|
||||
tap_func_=_print
|
||||
] c b f
|
||||
e = lt d 5
|
||||
in (e, g) }
|
||||
donated_invars=(False, False, False, False, False, False)
|
||||
name=cond_body ] m s t u
|
||||
in (v, s, t, w) }
|
||||
body_nconsts=2
|
||||
cond_jaxpr={ lambda ; i j k l.
|
||||
let
|
||||
in (i,) }
|
||||
cond_nconsts=0 ] b c g a 1 h
|
||||
in (d, 5, f) }""", func, [ct_body])
|
||||
|
||||
def test_scan(self):
|
||||
y = jnp.ones(5) # captured const
|
||||
@ -1055,16 +1136,19 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
|
||||
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
|
||||
self.assertRewrite("""
|
||||
{ lambda b ; a f.
|
||||
let c d g e = scan[ jaxpr={ lambda ; f a b g c.
|
||||
let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
||||
func=_print
|
||||
] a b g
|
||||
in (d, e, h, f) }
|
||||
length=5
|
||||
linear=(False, False, False, False, False)
|
||||
num_carry=3
|
||||
num_consts=1
|
||||
reverse=False ] b 1 2 f a
|
||||
let c d g e =
|
||||
scan[ jaxpr={ lambda ; f a b g c.
|
||||
let d e h = id_tap[ arg_treedef_=PyTreeDef(tuple, [*,*])
|
||||
has_token_=True
|
||||
nr_tapped_args_=2
|
||||
tap_func_=_print
|
||||
] a b g
|
||||
in (d, e, h, f) }
|
||||
length=5
|
||||
linear=(False, False, False, False, False)
|
||||
num_carry=3
|
||||
num_consts=1
|
||||
reverse=False ] b 1 2 f a
|
||||
in (c, d, e, g) }""", func, [y])
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax import lax, numpy as np
|
||||
from jax.config import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.lib import xla_client
|
||||
import jax.test_util as jtu
|
||||
import numpy as onp
|
||||
@ -47,6 +48,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(x), x + y + z)
|
||||
|
||||
def testInfeedThenOutfeed(self):
|
||||
hcb.stop_outfeed_receiver()
|
||||
@jax.jit
|
||||
def f(x):
|
||||
token = lax.create_token(x)
|
||||
@ -67,6 +69,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(out, y + onp.float32(1))
|
||||
|
||||
def testInfeedThenOutfeedInALoop(self):
|
||||
hcb.stop_outfeed_receiver()
|
||||
def doubler(_, token):
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), np.float32))
|
||||
|
Loading…
x
Reference in New Issue
Block a user