Refactored host_callback to use the C++ runtime. (#3644)

* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
This commit is contained in:
George Necula 2020-07-04 18:12:58 +03:00 committed by GitHub
parent 904f34a9ba
commit 4f3011f320
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 866 additions and 763 deletions

View File

@ -101,5 +101,5 @@ jobs:
pip install -r docs/requirements.txt
- name: Test documentation
run: |
pytest docs
pytest --doctest-modules jax/api.py
pytest -n 1 docs
pytest -n 1 --doctest-modules jax/api.py

View File

@ -1,6 +1,5 @@
flake8
jaxlib==0.1.51
msgpack
mypy==0.770
pytest-benchmark
pytest-xdist

View File

@ -5,12 +5,11 @@ ipykernel
nbsphinx
sphinx-autodoc-typehints
myst-parser[sphinx]
# For host_callback.py
msgpack
# The next packages are for notebooks
matplotlib
sklearn
# For CI tests.
pytest
pytest-xdist
# Must install jax itself for notebook execution to work
.

View File

@ -354,7 +354,7 @@ def xla_computation(fun: Callable,
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals,
instantiate=instantiate_const_outputs,
stage_out=True)
jaxpr, _ = xla.apply_outfeed_rewriter(jaxpr)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
c = xb.make_computation_builder('xla_computation_{}'.format(fun_name))
xla_consts = map(partial(xb.constant, c), consts)

File diff suppressed because it is too large Load Diff

View File

@ -709,7 +709,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True)
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
jaxpr, uses_outfeed = xla.apply_outfeed_rewriter(jaxpr)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_pvs, out_consts = unzip2(out_pvals)
@ -862,7 +862,7 @@ def parallel_callable(fun, backend, axis_name, axis_size, global_axis_size,
num_partitions, out_parts,
out_pvals, compiled.local_devices(),
backend)
return partial(execute_replicated, compiled, uses_outfeed, backend, handle_args,
return partial(execute_replicated, compiled, backend, handle_args,
handle_outs)
multi_host_supported_collectives: Set[core.Primitive] = set()
@ -1105,9 +1105,7 @@ def partitioned_sharding_spec(num_partitions: int,
replication_factors=[])
def execute_replicated(compiled,
uses_outfeed, backend, in_handler, out_handler, *args):
xla.check_before_outfeed_execution(uses_outfeed)
def execute_replicated(compiled, backend, in_handler, out_handler, *args):
input_bufs = in_handler(args)
out_bufs = compiled.execute_on_local_devices(list(input_bufs))
return out_handler(out_bufs)

View File

@ -174,12 +174,12 @@ pytype_aval_mappings.update(
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Optional[Callable[[core.Jaxpr], Tuple[core.Jaxpr, bool]]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, bool]:
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr, False
return jaxpr
outfeed_primitives: Set[core.Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
@ -207,13 +207,6 @@ def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
return True
return False
# TODO(necula): remove this when we start the outfeed receiver automatically.
can_execute_outfeed_computations: bool = False # Set by outfeed_receiver
def check_before_outfeed_execution(uses_outfeed: bool):
if uses_outfeed and not can_execute_outfeed_computations:
raise ValueError("Attempting to execute compiled code using outfeed, "
"but outfeed_receiver is not started.")
### op-by-op execution
def arg_spec(x):
@ -607,7 +600,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
jaxpr, pvals, consts = pe.trace_to_jaxpr(
fun, pvals, instantiate=False, stage_out=True, bottom=True)
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr, uses_outfeed = apply_outfeed_rewriter(jaxpr)
jaxpr = apply_outfeed_rewriter(jaxpr)
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
@ -667,9 +660,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
options.parameter_is_tupled_arguments = tuple_args
compiled = backend.compile(built, compile_options=options)
if nreps == 1:
return partial(_execute_compiled, compiled, uses_outfeed, result_handlers)
return partial(_execute_compiled, compiled, result_handlers)
else:
return partial(_execute_replicated, compiled, uses_outfeed, result_handlers)
return partial(_execute_replicated, compiled, result_handlers)
def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args):
"""Configures input/output "must" aliasing based on `donated_args`."""
@ -767,18 +760,14 @@ def _pval_to_result_handler(device, pval):
else:
return aval_to_result_handler(device, pv)
def _execute_compiled(compiled: XlaExecutable, uses_outfeed: bool,
handlers, *args):
check_before_outfeed_execution(uses_outfeed)
def _execute_compiled(compiled: XlaExecutable, handlers, *args):
device, = compiled.local_devices()
input_bufs = [device_put(x, device) for x in args if x is not token]
out_bufs = compiled.execute(input_bufs)
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]
def _execute_replicated(compiled: XlaExecutable, uses_outfeed: bool,
handlers, *args):
check_before_outfeed_execution(uses_outfeed)
def _execute_replicated(compiled: XlaExecutable, handlers, *args):
input_bufs = [
[device_put(x, device) for x in args if x is not token]
for device in compiled.local_devices()]

View File

@ -4,9 +4,14 @@ filterwarnings =
ignore:No GPU/TPU found, falling back to CPU.:UserWarning
ignore:Explicitly requested dtype.*is not available.*:UserWarning
ignore:jax.experimental.vectorize is deprecated.*:FutureWarning
ignore:outfeed_receiver is unnecessary and deprecated:DeprecationWarning
# The rest are for experimental/jax_to_tf
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst"
addopts = --doctest-glob="*.rst" --dist=loadfile
# --dist=loadfile ensure that all the tests in one file are sent to the same runner. This is useful
# for host_callback_test which start and then stop on teardown the C++ outfeed receiver
# runtime. If we do not stop the receiver, other tests that use outfeed are going to fail.

View File

@ -16,11 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import functools
import logging
import numpy as np
import os
import re
import threading
import time
from typing import Callable, Sequence
from unittest import SkipTest
@ -34,27 +35,30 @@ from jax import test_util as jtu
from jax.config import config
from jax.experimental import host_callback as hcb
from jax.lib import xla_bridge
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def skip_if_jit_not_enabled():
if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false":
raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.")
def supported_dtypes():
return sorted(jtu.supported_dtypes(), key=lambda x: np.dtype(x).name)
class _TestingOutputStream(object):
"""Use as `output_stream` for tests."""
def __init__(self):
self._output = []
self.testMethodName = None
self.test_method_name = None
def write(self, what: str) -> None:
print(f"output_stream[{self.testMethodName}]: {what}", end="")
print(f"output_stream[{self.test_method_name}]: {what}", end="")
self._output.append(what)
@property
@ -80,14 +84,16 @@ def fun1(a):
def fun1_equiv(a): # Numerical equivalent of fun`
return (a * 2.)**2
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str):
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase,
expected: str, what: str):
"""A variant that preprocesses the string to eliminate non-determinism in
floating point values, and several uninteresting id_tap primitive params."""
floating point values, and several uninteresting id_tap primitive params.
"""
# Sometimes we get floating points in the output; we round them
def repl_floats(match_group):
matched = match_group.group(0)
if matched == ".": return matched
# TODO: why can't we use here np.around?
x = np.around(float(matched), decimals=2)
return f"{x:.2f}"
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
@ -98,23 +104,29 @@ def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str)
def repl_func(match_group):
matched = match_group.group(0)
if "function _print_consumer" in matched:
return "func=_print"
return "tap_func_=_print"
else:
return "..."
what = re.sub(r"func=(.*)", repl_func, what)
what = re.sub(r"tap_func_=(.*)", repl_func, what)
tst.assertMultiLineStrippedEqual(expected, what)
class HostCallbackTest(jtu.JaxTestCase):
def setUp(self):
testing_stream.reset()
testing_stream.testMethodName = self._testMethodName
testing_stream.test_method_name = self._testMethodName
self.old_flags = os.getenv("XLA_FLAGS", "")
def tearDown(self) -> None:
if os.getenv("XLA_FLAGS") != self.old_flags:
os.environ["XLA_FLAGS"] = self.old_flags
xla_bridge.get_backend.cache_clear()
hcb.barrier_wait()
@classmethod
def tearDownClass(cls):
hcb.stop_outfeed_receiver()
def helper_set_devices(self, nr_devices):
flags_str = os.getenv("XLA_FLAGS", "")
@ -135,8 +147,8 @@ class HostCallbackTest(jtu.JaxTestCase):
# TODO: renable jaxpr golden tests when changing host_callback
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.)))
with hcb.outfeed_receiver():
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
@ -150,8 +162,9 @@ what: y * 3
return x1 + y1
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.)))
with hcb.outfeed_receiver():
self.assertEqual(3. * (2. + 3.), func2(3.))
self.assertEqual(3. * (2. + 3.), func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
[ 6.00
9.00 ]""", testing_stream.output)
@ -162,8 +175,8 @@ what: y * 3
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
return res["a"] + res["b"]
with hcb.outfeed_receiver():
self.assertEqual(3. * (2. + 3.), func2(3.))
self.assertEqual(3. * (2. + 3.), func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
{ a=6.00
b=9.00 }""", testing_stream.output)
@ -175,8 +188,8 @@ what: y * 3
output_stream=testing_stream)
return x1
with hcb.outfeed_receiver():
self.assertEqual(3. * 4., func2(3.))
self.assertEqual(3. * 4., func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
[ 6.00
9.00 ]""", testing_stream.output)
@ -194,8 +207,8 @@ what: y * 3
return x3
with self.assertRaises(hcb.TapFunctionException):
with hcb.outfeed_receiver():
_ = func(0)
func(0)
hcb.barrier_wait()
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
@ -208,11 +221,8 @@ what: x3
def test_jit_simple(self):
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
2. * x, what="here", output_stream=testing_stream))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = jit_fun1(5.)
self.assertAllClose(6. * 5., res)
self.assertAllClose(6. * 5., jit_fun1(5.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
@ -224,8 +234,8 @@ what: here
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5)))
with hcb.outfeed_receiver():
self.assertAllClose(5, api.jit(func)(5))
self.assertAllClose(5, api.jit(func)(5))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
testing_stream.reset()
@ -239,9 +249,9 @@ what: here
api.make_jaxpr(func)(1))
logging.info("%s: %s", self._testMethodName,
api.xla_computation(func)(1).as_hlo_text())
self.assertEqual(2, api.jit(func)(1))
hcb.barrier_wait()
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(2, api.jit(func)(1))
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -256,10 +266,9 @@ where: 2
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
return x2
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(2, api.jit(func)(1))
self.assertEqual(11, api.jit(func)(10))
self.assertEqual(2, api.jit(func)(1))
self.assertEqual(11, api.jit(func)(10))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -280,8 +289,8 @@ where: 2
x3 = api.jit(func_nested)(x1)
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(3, api.jit(func)(1))
self.assertEqual(3, api.jit(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -300,9 +309,9 @@ where: 3
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
return x2
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
for d in devices:
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
for d in devices:
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
hcb.barrier_wait()
logging.info(f"{self._testMethodName}: found output {testing_stream.output}")
self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output)))
self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output)))
@ -332,13 +341,12 @@ where: 3
self.assertEqual(func(5, what), a)
transform = api.jit if with_jit else lambda f: f
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
for what in ("pair_1_x", "pair_x_2x", "dict"):
self.assertEqual(func(10, what),
transform(lambda x: hcb.id_tap(tap_func, func(x, what),
result=func(x * 2, what),
what=what))(5))
# Wait for receivers to be done
for what in ("pair_1_x", "pair_x_2x", "dict"):
self.assertEqual(func(10, what),
transform(lambda x: hcb.id_tap(tap_func, func(x, what),
result=func(x * 2, what),
what=what))(5))
hcb.barrier_wait() # Wait for receivers to be done
self.assertEqual(3, tap_count)
@parameterized.named_parameters(
@ -354,15 +362,17 @@ where: 3
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
x4 = lax.cond(x % 2 == 0,
lambda x: hcb.id_print(x, where="cond_t", output_stream=testing_stream),
lambda x: hcb.id_print(-1, where="cond_f", result=x, output_stream=testing_stream),
lambda x: hcb.id_print(x, where="cond_t",
output_stream=testing_stream),
lambda x: hcb.id_print(-1, where="cond_f", result=x,
output_stream=testing_stream),
x2 + 1)
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
return x5
transform = api.jit if with_jit else lambda f: f
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(4, transform(func)(1))
self.assertEqual(4, transform(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -376,9 +386,8 @@ where: end
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
dict(testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
def test_while_cond(self, with_jit=False):
def func(x):
@ -398,8 +407,8 @@ where: end
return res
transform = api.jit if with_jit else lambda f: f
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(4, transform(func)(1))
self.assertEqual(4, transform(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -434,8 +443,8 @@ where: end
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
return res
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertEqual(3, api.jit(func)(1))
self.assertEqual(3, api.jit(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self,
"""
where: w_p
@ -475,11 +484,11 @@ where: 3
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
return res
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
if with_jit:
func = api.jit(func)
res = func(1)
self.assertAllClose(jnp.array([1, 2, 3]), res)
if with_jit:
func = api.jit(func)
res = func(1)
self.assertAllClose(jnp.array([1, 2, 3]), res)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
@ -529,25 +538,23 @@ where: 10
xs,
a_new_test="************",
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
_ = jit_fun1(args)
# self.assertAllClose(args, res)
res = jit_fun1(args)
self.assertAllClose(args, res)
def test_jit_large(self):
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
api.jit(hcb.id_print)(arg)
api.jit(hcb.id_print)(arg)
def test_jit_several_together(self):
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
def test_jit_interleaving(self):
# Several jit's without data dependencies; they may interfere
count = 0 # Count tap invocations
nr_arrays = 5
def tap_func(arg, **kwargs):
def tap_func(arg, **_):
nonlocal count
assert len(arg) == nr_arrays
count += 1
@ -556,12 +563,13 @@ where: 10
for i in range(count):
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1]
return x
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
x = jnp.array(1, dtype=np.int32)
res = 0
for i in range(10):
# No dependencies between the jit invocations
res += api.jit(lambda x: func(x, 10))(x)
x = jnp.array(1, dtype=np.int32)
res = 0
for _ in range(10):
# No dependencies between the jit invocations
res += api.jit(lambda x: func(x, 10))(x)
hcb.barrier_wait()
self.assertEqual(100, count)
def test_jit_tap_exception(self):
@ -574,9 +582,10 @@ where: 10
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
res = api.jit(func)(0) # No error yet
with self.assertRaises(hcb.TapFunctionException):
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = api.jit(func)(0)
hcb.barrier_wait()
# Even though the receiver thread raised, the main thread should still
# return 3.
self.assertEqual(3, res)
@ -588,49 +597,20 @@ what: x3
3""", testing_stream.output)
testing_stream.reset()
def test_jit_unknown_tap(self):
# Simulate an unknown tap function
def func(x):
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
with self.assertRaises(hcb.TapFunctionException):
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = api.jit(func)(0)
# Even though the receiver thread raised, the main thread should still
# return 3.
self.assertEqual(3, res)
# We should have received all others
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
testing_stream.reset()
# On CPU and GPU the device code blocks
# On GPU it seems that there is a 5 min timeout?
# On TPU the client does not block, but messes up the rest somehow
@jtu.skip_on_devices("cpu", "gpu", "tpu")
def test_jit_receiver_ends_prematurely(self):
# Simulate an unknown tap function
def func(x):
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1) # Will end the consumer loop
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
_ = api.jit(func)(0)
assert False # It seems that the previous jit blocks above
def test_jit_error_no_consumer(self):
# Check for errors if starting jit without a consumer active
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
api.jit(lambda x: hcb.id_print(x))(0)
def test_jit_nested_cond_no_print(self):
"""A nested conditional, without any prints"""
raise SkipTest("skip this")
@api.jit
def cfun(x):
return lax.cond(
lax.lt(x, 2),
lambda x: x,
lambda x: lax.cond(x < 5,
3, lambda x: x,
4, lambda y: y),
x)
print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text())
cfun(1)
def test_while(self):
"""Executing while, even without JIT uses compiled code"""
@ -641,8 +621,8 @@ what: x3
lambda c: c[1] < 5,
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
(x, 1))
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
func(y)
func(y)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
1
2
@ -650,27 +630,13 @@ what: x3
4""", testing_stream.output)
testing_stream.reset()
def test_while_error_no_receiver(self):
"""Executing while needs the receiver"""
y = jnp.ones(5) # captured const
def func(x):
return lax.while_loop(
lambda c: c[1] < 5,
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
(x, 1))
with self.assertRaisesRegex(ValueError, ".*outfeed_receiver.*not started"):
func(y).block_until_ready()
def test_jvp(self):
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
#assertMultiLineStrippedEqual(self, "",
# str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1))))
with hcb.outfeed_receiver():
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
#assertMultiLineStrippedEqual(self, "")
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
self.assertAllClose(100., res_primals, check_dtypes=False)
self.assertAllClose(4., res_tangents, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
@ -685,23 +651,24 @@ transforms: ({'name': 'jvp'},) what: y * 3
def test_grad_primal_unused(self):
# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream)
return 2. * hcb.id_print(x * 3., what="x * 3",
output_stream=testing_stream)
grad_func = api.grad(func)
with hcb.outfeed_receiver():
assertMultiLineStrippedEqual(self, """
jaxpr = str(api.make_jaxpr(grad_func)(5.))
# Just making the Jaxpr invokes the id_print once
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
{ lambda ; a.
let
in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
# Just making the Jaxpr invokes the id_print once
in (6.00,) }""", jaxpr)
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
testing_stream.reset()
with hcb.outfeed_receiver():
res_grad = grad_func(jnp.float32(5.))
res_grad = grad_func(jnp.float32(5.))
hcb.barrier_wait()
self.assertAllClose(6., res_grad, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
@ -714,13 +681,14 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
def test_grad_simple(self):
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
return x * hcb.id_print(y * 3., what="y * 3",
output_stream=testing_stream)
grad_func = api.grad(func)
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))
with hcb.outfeed_receiver():
res_grad = grad_func(jnp.float32(5.))
res_grad = grad_func(jnp.float32(5.))
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
@ -738,18 +706,19 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
return x * (y * 3.)
grad_func = api.grad(api.grad(func))
with hcb.outfeed_receiver():
_ = api.make_jaxpr(grad_func)(5.)
# Just making the Jaxpr invokes the id_print twiceonce
assertMultiLineStrippedEqual(self, """
# Just making the Jaxpr invokes the id_print twice
_ = api.make_jaxpr(grad_func)(5.)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00""", testing_stream.output)
testing_stream.reset()
res_grad = grad_func(jnp.float32(5.))
testing_stream.reset()
res_grad = grad_func(jnp.float32(5.))
self.assertAllClose(12., res_grad, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
@ -765,8 +734,8 @@ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
vmap_fun1 = api.vmap(fun1)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs)))
with hcb.outfeed_receiver():
_ = vmap_fun1(vargs)
vmap_fun1(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
@ -784,8 +753,8 @@ transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
vmap_func = api.vmap(func)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs)))
with hcb.outfeed_receiver():
_ = vmap_func(vargs)
_ = vmap_func(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
@ -804,8 +773,8 @@ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
xv = jnp.arange(5, dtype=np.int32)
yv = jnp.arange(3, dtype=np.int32)
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv)))
with hcb.outfeed_receiver():
_ = sum_all(xv, yv)
_ = sum_all(xv, yv)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
[[0 1 2 3 4]
@ -827,9 +796,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dim
return res
inputs = np.arange(5, dtype=np.int32)
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
check_dtypes=False)
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
[0 1 2 3 4]
@ -856,9 +825,9 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
return res
inputs = np.arange(5, dtype=np.int32)
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
check_dtypes=False)
res = api.jit(api.vmap(func))(inputs)
hcb.barrier_wait()
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
[0 1 2 3 4]
@ -880,21 +849,15 @@ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
pmap_fun1 = api.pmap(fun1, axis_name="i")
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
res = pmap_fun1(vargs)
res = pmap_fun1(vargs)
hcb.barrier_wait()
expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
self.assertAllClose(expected_res, res, check_dtypes=False)
def test_pmap_error_no_receiver(self):
# Check for errors if starting jit without a consumer active
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
api.pmap(lambda x: hcb.id_print(x))(vargs)
def test_mask(self):
# TODO(necula)
raise SkipTest("masking has regressed")
@partial(api.mask, in_shapes=['n'], out_shape='')
@functools.partial(api.mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
args = [jnp.arange(4)], dict(n=np.int64(2))
@ -916,21 +879,133 @@ logical_shapes: [(2,)] transforms: ('mask',) what: x
""", testing_stream.output)
testing_stream.reset()
def test_outfeed_receiver(self):
"""Test the deprecated outfeed_receiver"""
with hcb.outfeed_receiver():
self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True)
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
testing_stream.reset()
def test_callback_delay(self):
hcb.callback_extra = lambda dev: time.sleep(1)
def func(x):
for i in range(5):
x = hcb.id_print(x * i, what="x times i")
return x
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
def test_callback_delay_barrier(self):
hcb.callback_extra = lambda dev: time.sleep(2)
def func(x):
for i in range(1, 4):
x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream)
return x
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
# Wait for the results
hcb.barrier_wait()
expected = """
what: x times i
[[0. 1. 2.]
[3. 4. 5.]]
what: x times i
[[ 0. 2. 4.]
[ 6. 8. 10.]]
what: x times i
[[ 0. 6. 12.]
[18. 24. 30.]]"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
# Call again
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_multiple_barriers(self):
"""Call barrier_wait concurrently."""
def pause_tap(*args, **kwargs):
logging.info("pause_tap waiting")
time.sleep(2)
logging.info("pause_tap done")
def long_run(x):
return hcb.id_tap(pause_tap, x)
api.jit(long_run)(5.)
def try_barrier(idx):
logging.info(f"Starting test barrier {idx}")
hcb.barrier_wait()
logging.info(f"Finished test barrier {idx}")
threads = [
threading.Thread(
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
for idx in range(3)
]
[t.start() for t in threads]
[t.join() for t in threads]
def test_error_bad_consumer_id(self):
"""Try to use reserved consumer ID 0.
Check that we get the proper error from the runtime."""
comp = xla_bridge.make_computation_builder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
with self.assertRaisesRegex(RuntimeError,
"Consumer ID cannot be a reserved value: 0"):
hcb._outfeed_receiver.receiver.add_outfeed(
comp, token, 0,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
def test_error_different_shapes(self):
"""Try to register different shapes for the same consumer ID."""
comp = xla_bridge.make_computation_builder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._outfeed_receiver.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._outfeed_receiver.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._outfeed_receiver.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))])
class OutfeedRewriterTest(jtu.JaxTestCase):
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
has_input_token=True, has_output_token=True):
"""Check that the rewrite of func(*args) matches expected."""
_ = api.make_jaxpr(func)(*args)
jaxpr = api.make_jaxpr(func)(*args)
# TODO: re-enable when we change the host_callback rewriter
#assertMultiLineStrippedEqual(self, expected,
# str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
#rewritten = hcb._rewrite_typed_jaxpr(jaxpr,
# has_input_token, has_output_token)
#assertMultiLineStrippedEqual(self, expected, str(rewritten))
del jaxpr
def test_no_outfeed(self):
self.assertRewrite("""
{ lambda ; a.
let b = mul a a
c = add a b
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False)
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False,
has_output_token=False)
self.assertRewrite("""
{ lambda ; a d.
let b = mul a a
@ -946,9 +1021,11 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
self.assertRewrite("""
{ lambda ; a d.
let b = add a a
c e = id_tap[ arg_treedef=*
func=_print
] b d
c e = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] b d
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
def test_cond(self):
@ -962,8 +1039,10 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
d = convert_element_type[ new_dtype=int32
old_dtype=bool ] c
g h j = cond[ branches=( { lambda ; f_ e a b c g.
let d h = id_tap[ arg_treedef=*
func=_print
let d h = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] c g
in (d, e, h) }
{ lambda ; d g_ a b c h.
@ -980,15 +1059,13 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
(x, np.float32(1.)))
# TODO: we should not need to start a receiver here!!! I believe this is
# because of the partial evaluation of while, which calls impl, which
# uses JIT.
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertRewrite("""
self.assertRewrite("""
{ lambda b c ; a f.
let d e g = while[ body_jaxpr={ lambda ; c a b f.
let d g = id_tap[ arg_treedef=*
func=_print
let d g = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] b f
e = add d 1.00
in (c, e, g) }
@ -1011,43 +1088,47 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
(x, 1))
# TODO: we should not need to start a receiver here!!! I believe this is
# because of the partial evaluation of while, which calls impl, which
# uses JIT.
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
self.assertRewrite("""
{ lambda b c ; a f.
let h i = xla_call[ call_jaxpr={ lambda ; c a b g.
let d e h = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
] c b g
f = lt e 5
in (f, h) }
name=cond_before ] b a 1 f
y d e g = while[ body_jaxpr={ lambda ; n o p q r s.
let t u v = xla_call[ call_jaxpr={ lambda ; c a b f.
let d g = id_tap[ arg_treedef=*
func=_print
] b f
e = add d 1
in (c, e, g) }
name=body ] o q r s
w x = xla_call[ call_jaxpr={ lambda ; c a b g.
let d e h = id_tap[ arg_treedef=*
func=_print
nr_untapped=1
] c b g
f = lt e 5
in (f, h) }
name=cond_body ] n t u v
in (w, t, u, x) }
body_nconsts=2
cond_jaxpr={ lambda ; j k l m.
let
in (j,) }
cond_nconsts=0 ] b c h a 1 i
in (d, 5, g) }""", func, [ct_body])
self.assertRewrite("""
{ lambda b c ; a e.
let g h = xla_call[ call_jaxpr={ lambda ; c a b f.
let _ d g = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] c b f
e = lt d 5
in (e, g) }
donated_invars=(False, False, False, False)
name=cond_before ] b a 1 e
x d _ f =
while[ body_jaxpr={ lambda ; m n o p q r.
let s t u = xla_call[ call_jaxpr={ lambda ; c a b f.
let d g = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] b f
e = add d 1
in (c, e, g) }
donated_invars=(False, False, False, False, False, False, False)
name=body ] n p q r
v w = xla_call[ call_jaxpr={ lambda ; c a b f.
let _ d g = id_tap[ arg_treedef_=*
has_token_=True
nr_tapped_args_=1
tap_func_=_print
] c b f
e = lt d 5
in (e, g) }
donated_invars=(False, False, False, False, False, False)
name=cond_body ] m s t u
in (v, s, t, w) }
body_nconsts=2
cond_jaxpr={ lambda ; i j k l.
let
in (i,) }
cond_nconsts=0 ] b c g a 1 h
in (d, 5, f) }""", func, [ct_body])
def test_scan(self):
y = jnp.ones(5) # captured const
@ -1055,16 +1136,19 @@ class OutfeedRewriterTest(jtu.JaxTestCase):
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
self.assertRewrite("""
{ lambda b ; a f.
let c d g e = scan[ jaxpr={ lambda ; f a b g c.
let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
func=_print
] a b g
in (d, e, h, f) }
length=5
linear=(False, False, False, False, False)
num_carry=3
num_consts=1
reverse=False ] b 1 2 f a
let c d g e =
scan[ jaxpr={ lambda ; f a b g c.
let d e h = id_tap[ arg_treedef_=PyTreeDef(tuple, [*,*])
has_token_=True
nr_tapped_args_=2
tap_func_=_print
] a b g
in (d, e, h, f) }
length=5
linear=(False, False, False, False, False)
num_carry=3
num_consts=1
reverse=False ] b 1 2 f a
in (c, d, e, g) }""", func, [y])

View File

@ -19,6 +19,7 @@ from absl.testing import absltest
import jax
from jax import lax, numpy as np
from jax.config import config
from jax.experimental import host_callback as hcb
from jax.lib import xla_client
import jax.test_util as jtu
import numpy as onp
@ -47,6 +48,7 @@ class InfeedTest(jtu.JaxTestCase):
self.assertAllClose(f(x), x + y + z)
def testInfeedThenOutfeed(self):
hcb.stop_outfeed_receiver()
@jax.jit
def f(x):
token = lax.create_token(x)
@ -67,6 +69,7 @@ class InfeedTest(jtu.JaxTestCase):
self.assertAllClose(out, y + onp.float32(1))
def testInfeedThenOutfeedInALoop(self):
hcb.stop_outfeed_receiver()
def doubler(_, token):
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), np.float32))