rocm_jax/tests/host_callback_test.py
George Necula 3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00

3086 lines
109 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools
import logging
import os
import re
import threading
import time
from typing import Callable, Optional, Sequence
import unittest
from unittest import skip, SkipTest
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import ad_checkpoint
from jax import core
from jax.config import config
from jax import dtypes
from jax.experimental import host_callback as hcb
from jax.experimental import PartitionSpec as P
from jax.experimental import maps
from jax.experimental import pjit
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge
xops = xla_client.ops
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class _TestingOutputStream(object):
"""Use as `output_stream` for tests."""
def __init__(self):
self._output = []
self._test_method_name = None
def write(self, what: str) -> None:
print(f"output_stream[{self._test_method_name}]: {what}", end="")
self._output.append(what)
@property
def output(self):
return "".join(self._output)
@property
def output_sorted_by_device(self):
# Assume that the output is a sequence of strings including metadata
# and data, with metadata containing `device: xxx`
by_device = [] # each element is a pair (device, str_list)
for s in self._output:
m = re.match(r".*device: (\S+)", s)
if m:
by_device.append((m.group(1), []))
assert by_device, f"output does not include 'device:': {self._output}"
by_device[-1][1].append(s)
sorted_by_device = sorted(by_device, key=lambda x: x[0])
return "\n".join(itertools.chain(*[s[1] for s in sorted_by_device]))
def __str__(self):
return "TestingOutputStream"
def reset(self):
self._output = []
testing_stream = _TestingOutputStream()
def fun1(a):
"""Function used for several `id_tap` tests."""
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
return y ** 2 # Some computation to make the gradient interesting
def fun1_equiv(a): # Numerical equivalent of fun1
return (a * 2.) ** 2
def maybe_print(do_print: bool, arg, what: str, tap_with_device: Optional[bool] = False):
"""Conditionally print on testing_string"""
if do_print:
return hcb.id_print(arg, what=what,
output_stream=testing_stream, tap_with_device=tap_with_device)
else:
return arg
def local_devices():
# Tests require using not more than 2 devices.
return jax.local_devices()[:2]
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase,
expected: str, what: str):
"""A variant that preprocesses the string to eliminate non-determinism in
floating point values, and several uninteresting id_tap primitive params.
"""
# Sometimes we get floating points in the output; we round them
def repl_floats(match_group):
matched = match_group.group(0)
if matched == ".": return matched
x = np.around(float(matched), decimals=2)
return f"{x:.2f}"
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
what = re.sub(r"output_stream=[^\]\n,]*,?", "", what)
what = re.sub(r"threshold=[^\]\n,]*,?", "", what)
what = re.sub(r"bwd=[^\]\n]*", "", what)
what = re.sub(r"out_trees=[^\]\n]*", "", what)
what = re.sub(r"fwd_jaxpr_thunk=[^\]\n]*", "", what)
what = re.sub(r"jvp_jaxpr_thunk=[^\]\n]*", "", what)
# Empty lines
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
def repl_func(match_group):
matched = match_group.group(3)
if "function _print_consumer" in matched:
return match_group.group(1) + "=_print"
else:
return match_group.group(1) + "=..."
what = re.sub(r"((tap_func_)|(callback))=([^\]\n,]*),?", repl_func, what)
tst.assertMultiLineStrippedEqual(expected, what)
def helper_set_hlo_dump():
flags_str = os.getenv("XLA_FLAGS", "")
import shutil
dump_dir = "/tmp/xla_dump"
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to={dump_dir}"
if os.path.isdir(dump_dir):
logging.warning("Deleting old XLA dump directory %s", dump_dir)
shutil.rmtree(dump_dir)
logging.warning("Setting XLA dump directory %s", dump_dir)
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
def helper_print_optimized_hlo(fun, *args):
backend = xla_bridge.get_backend()
c = jax.xla_computation(fun, backend='cpu')(*args)
print(re.sub(r", metadata.*", "",
backend.compile(c).hlo_modules()[0].to_string()))
def helper_log_ir(name,
f_jax,
*args,
num_partitions=None,
strip_metadata=False):
print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
jax_comp = jax.xla_computation(f_jax, backend='cpu')(*args)
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
backend = xla_bridge.get_backend()
if num_partitions is not None:
num_replicas = 1
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = xla_bridge.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=use_spmd_partitioning,
)
else:
compile_options = None
jax_optimized_hlo = backend.compile(
jax_comp, compile_options).hlo_modules()[0].to_string()
if strip_metadata:
jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo)
print(f"Optimized HLO[{name}] for "
f"platform {backend.platform}: {jax_optimized_hlo}")
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
expected_2CPUs: str):
"""Check that the multi-device output is equal to the expected.
The tests run with 2 devices if available, otherwise 1 device.
We adjust the expected output here for 1 device.
Args:
expected_2CPUs: the expected output for 2 CPUs. If there is only
one device, this is trimmed to the first device. If the current
device_under_test is not a CPU, then we change the names
"""
expected = expected_2CPUs
if len(local_devices()) == 1:
start_device_1 = expected.find('device: cpu:1')
if start_device_1 >= 0:
expected = expected[0:start_device_1]
def replace_device_name(m) -> str:
return str(local_devices()[int(m.group(1))])
expected = re.sub(r'cpu:(\d+)', replace_device_name, expected)
what = testing_stream.output_sorted_by_device
return assertMultiLineStrippedEqual(tst, expected, what)
class HostCallbackTapTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.device_under_test() == "gpu" and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
testing_stream.reset()
testing_stream._test_method_name = self._testMethodName
self.old_flags = os.getenv("XLA_FLAGS", "")
def tearDown(self) -> None:
if os.getenv("XLA_FLAGS") != self.old_flags:
os.environ["XLA_FLAGS"] = self.old_flags
xla_bridge.get_backend.cache_clear()
hcb.barrier_wait("HostCallbackTapTest.tearDown")
super().tearDown()
def test_tap_eval(self):
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
def test_tap_with_tuple_results(self):
def func2(x):
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
return x1 + y1
self.assertEqual(3. * (2. + 3.), func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( 6.00 9.00 )""", testing_stream.output)
def test_tap_with_dict_results(self):
def func2(x):
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
return res["a"] + res["b"]
self.assertEqual(3. * (2. + 3.), func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
{ a=6.00 b=9.00 }""", testing_stream.output)
def test_tap_with_result(self):
def func2(x):
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
output_stream=testing_stream)
return x1
self.assertEqual(3. * 4., func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( 6.00 9.00 )""", testing_stream.output)
def test_tap_with_result_no_arg(self):
def tap_func(arg, transforms):
testing_stream.write(f"called tap_func with {arg}")
def func2(x):
x1 = hcb.id_tap(tap_func, None, result=x)
return x1
self.assertEqual(3., func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, "called tap_func with None",
testing_stream.output)
def test_tap_result_unused(self):
def tap_func(arg, transforms):
testing_stream.write(f"called tap_func with {arg}")
def func2(x):
hcb.id_tap(tap_func, None)
return x
self.assertEqual(3., func2(3.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, "called tap_func with None",
testing_stream.output)
def test_tap_with_device(self):
def func2(x):
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
output_stream=testing_stream,
tap_with_device=True)
return x1
self.assertEqual(3. * 4., func2(3.))
hcb.barrier_wait()
assertMultiDeviceOutputEqual(self, """
device: cpu:0
( 6.00 9.00 )""")
def test_tap_eval_exception(self):
if not FLAGS.jax_host_callback_outfeed:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error
def tap_err(*args, **kwargs):
raise ValueError("Some user message")
def func(x):
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
x2 = hcb.id_tap(tap_err, x1 + 1)
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
with self.assertRaisesRegex(
hcb.CallbackException,
re.compile("There were exceptions during callback processing. Last one was:.*"
"ValueError: Some user message", re.DOTALL)):
func(0)
hcb.barrier_wait()
# We should have received everything before the error
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
def test_tap_empty(self):
"""Tap empty arrays."""
hcb.id_print((), output_stream=testing_stream)
hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( )
what: second
( 1.00 [] )""", testing_stream.output)
def test_tap_jit_simple(self):
jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print(
2. * x, what="here", output_stream=testing_stream))
self.assertAllClose(6. * 5., jit_fun1(5.))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
def test_tap_jit_no_invars(self):
def func(): # jitted function does not take arguments
return hcb.id_print(42, output_stream=testing_stream)
self.assertAllClose(42, jax.jit(func)())
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
def test_tap_jit_multiple_invars(self):
def func(x1, x2):
return hcb.id_print(x1 + x2, output_stream=testing_stream)
self.assertAllClose(42, jax.jit(func)(40, 2))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
def test_tap_jit_constant(self):
def func(x):
return hcb.id_print(42, result=x, output_stream=testing_stream)
self.assertAllClose(5, jax.jit(func)(5))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
def test_tap_jit_sequence1(self):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
logging.info("%s: %s", self._testMethodName,
jax.make_jaxpr(func)(1))
logging.info("%s: %s", self._testMethodName,
jax.xla_computation(func, backend='cpu')(1).as_hlo_text())
self.assertEqual(2, jax.jit(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2""", testing_stream.output)
def test_tap_jit2(self):
"""A sequence of JIT."""
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
return x2
self.assertEqual(2, jax.jit(func)(1))
self.assertEqual(11, jax.jit(func)(10))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2
where: 1
10
where: 2
11""", testing_stream.output)
def test_tap_jit_result_unused(self):
"""We can id_print even if we don't use the result."""
def func(x):
hcb.id_print(x, where="1", output_stream=testing_stream)
hcb.id_print(x + 1, where="2", output_stream=testing_stream)
return x + 1
self.assertEqual(2, jax.jit(func)(1))
self.assertEqual(11, jax.jit(func)(10))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2
where: 1
10
where: 2
11""", testing_stream.output)
def test_tap_jit_nested(self):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
def func_nested(x):
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
return x2
x3 = jax.jit(func_nested)(x1)
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
self.assertEqual(3, jax.jit(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
def test_tap_jit_devices(self):
"""Running on multiple devices."""
logging.info("%s: has devices %s", self._testMethodName, local_devices())
def func(x, device_id):
x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
return x2
for d in local_devices():
self.assertEqual(112, jax.jit(func, device=d, static_argnums=1)(111, d.id))
hcb.barrier_wait()
logging.info("%s: found output %s", self._testMethodName,
testing_stream.output)
self.assertEqual(
len(local_devices()), len(re.findall(r"111", testing_stream.output)))
self.assertEqual(
len(local_devices()), len(re.findall(r"112", testing_stream.output)))
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
def test_tap_pytree(self, with_jit=False):
def func(x, what=""):
"""Returns some pytrees depending on x"""
if what == "pair_1_x":
return (1, x)
elif what == "pair_x_2x":
return (x, 2 * x)
elif what == "dict":
return dict(a=2 * x, b=3 * x)
else:
assert False
tap_count = 0
def tap_func(a, _, *, what=""):
nonlocal tap_count
tap_count += 1
self.assertEqual(func(5, what), a)
transform = jax.jit if with_jit else lambda f: f
for what in ("pair_1_x", "pair_x_2x", "dict"):
transformed = transform(
lambda x: hcb.id_tap(
partial(tap_func, what=what),
func(x, what),
result=func(x * 2, what))
)(5)
self.assertEqual(func(10, what), transformed)
hcb.barrier_wait() # Wait for receivers to be done
self.assertEqual(3, tap_count)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_concurrent_{concurrent}",
concurrent=concurrent)
for concurrent in [True, False]))
def test_tap_multiple(self, concurrent=False):
"""Call id_tap multiple times, concurrently or in sequence. """
if concurrent and jtu.device_under_test() in ["cpu", "gpu"]:
# TODO(necula): if there is device side concurrency, outfeeds from
# different computations can be interleaved. For example, it seems that
# on GPU if multiple host threads run a jit computation, the multiple
# computations are interleaved on the GPU. This can result in the outfeed
# trains being interleaved, which will trigger an error.
# The solution is to fix on GPU the receiving logic so that we can outfeed
# the train as one tuple, and receive it one piece as a time. Then the
# trains should be atomic.
# See also b/160692602.
raise SkipTest("concurrent id_tap not supported on CPU, GPU")
received = set()
count = 5
def pause_tap(idx, _):
received.add(int(idx))
logging.info("Starting do_tap %s. Sleeping 1sec ...", idx)
time.sleep(0.3)
logging.info("Finish do_tap %s", idx)
def do_tap(idx):
jax.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
if concurrent:
threads = [
threading.Thread(
name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,))
for idx in range(count)
]
[t.start() for t in threads]
[t.join() for t in threads]
else:
for idx in range(count):
do_tap(idx)
hcb.barrier_wait()
self.assertEqual(received, set(range(count)))
# TODO(necula): see comment for test_multiple_tap. Here we disable also
# on TPU, because the barrier_wait runs on all devices, including on the CPU
# where it would run into concurrency problems.
@skip("Concurrency not supported")
def test_tap_multiple_barriers(self):
"""Call barrier_wait concurrently."""
def pause_tap(*args, **kwargs):
logging.info("pause_tap waiting")
time.sleep(0.3)
logging.info("pause_tap done")
def long_run(x):
return hcb.id_tap(pause_tap, x)
jax.jit(long_run)(5.)
def try_barrier(idx):
logging.info("Starting test barrier %s", idx)
hcb.barrier_wait()
logging.info("Finished test barrier %s", idx)
threads = [
threading.Thread(
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
for idx in range(3)
]
[t.start() for t in threads]
[t.join() for t in threads]
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
def test_tap_cond(self, with_jit=False):
"""A conditional"""
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
x4 = lax.cond(x % 2 == 0,
lambda x: hcb.id_print(x, where="cond_t",
output_stream=testing_stream),
lambda x: hcb.id_print(-1, where="cond_f", result=x,
output_stream=testing_stream),
x2 + 1)
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
return x5
transform = jax.jit if with_jit else lambda f: f
self.assertEqual(4, transform(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2
where: cond_f
-1
where: end
4""", testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
def test_tap_while_cond(self, with_jit=False):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
def body(x):
x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream)
x4 = lax.cond(x % 2 == 0,
lambda x: hcb.id_print(x, where="w_b_t",
output_stream=testing_stream),
lambda x: hcb.id_print(-1, where="w_b_f",
result=x, output_stream=testing_stream),
x3 + 1)
return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream)
x10 = lax.while_loop(lambda x: x <= 3, body, x2)
res = hcb.id_print(x10, where="end", output_stream=testing_stream)
return res
transform = jax.jit if with_jit else lambda f: f
self.assertEqual(4, transform(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
def test_tap_jit_while_pred_tap(self):
"""While with printing in the conditional."""
def func(x):
x1 = hcb.id_print(x, where="1")
x10 = lax.while_loop(lambda x: hcb.id_print(x < 3,
where="w_p",
output_stream=testing_stream),
lambda x: hcb.id_print(x + 1, where="w_b",
output_stream=testing_stream),
x1)
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
return res
self.assertEqual(3, jax.jit(func)(1))
hcb.barrier_wait()
assertMultiLineStrippedEqual(self,
"""
where: w_p
True
where: w_b
2
where: w_p
True
where: w_b
3
where: w_p
False
where: 3
3""", testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
def test_tap_scan_cond(self, with_jit=True):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
def body(c, x):
x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
x4 = lax.cond(x % 2 == 0,
lambda x: hcb.id_print(x, where="s_t", output_stream=testing_stream),
lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream),
x3 + 1)
return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream))
_, x10 = lax.scan(body, x2, jnp.arange(3))
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
return res
if with_jit:
func = jax.jit(func)
res = func(1)
self.assertAllClose(jnp.array([1, 2, 3]), res)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2
where: s_1
0
where: s_t
1
where: s_2
1
where: s_1
1
where: s_f
-1
where: s_2
2
where: s_1
2
where: s_t
3
where: s_2
3
where: 10
[1 2 3]""", testing_stream.output)
testing_stream.reset()
@parameterized.named_parameters(
jtu.cases_from_list(
dict(
testcase_name=f"_shape_{shape}_dtype_{np.dtype(dtype).name}_nr_args={nr_args}",
shape=shape,
dtype=dtype,
nr_args=nr_args) for nr_args in [1, 2]
for shape in [(), (2,), (2, 3), (2, 3, 4)]
for dtype in jtu.dtypes.all))
def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
raise SkipTest(f"host_callback not implemented for {dtype}.")
if dtype == np.bool_:
args = [self.rng().choice(a=[True, False], size=shape)]
else:
args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
if nr_args > 1:
args = args * nr_args
jit_fun1 = jax.jit(lambda xs: hcb.id_print(
xs,
a_new_test="************",
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
res = jit_fun1(args)
self.assertAllClose(args, res, check_dtypes=True)
def test_tap_jit_large(self):
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
jax.jit(hcb.id_print)(arg)
def test_tap_jit_several_together(self):
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
jax.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
def test_tap_jit_interleaving(self):
# Several jit's without data dependencies; they may interfere
count = 0 # Count tap invocations
nr_arrays = 5
def tap_func(arg, _):
nonlocal count
assert len(arg) == nr_arrays
count += 1
# This is the function that we'll run multiple times
def func(x, count):
for i in range(count):
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)])[-1]
return x
x = jnp.array(1, dtype=np.int32)
res = 0
for _ in range(10):
# No dependencies between the jit invocations
res += jax.jit(lambda x: func(x, 10))(x)
hcb.barrier_wait()
self.assertEqual(100, count)
def test_tap_jit_tap_exception(self):
if not FLAGS.jax_host_callback_outfeed:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error
def tap_err(*args, **kwargs):
raise NotImplementedError
def func(x):
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
x2 = hcb.id_tap(tap_err, x1 + 1)
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
return x3
res = jax.jit(func)(0) # No error yet
with self.assertRaises(hcb.CallbackException):
hcb.barrier_wait()
# Even though the receiver thread raised, the main thread should still
# return 3.
self.assertEqual(3, res)
# We should have received all others
assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
def test_tap_while(self):
"""Executing while, even without JIT uses compiled code"""
y = jnp.ones(5) # captured const
def func(x):
return lax.while_loop(
lambda c: c[1] < 5,
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
(x, 1))
func(y)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
1
2
3
4""", testing_stream.output)
def test_tap_jvp(self):
jvp_fun1 = lambda x, xt: jax.jvp(fun1, (x,), (xt,))
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
self.assertAllClose(100., res_primals, check_dtypes=False)
self.assertAllClose(4., res_tangents, check_dtypes=False)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
transforms: ['jvp'] what: a * 2
( 10.00 0.20 )
transforms: ['jvp'] what: y * 3
( 30.00 0.60 )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
def test_tap_grad_primal_unused(self):
# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3",
output_stream=testing_stream)
grad_func = jax.grad(func)
arg = jnp.float32(5.)
jaxpr = str(jax.make_jaxpr(grad_func)(arg))
# making the Jaxpr does not print anything
hcb.barrier_wait()
treedef = tree_util.tree_structure(arg)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=()
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
transforms=(('jvp',), ('transpose',))
] d
f:f32[] = mul e 3.00
in (f,) }}""", jaxpr)
else:
assertMultiLineStrippedEqual(self, f"""
{{ lambda ; a:f32[]. let
b:f32[] = mul a 3.00
c:f32[] = outside_call[
arg_treedef={treedef}
callback=...
identity=True
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output)
testing_stream.reset()
res_grad = grad_func(arg)
hcb.barrier_wait()
self.assertAllClose(6., res_grad, check_dtypes=False)
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00
transforms: ['jvp', 'transpose'] what: x * 3
2.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 3
15.00""", testing_stream.output)
def test_tap_grad_simple(self):
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
return x * hcb.id_print(y * 3., what="y * 3",
output_stream=testing_stream)
grad_func = jax.grad(func)
res_grad = grad_func(jnp.float32(5.))
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ['jvp', 'transpose'] what: y * 3
5.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00""", testing_stream.output)
def test_tap_grad_grad(self):
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
return x * (y * 3.)
grad_func = jax.grad(jax.grad(func))
# making the Jaxpr does not print anything
_ = jax.make_jaxpr(grad_func)(5.)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, "", testing_stream.output)
res_grad = grad_func(jnp.float32(5.))
self.assertAllClose(12., res_grad, check_dtypes=False)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
transforms: ['jvp', 'transpose'] what: x * 2
15.00
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
2.00
transforms: ['jvp', 'transpose'] what: x * 2
3.00""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: x * 2
10.00""", testing_stream.output)
def test_tap_grad_pytree(self):
def func(x):
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
result=(x * 4., x * 5.),
output_stream=testing_stream)
return x4 + 2. * x5
x = jnp.float32(5.)
grad_func = jax.grad(func)
print(jax.make_jaxpr(grad_func)(x))
res_grad = grad_func(x)
self.assertAllClose(14., res_grad, check_dtypes=False)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00 15.00 )
transforms: ['jvp', 'transpose'] what: pair
( 0.00 0.00 )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: pair
( 10.00 15.00 )""", testing_stream.output)
def test_tap_jvp_float0(self):
def f(x, yint):
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
return x * yint
res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
self.assertAllClose((6., 0.6), res)
def test_tap_grad_float0(self):
def func(x, yint):
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
return x * yint
grad_func = jax.grad(func)
res_grad = grad_func(jnp.float32(5.), jnp.int32(2))
self.assertAllClose(2., res_grad, check_dtypes=False)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00 2 )
transforms: ['jvp', 'transpose'] what: pair
( 2.00 False )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
what: pair
( 5.00 2 )""", testing_stream.output)
def test_tap_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
return (3. * x[0], x[1])
def f_jax_vjp(x):
res, pullback = jax.vjp(f_jax, x)
g, = pullback((np.ones(x[0].shape, dtype=x[0].dtype),
np.zeros(x[1].shape, dtype=dtypes.float0)))
return g
g = f_jax_vjp(x)
self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0])
self.assertEqual(dtypes.float0, g[1].dtype)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
def test_tap_higher_order_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
return (jnp.sin(x[0]), x[1])
def wrap_vjp(f, args, res_f_of_args):
# Given a function "f" and "args" return the f_vjp and args_vjp
def make_ct(res):
res_dtype = np.result_type(res)
if res_dtype == dtypes.float0:
return res
ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype)
return np.ones(np.shape(res), dtype=ct_dtype)
cts = tree_util.tree_map(make_ct, res_f_of_args)
def f_vjp(args, cts):
res, pullback = jax.vjp(f, *args)
return pullback(cts)
return (f_vjp, (args, cts))
res = f_jax(x)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
testing_stream.reset()
# 1st order
f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res)
res_vjp1 = f_jax_vjp1(*args_vjp1)
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00] [False False False] )""", testing_stream.output)
else:
assertMultiLineStrippedEqual(self, """
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
testing_stream.reset()
# 2nd order
f_jax_vjp2, args_vjp2 = wrap_vjp(f_jax_vjp1, args_vjp1, res_vjp1)
res_vjp2 = f_jax_vjp2(*args_vjp2)
# 3rd order
f_jax_vjp3, args_vjp3 = wrap_vjp(f_jax_vjp2, args_vjp2, res_vjp2)
_ = f_jax_vjp3(*args_vjp3)
def test_tap_vmap(self):
vmap_fun1 = jax.vmap(fun1)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
vmap_fun1(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
[ 8.00 10.00]
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
[24.00 30.00]""", testing_stream.output)
def test_tap_vmap_not_batched(self):
x = 3.
def func(y):
# x is not mapped, y is mapped
_, y = hcb.id_print((x, y), output_stream=testing_stream)
return x + y
vmap_func = jax.vmap(func)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
_ = vmap_func(vargs)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (None, 0)})]
( 3.00 [4.00 5.00] )""", testing_stream.output)
def test_tap_vmap_vmap(self):
# A 2D tensor with x[i, j] = i + j using 2 vmap
def sum(x, y):
return hcb.id_print(x + y, output_stream=testing_stream)
def sum_rows(xv, y):
return jax.vmap(sum, in_axes=(0, None))(xv, y)
def sum_all(xv, yv):
return jax.vmap(sum_rows, in_axes=(None, 0))(xv, yv)
xv = jnp.arange(5, dtype=np.int32)
yv = jnp.arange(3, dtype=np.int32)
# assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv)))
_ = sum_all(xv, yv)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
[[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]]""", testing_stream.output)
def test_tap_vmap_while(self):
"""Vmap of while."""
def func(x):
# like max(x, 2)
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream)
x2 = lax.while_loop(
lambda x: x < 2, lambda x: hcb.id_print(
x + 1, where="body:x+1", output_stream=testing_stream), x1)
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream)
return res
inputs = np.arange(5, dtype=np.int32)
self.assertAllClose(
np.array([2, 2, 2, 3, 4]),
jax.jit(jax.vmap(func))(inputs),
check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(
self, """
transforms: [('batch', {'batch_dims': (0,)})] where: before:x
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: after:x
[2 2 2 3 4]""", testing_stream.output)
def test_tap_vmap_while_tap_cond(self):
"""Vmap of while, with a tap in the conditional."""
def func(x):
# like max(x, 2)
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
output_stream=testing_stream),
lambda x: hcb.id_print(x + 1, where="w_b",
output_stream=testing_stream),
x1)
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
return res
inputs = np.arange(5, dtype=np.int32)
res = jax.jit(jax.vmap(func))(inputs)
hcb.barrier_wait()
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
assertMultiLineStrippedEqual(self, """
transforms: [('batch', {'batch_dims': (0,)})] where: 1
[0 1 2 3 4]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True True False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[1 2 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[ True False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
[2 3 3 4 5]
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
[False False False False False]
transforms: [('batch', {'batch_dims': (0,)})] where: 3
[2 2 2 3 4]""", testing_stream.output)
def test_tap_transforms_old_doc(self):
if not FLAGS.jax_host_callback_ad_transforms:
raise unittest.SkipTest("disabled for new behavior")
# Examples from the documentation
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
_, y = hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
return y * x
print(f"impl = {power3(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap = {jax.vmap(power3)(np.arange(3.))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [0. 1. 2.] [0. 1. 4.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
transforms: ['jvp'] what: x,x^2
( ( 3. 9. ) ( 0.1 0.6 ) )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
transforms: ['jvp', 'transpose'] what: x,x^2
( 0. 3. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )
transforms: ['jvp', 'transpose', ('batch', {'batch_dims': (None, 0)})] what: x,x^2
( 0. [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_transforms_doc(self):
# Examples from the documentation
if FLAGS.jax_host_callback_ad_transforms:
raise unittest.SkipTest("disabled for old behavior")
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
return y * x
print(f"impl = {power3(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
return primals, tangents
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
print_tangents((x, y))
return y * x
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: tangents
( 0.1 0.6 )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad = {jax.grad(power3)(3.)}")
hcb.barrier_wait()
# Only the primals by default
expected = """
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
# Must use the output of print_cotangents
(x1, y1) = print_cotangents((x, y))
return y1 * x1
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: cotangents
( 9. 3. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
# TODO: grad of grad
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
hcb.barrier_wait()
expected = """
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
( [2. 3.] [4. 9.] )
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
( [4. 9.] [2. 3.] )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
hcb.barrier_wait()
expected = """
what: x,x^2
( 3. 9. )
what: x,x^2
( 27. 729. )
what: x,x^2
( 3. 9. )"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
def test_tap_pmap(self):
if len(local_devices()) < 2:
raise SkipTest("test requires at least 2 devices")
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
_, y = hcb.id_print((x, y),
what="x,x^2",
output_stream=testing_stream,
tap_with_device=True)
return y * x
pmap_power3 = jax.pmap(power3, devices=local_devices())
xv = np.array([3, 4], dtype=np.int32)
res = pmap_power3(xv)
hcb.barrier_wait()
self.assertAllClose(xv * xv * xv, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(
self, """
device: cpu:0 what: x,x^2
( 3 9 )
device: cpu:1 what: x,x^2
( 4 16 )""")
def test_tap_pmap_vmap(self):
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
dtype=np.int32)
def fun1(x, do_print=False): # x: i32
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
pmap_vmap_fun1 = jax.pmap(
jax.vmap(partial(fun1, do_print=True)), devices=local_devices())
res = pmap_vmap_fun1(matrix)
hcb.barrier_wait()
expected_res = jax.pmap(
jax.vmap(partial(fun1, do_print=False)), devices=local_devices())(
matrix)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[0.00 2.00 4.00]
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[20.00 22.00 24.00]""")
def test_tap_pmap_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 + k
nr_devices = len(local_devices())
if nr_devices % 2 != 0:
raise SkipTest("test works only on even number of devices")
shape = (2, nr_devices // 2, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
def fun1(x, do_print=False): # x: f32
y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)
return y ** 2
pmap_fun1 = jax.pmap(
jax.pmap(jax.vmap(partial(fun1, do_print=True))),
devices=local_devices())
res = pmap_fun1(matrix)
hcb.barrier_wait()
expected_res = jax.pmap(
jax.pmap(jax.vmap(partial(fun1, do_print=False))),
devices=local_devices())(
matrix)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[0.00 2.00 4.00]
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[200.00 202.00 204.00]""")
@ignore_jit_of_pmap_warning()
def test_tap_pmap_pmap_extra(self):
"""pmap of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
if nr_devices != 2:
raise SkipTest("test works only on 2 devices")
shape = (2, 1, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
def fun(xv, do_print=False):
# This will be printed on all devices, with shape [1, 3]
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
# This will be printed on all devices, with shape [1, 3]
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
res = jax.pmap(partial(fun, do_print=True))(matrix)
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
hcb.barrier_wait()
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: before
[[1.00 2.00 3.00]]
device: cpu:0 what: inside
[2.00 4.00 6.00]
device: cpu:0 what: after
[[3.00 5.00 7.00]]
device: cpu:1 what: before
[[101.00 102.00 103.00]]
device: cpu:1 what: inside
[202.00 204.00 206.00]
device: cpu:1 what: after
[[203.00 205.00 207.00]]""")
def test_tap_jvp_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices())
shape = (nr_devices, 2, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
def fun(xv, do_print=False):
# x: f32[3]
return jax.jvp(jax.pmap(jax.vmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True))),
(xv,), (.1 * jnp.ones_like(xv),))
res = fun(matrix, do_print=True)
hcb.barrier_wait()
expected_res = fun(matrix, do_print=False)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :]
if FLAGS.jax_host_callback_ad_transforms:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[ 0.00 2.00 4.00]
[20.00 22.00 24.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
( [[200.00 202.00 204.00]
[220.00 222.00 224.00]] [[0.20 0.20 0.20]
[0.20 0.20 0.20]] )""")
else:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[ 0.00 2.00 4.00]
[20.00 22.00 24.00]]
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[200.00 202.00 204.00]
[220.00 222.00 224.00]]""")
def test_tap_vmap_pmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(local_devices())
shape = (2, nr_devices, 3)
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
dtype=np.float32)
def fun(xv, do_print=False):
# x: f32[3]
return jax.vmap(jax.pmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)))(xv)
res = fun(matrix, do_print=True)
hcb.barrier_wait()
expected_res = fun(matrix, do_print=False)
self.assertAllClose(expected_res, res, check_dtypes=False)
# Assertion text is for 2 devices (also works for 1 device)
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[:, 0, :]
assertMultiDeviceOutputEqual(self, """
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[ 0.00 2.00 4.00]
[200.00 202.00 204.00]]
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
[[ 20.00 22.00 24.00]
[220.00 222.00 224.00]]""")
@ignore_jit_of_pmap_warning()
def test_tap_jit_pmap_extra(self):
"""jit of a pmap surrounded by extra code."""
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
assert nr_devices in (1, 2)
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
dtype=np.float32)
def fun(xv, do_print=False):
# This will be printed on all devices with shape (nr_devices, 3)
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
# This will be printed on all devices with shape (nr_devices, 3)
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
res = jax.jit(partial(fun, do_print=True))(matrix)
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
hcb.barrier_wait()
if len(local_devices()) == 2:
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: before
[[ 1.00 2.00 3.00]
[11.00 12.00 13.00]]
device: cpu:0 what: inside
[2.00 4.00 6.00]
device: cpu:0 what: after
[[ 3.00 5.00 7.00]
[23.00 25.00 27.00]]
device: cpu:1 what: before
[[ 1.00 2.00 3.00]
[11.00 12.00 13.00]]
device: cpu:1 what: inside
[22.00 24.00 26.00]
device: cpu:1 what: after
[[ 3.00 5.00 7.00]
[23.00 25.00 27.00]]""")
else:
assert len(local_devices()) == 1
assertMultiDeviceOutputEqual(self, """
device: cpu:0 what: before
[[1.00 2.00 3.00]]
device: cpu:0 what: inside
[2.00 4.00 6.00]
device: cpu:0 what: after
[[3.00 5.00 7.00]]""")
@unittest.skip("cond of pmap does not work in JAX. Issue #5178.")
def test_tap_cond_pmap(self):
# A matrix M[ij] = i * 10 + j
nr_devices = len(local_devices())
shape = (nr_devices, 3)
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
dtype=np.float32)
def fun1(x, do_print=False):
return maybe_print(do_print, x * 2., "x * 2")
def fun2(cond, xv, do_print=False):
return lax.cond(cond, jax.pmap(partial(fun1, do_print=do_print)),
lambda xv: xv, xv)
res = fun2(True, matrix)
self.assertAllClose(fun2(True, matrix, do_print=False), res, check_dtypes=False)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
TBD""", testing_stream.output)
@jtu.skip_on_devices("cpu", "gpu")
# TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall
def test_tap_pjit(self):
devices = np.array(local_devices())
nr_devices = len(devices)
if nr_devices < 2:
raise SkipTest("test requires at least 2 devices")
print(f"test_tap_pjit is running on devices {devices}.")
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
# y: i32[3, 4]
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
y = jnp.ones((3, 4), np.int32)
@partial(jax.named_call, name="fun1") # for xprof debugging
def fun1(x, do_print=False):
z = jnp.dot(x, y)
return maybe_print(do_print, z, "z", tap_with_device=True)
res0 = fun1(x, do_print=False)
pjit_fun1 = pjit.pjit(
partial(fun1, do_print=True),
in_axis_resources=(P("d"),),
out_axis_resources=P("d"))
with maps.mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",
pjit_fun1,
x,
num_partitions=nr_devices)
res = pjit_fun1(x)
self.assertAllClose(res0, res)
hcb.barrier_wait("before check")
# Assertion text is for 2 devices (also works for 1 device)
# Note that a single call is made.
assertMultiDeviceOutputEqual(
self, """
device: cpu:0 what: z
[[ 3 3 3 3]
[33 33 33 33]]""")
def test_tap_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
@jax.custom_jvp
def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * hcb.id_print(x_dot, output_stream=testing_stream, what="x_dot")
return primal_out, tangent_out
def g(x):
# Sum f(x_i)
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
np.full(x.shape[1:], 0.), # Like x w/o leading dim
x)[0]
arg = np.full((2,), 0.7)
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
what: x
0.7
what: x
0.7""", testing_stream.output)
testing_stream.reset()
self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False)
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
what: x
0.7
what: x
0.7
transforms: ['transpose'] what: x_dot
2.1
transforms: ['transpose'] what: x_dot
2.1""", testing_stream.output)
def test_tap_scan_custom_vjp(self):
"""custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives."""
@jax.custom_vjp
def f(x):
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * hcb.id_print(ct_b, output_stream=testing_stream, what="ct_b"),
f.defvjp(f_fwd, f_bwd)
def g(x):
# Sum f(x_i)
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
np.full(x.shape[1:], 0.), # Like x w/o leading dim
x)[0]
arg = np.full((2,), 0.7)
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
what: x
0.7
what: x
0.7""", testing_stream.output)
testing_stream.reset()
self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False)
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
what: x
0.7
what: x
0.7
what: ct_b
1.
what: ct_b
1.""", testing_stream.output)
def test_tap_mask(self):
@partial(jax.mask, in_shapes=['n'], out_shape='')
def padded_sum(x):
three_x = hcb.id_print((x, 2 * x), result=3 * x, what="x",
output_stream=testing_stream)
return jnp.sum(three_x)
x = np.arange(5.)
self.assertAllClose(9., padded_sum([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# With VMAP
xv = np.arange(10.).reshape((2, 5)) # logical_shape = 5
self.assertAllClose(
np.array([9., 78.]),
# batch_size = 2, n=3 and 4 for the two elements
jax.vmap(padded_sum)([xv],
dict(n=np.array([3., 4.]))))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), ('batch', {'batch_dims': (0, 0, 0, 0)})] what: x
( ( [[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]]
[[ 0. 2. 4. 6. 8.]
[10. 12. 14. 16. 18.]] )
( ( [3. 4.] ) ( [3. 4.] ) ) )""", testing_stream.output)
testing_stream.reset()
# With JVP
self.assertAllClose((9., 0.9),
jax.jvp(lambda arg: padded_sum([arg], dict(n=3)),
(x,), (x * 0.1,)))
hcb.barrier_wait()
if FLAGS.jax_host_callback_ad_transforms:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5}), 'jvp'] what: x
( ( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )
( ( [0. 0.1 0.2 0.3 0.4] [0. 0.2 0.4 0.6 0.8] ) ( ( False ) ( False ) ) ) )""",
testing_stream.output)
else:
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
testing_stream.reset()
# Now with JIT
self.assertAllClose(9., jax.jit(padded_sum)([x], dict(n=3)))
hcb.barrier_wait()
self.assertMultiLineStrippedEqual("""
transforms: [('mask', {'logical_shapes': 5})] what: x
( ( [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.] ) ( ( 3 ) ( 3 ) ) )""",
testing_stream.output)
def test_tap_callback_delay(self):
hcb.callback_extra = lambda dev: time.sleep(1)
def func(x):
for i in range(5):
x = hcb.id_print(x * i, what="x times i")
return x
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
def test_tap_callback_delay_barrier(self):
hcb.callback_extra = lambda dev: time.sleep(2)
def func(x):
for i in range(1, 4):
x = hcb.id_print(x * i, what=f"x times {i}", output_stream=testing_stream)
return x
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
# Wait for the results
hcb.barrier_wait("first")
expected = """
what: x times 1
[[0. 1. 2.]
[3. 4. 5.]]
what: x times 2
[[ 0. 2. 4.]
[ 6. 8. 10.]]
what: x times 3
[[ 0. 6. 12.]
[18. 24. 30.]]"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
testing_stream.reset()
# Call again
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
hcb.barrier_wait("second")
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_error_bad_consumer_id(self):
"""Try to use reserved consumer ID 0.
Check that we get the proper error from the runtime."""
if not hcb._use_outfeed(jtu.device_under_test()):
raise SkipTest("test works only for outfeed")
comp = xla_client.XlaBuilder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
with self.assertRaisesRegex(RuntimeError,
"Consumer ID cannot be a reserved value: 0"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 0,
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
def test_tap_error_different_shapes(self):
"""Try to register different shapes for the same consumer ID."""
if not hcb._use_outfeed(jtu.device_under_test()):
raise SkipTest("test works only for outfeed")
comp = xla_client.XlaBuilder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))])
with self.assertRaisesRegex(
RuntimeError, ".*does not match previous shape element_type.*"):
hcb._callback_handler_data.receiver.add_outfeed(
comp, token, 123,
[xops.Constant(comp, np.zeros((2,), dtype=np.float32))])
def test_tap_id_tap_removed_kwargs(self):
def func(x, transforms, y):
pass
with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"):
hcb.id_tap(func, 1, y=2)
def test_tap_odeint(self):
# TODO: find a smaller repro for bug #4015
# Seems to be xla_call(scan(xla_call)), all under grad.
from jax.experimental.ode import odeint
def f(x, t, k):
x = hcb.id_print(x)
return -k * x
def loss(k=1.0):
t = jnp.linspace(0, 0.001, num=2)
xs = odeint(f, 1.0, t, k)
return xs[-1]
jax.grad(loss)(1.0) # should not fail
def test_tap_remat_0(self):
def f(i, k):
x = hcb.id_print(k + i, output_stream=testing_stream)
return k * x
def loss(k):
return lax.fori_loop(0, 2, jax.remat(f), k)
print(loss(3))
hcb.barrier_wait()
expected = """
3
10"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_use_remat={use_remat}_{grad_func}_use_result={use_result}",
use_result=use_result, use_remat=use_remat, grad_func=grad_func)
for use_result in [True, False]
for grad_func in ["grad", "value_and_grad"]
for use_remat in ["old", "new", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
def f(x):
id_print_result = hcb.id_print(x, output_stream=testing_stream)
if use_result:
x = id_print_result
return 3. * x
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
if use_remat == "old":
trans_f = jax.remat(f)
elif use_remat == "new":
trans_f = ad_checkpoint.checkpoint(f)
else:
assert use_remat == "none"
trans_f = f
print(jax.make_jaxpr(grad_f(trans_f))(2.))
grad_f(trans_f)(2.)
hcb.barrier_wait()
if use_remat == "none":
if use_result:
if FLAGS.jax_host_callback_ad_transforms:
expected = """
2.
transforms: ['jvp', 'transpose']
3."""
else:
# GOOD: whether or not we use_result, in absence of
# jax_host_callback_ad_transforms we get the same callback.
expected = "2."
else:
expected = "2."
else: # use_remat
if use_result:
if FLAGS.jax_host_callback_ad_transforms:
expected = """
2.
2.
transforms: ['jvp', 'transpose']
3."""
else:
expected = """
2.
2."""
else:
if use_remat == "old":
# TODO: we should see two callbacks
expected = ""
else:
# Good: we see two callbacks, whether or not we use the result.
expected = """
2.
2."""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_named_call(self):
def tap_scalar(init, do_print=False):
@partial(jax.named_call, name="step")
def step(acc, step_nr):
acc = acc + step_nr
maybe_print(do_print, step_nr, what="step_nr")
return acc, None
return lax.scan(step, init, np.arange(2))
self.assertAllClose(tap_scalar(3., do_print=False), tap_scalar(3., do_print=True))
hcb.barrier_wait()
expected = """
what: step_nr
0
what: step_nr
1"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
class HostCallbackCallTest(jtu.JaxTestCase):
"""Tests for hcb.call"""
def setUp(self):
super().setUp()
if jtu.device_under_test() == "gpu" and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
testing_stream.reset()
testing_stream._test_method_name = self._testMethodName
def tearDown(self) -> None:
hcb.barrier_wait("HostCallbackCallTest.tearDown")
super().tearDown()
def call_log_testing_stream(self, func, arg, *, result_shape, name=""):
"""Call `func` and log inputs and outputs to the testing stream"""
def call_log(arg):
def val2str(v):
return np.array2string(np.array(arg))
testing_stream.write(f"Call {name}({val2str(arg)})\n")
res = func(arg)
testing_stream.write(f" = {val2str(res)}\n")
return res
return hcb.call(call_log, arg, result_shape=result_shape)
def test_call_simple(self):
def f_outside(x):
return 2 * x
def fun(x):
y = hcb.call(f_outside, x + 1, result_shape=x)
return 3 * (1 + y)
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg))
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{np.dtype(dtype).name}", dtype=dtype)
for dtype in jtu.dtypes.all
if dtype != np.bool_))
def test_call_types(self, dtype=np.float64):
def f_outside(x):
# Use x + x to ensure that the result type is the same
return x + x
def fun(x):
return hcb.call(f_outside, x + x, result_shape=x)
arg = np.arange(24, dtype=dtype).reshape((2, 3, 4))
self.assertAllClose(arg + arg + arg + arg, fun(arg), check_dtypes=True)
def test_call_types_bool(self, dtype=np.float64):
def f_outside(x):
return np.invert(x)
def fun(x):
return hcb.call(f_outside, x, result_shape=x)
arg = self.rng().choice(a=[True, False], size=(2, 3, 4))
self.assertAllClose(np.invert(arg), fun(arg))
def test_call_tuples(self):
def f_outside(args):
x, y = args
return y, x # Swap the tuple
def fun(x):
xy = hcb.call(f_outside, (x, x + 1), result_shape=(x, x))
return 2 * xy[0] + 3 * xy[1]
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg))
def test_call_empty_arg(self):
"""Call with empty array."""
result = np.ones((2,), dtype=np.float32)
def f_outside(_):
return result
def fun(x):
return x + hcb.call(f_outside, (),
result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype))
self.assertAllClose(2. + result, fun(2.))
def test_call_empty_result(self):
"""Call returning empty array."""
result_shape = (2, 0)
def f_outside(_):
return np.ones(result_shape, dtype=np.float32)
def fun(x):
return x + hcb.call(f_outside, 1.,
result_shape=jax.ShapeDtypeStruct(result_shape, np.float32))
self.assertAllClose(f_outside(0.), fun(2.))
def test_call_empty_result_inside_pytree(self):
"""Call returning a tuple with an empty array and a non-empty one."""
result_shape_0 = (2, 0)
result_shape_2 = (0,)
def f_outside(_):
return (np.ones(result_shape_0, dtype=np.float32),
np.ones((1,), dtype=np.float32),
np.ones(result_shape_2, dtype=np.float32))
def fun(x):
res = hcb.call(f_outside, 1.,
result_shape=(jax.ShapeDtypeStruct(result_shape_0, np.float32),
jax.ShapeDtypeStruct((1,), np.float32),
jax.ShapeDtypeStruct(result_shape_2, np.float32)))
self.assertEqual(result_shape_0, res[0].shape)
self.assertEqual(result_shape_2, res[2].shape)
return x + res[1]
self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.))
def test_call_empty_result_all_pytree(self):
"""Call returning a tuple of empty arrays."""
result_shape = (2, 0)
def f_outside(_):
return (np.ones(result_shape, dtype=np.float32),
np.ones(result_shape, dtype=np.float32))
def fun(x):
res = hcb.call(f_outside, 1.,
result_shape=(jax.ShapeDtypeStruct(result_shape, np.float32),
jax.ShapeDtypeStruct(result_shape, np.float32)))
return x + res[0] + res[1]
self.assertAllClose(np.ones(result_shape, dtype=np.float32),
fun(2.))
def test_call_no_result(self):
def f_outside(arg):
self.call_log_testing_stream(lambda x: None, arg,
result_shape=None,
name="outside")
return arg
self.assertAllClose((3., 4.), f_outside((3., 4.)))
hcb.barrier_wait()
expected = """
Call outside([3. 4.])
= [3. 4.]"""
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_call_cond(self):
def f_outside(args):
x, y = args
return x * y
def loop(x, use_outside=True):
def body(i, acc):
return lax.cond(i % 2 == 1,
lambda _: (hcb.call(f_outside, (acc, i),
result_shape=acc)
if use_outside else f_outside((acc, i))),
lambda _: acc,
None)
return lax.fori_loop(0, 18, body, x)
res_inside = loop(1.2, use_outside=False)
self.assertAllClose(res_inside, jax.jit(loop)(1.2))
def test_call_jit_scan_call(self):
def f_outside(x):
return x
def loop(x, use_outside=True):
def body(carry, i):
if use_outside:
return carry + hcb.call(f_outside, i,
result_shape=i), None
else:
return carry + i, None
return lax.scan(body, 0, x)
x = np.arange(5, dtype=np.int32)
res_outside = jax.jit(partial(loop, use_outside=True))(x)
self.assertAllClose(res_outside, loop(x, use_outside=False))
def test_call_doc_example1(self):
"""Examples from the documentation: simplest, call a function"""
def host_eig(x):
return np.linalg.eigvals(x)
shape = (2, 5, 4, 4)
m = np.ones(shape, dtype=np.float32)
def fun(m):
eig_m = hcb.call(host_eig, m,
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
return eig_m
expected_res = np.linalg.eigvals(m)
self.assertAllClose(expected_res, fun(m))
def test_call_doc_example_hlo(self):
"""Examples from the documentation: simplest, call a function."""
def fun1(m):
return jnp.sin(hcb.call(lambda x: np.cos,
jnp.cos(m),
result_shape=m))
m = np.ones((2,), np.float32)
helper_print_optimized_hlo(fun1, m)
def fun2(m):
x = hcb.call(lambda x: None, 2, result_shape=())
return x
m = np.ones((2,), np.float32)
helper_print_optimized_hlo(fun2, m)
def test_call_with_device(self):
def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
return x
def func(x):
return hcb.call(callback_func, x,
result_shape=x,
call_with_device=True)
self.assertEqual(3., func(3.))
assertMultiDeviceOutputEqual(self, """
device: cpu:0
Called with 3.00""")
def test_call_pmap(self):
# Works for 1 or 2 devices
def callback_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
return x * np.array(3, np.int32)
def fun(x): # x: i32
return hcb.call(callback_func, x * 2,
result_shape=x,
call_with_device=True)
xv = jnp.arange(len(local_devices()), dtype=jnp.int32)
res = jax.pmap(fun)(xv)
self.assertAllClose(jax.pmap(lambda x: x * 6)(xv), res)
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(self, """
device: cpu:0
Called with 0
device: cpu:1
Called with 2""")
def test_call_vmap(self):
def f_outside(x): return x
def fun(x):
return hcb.call(f_outside, x, result_shape=x)
with self.assertRaisesRegex(NotImplementedError,
"batching rules are implemented only for id_tap, not for call"):
jax.vmap(fun)(np.ones((2, 3)))
@jtu.skip_on_devices("cpu", "gpu")
# TODO(necula): file XLA:GPU bug for the 'Sharding' CustomCall
def test_call_pjit(self):
devices = np.array(local_devices())
nr_devices = len(devices)
if nr_devices < 2:
raise SkipTest("test requires at least 2 devices")
print(f"test_call_pjit is running on devices {devices}.")
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
# y: i32[3, 4]
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
y = jnp.ones((3, 4), np.int32)
def callback_x5_func(x, device=None):
testing_stream.write(f"device: {device}\n Called with {x}")
return x * np.array(5, np.int32)
def fun(x):
xy = jnp.dot(x, y)
return hcb.call(
callback_x5_func, xy, result_shape=xy, call_with_device=True)
pjit_fun = pjit.pjit(
fun, in_axis_resources=(P("d"),), out_axis_resources=P("d"))
with maps.mesh(devices, ["d"]):
# Print the internal IR
helper_log_ir(
f"{self._testMethodName}.pjit",
pjit_fun,
x,
num_partitions=nr_devices)
res = pjit_fun(x)
expected_res = jnp.dot(x, y) * np.array(5, np.int32)
self.assertAllClose(expected_res, res, check_dtypes=False)
hcb.barrier_wait("before assertion")
# Assertion text is for 2 devices (also works for 1 device)
assertMultiDeviceOutputEqual(
self, """
device: cpu:0
Called with [[ 3 3 3 3]
[33 33 33 33]]""")
def test_call_error_bad_result_shape(self):
with self.assertRaisesRegex(
ValueError,
"The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"):
hcb.call(lambda x: x, 3., result_shape="string")
with self.assertRaisesRegex(
ValueError,
"The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"):
hcb.call(lambda x: x, 3., result_shape=lambda x: x)
hcb.barrier_wait("wait for error")
def helper_check_callback_errors(self, thunk: Callable,
expected_exc_txt: str):
"""Calls thunk() and checks for expected exceptions.
"""
if jtu.device_under_test() == "cpu":
# On CPU the runtime crashes, and the tests are all aborted
raise SkipTest("TODO: CPU runtime crashes on unexpected infeed")
elif jtu.device_under_test() == "gpu":
# On GPU we get a nice error back to Python
with self.assertRaisesRegex(
RuntimeError,
"RET_CHECK failure .* Mismatch between infeed source buffer shape s8.12345."):
thunk()
elif jtu.device_under_test() == "tpu":
# On TPU we get no error!!!
raise SkipTest("TODO: TPU runtime does not check infeed, and just computes with garbage")
# Both on GPU and TPU we also get an error during the barrier_wait at the
# end of the test. Run a barrier_wait now, to consume that error.
with self.assertRaisesRegex(
hcb.CallbackException,
re.compile(
"There were exceptions during callback processing.*Last one was:.*" +
expected_exc_txt,
re.DOTALL)):
hcb.barrier_wait("Waiting for error")
def test_call_error_callback_throws_exception(self):
def f_outside(x):
raise ValueError("user exception")
def fun(x):
return hcb.call(f_outside, x, result_shape=x)
self.helper_check_callback_errors(lambda: fun(3.),
"ValueError: user exception")
def test_call_error_callback_returns_unexpected_shape(self):
def fun(x):
return hcb.call(lambda x: (x, x), x, result_shape=x)
self.helper_check_callback_errors(lambda: fun(3.),
"Callback func .* should have returned a result with pytree")
def test_call_error_then_compute(self):
# Continue computation on device after error
def f_outside(x):
raise ValueError("user exception")
def fun(x):
x1 = hcb.call(f_outside, x, result_shape=x)
return x1
arg = np.arange(3, dtype=np.int32)
self.helper_check_callback_errors(lambda: self.assertAllClose(arg, fun(arg)),
"ValueError: user exception")
def call_jax_other_device(jax_outside_fun, arg, *, device):
"""Calls a JAX function on a specific device with simple support for reverse AD.
Functions whose name starts with "jax_outside" are called on another device,
by way of hcb.call.
"""
def run_jax_outside_fun(arg):
return jax.jit(jax_outside_fun)(jax.device_put(arg, device))
@jax.custom_vjp
def make_call(arg):
return hcb.call(run_jax_outside_fun, arg,
result_shape=jax.eval_shape(jax_outside_fun, arg))
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(arg):
# Return the primal argument as the residual. Use `make_call` for the
# primal computation to enable higher-order AD.
return make_call(arg), arg # Return the primal argument as the residual
def make_call_vjp_bwd(res, ct_res):
arg = res # residual is the primal argument
def jax_outside_vjp_fun(arg_and_ct):
arg, ct = arg_and_ct
_, f_vjp = jax.vjp(jax_outside_fun, arg)
ct_in, = f_vjp(ct)
return ct_in
return (call_jax_other_device(jax_outside_vjp_fun, (arg, ct_res), device=device),)
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return make_call(arg)
class CallJaxTest(jtu.JaxTestCase):
"""Tests using `call_jax_other_device`."""
def setUp(self):
if jtu.device_under_test() == "gpu" and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
if jtu.device_under_test() != "cpu":
assert jax.devices("cpu")
self.outside_device = jax.devices("cpu")[0]
else:
if len(jax.devices("cpu")) == 1:
raise SkipTest("Test needs at least two devices. On CPU use XLA_FLAGS=--xla_force_host_platform_device_count=2")
self.outside_device = jax.devices("cpu")[1]
super().setUp()
def test_jax_impl(self):
def f_jax(x):
return jnp.sin(x)
def f_outside(x):
return call_jax_other_device(f_jax, x, device=self.outside_device)
self.assertAllClose(f_jax(3.), f_outside(3.))
self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.))
def test_jax_impl_pytree(self):
def f_jax(x):
# x : dict(a=..., b=...) and output is a list of two elements
return [jnp.sin(x["a"]), jnp.sin(x["b"])]
def f_outside(x):
return call_jax_other_device(f_jax, x, device=self.outside_device)
x = dict(a=3., b=4.)
res_jax = f_jax(x)
# print(f"outside_jaxpr = {jax.make_jaxpr(f_outside)(x)}")
res_outside = f_outside(x)
self.assertAllClose(res_jax, res_outside)
def test_jax_grad(self):
def f_jax(x):
return 2. * jnp.sin(x)
def f_outside(x):
return 2. * call_jax_other_device(jnp.sin, x, device=self.outside_device)
res_jax = jax.grad(f_jax)(3.)
self.assertAllClose(res_jax, jax.grad(f_outside)(3.))
def test_jax_grad_pytree(self):
def f_jax(x):
# x : dict(a=..., b=...) and output is a float
return 3. * jnp.sin(x["a"]) + jnp.sin(x["b"])
def f_outside(x):
return call_jax_other_device(f_jax, x, device=self.outside_device)
x = dict(a=3., b=4.)
res_jax = jax.grad(f_jax)(x)
self.assertAllClose(res_jax, jax.grad(f_outside)(x))
def test_jax_grad_of_grad(self):
def f_jax(x):
return 2. * x * x * x
def f_outside(x):
return 2. * call_jax_other_device(lambda x: x * x * x, x, device=self.outside_device)
res_jax = jax.grad(jax.grad(f_jax))(5.)
res_outside = jax.grad(jax.grad(f_outside))(5.)
self.assertAllClose(res_jax, res_outside)
class OutfeedRewriterTest(jtu.JaxTestCase):
def setUp(self):
if jtu.device_under_test() == "gpu" and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
super().setUp()
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
has_input_token=True, has_output_token=True):
"""Check that the rewrite of func(*args) matches expected."""
jaxpr = jax.make_jaxpr(func)(*args)
rewritten = hcb._rewrite_closed_jaxpr(jaxpr, # noqa: F841
has_input_token, has_output_token)
# Since it is somewhat annoying to update the Jaxpr assertions when we change
# the Jaxpr printing, we do not check these by default. It is recommended that
# before making changes to the code generation and Jaxpr rewriting, turn on
# the checking, update the expected Jaxpr, and then make the changes.
# assertMultiLineStrippedEqual(self, expected, str(rewritten))
del rewritten
def test_no_outfeed(self):
self.assertRewrite("""
{ lambda ; a.
let b = mul a a
c = add a b
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False,
has_output_token=False)
self.assertRewrite("""
{ lambda ; a d e.
let b = mul a a
c = add a b
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
self.assertRewrite("""
{ lambda ; a d e.
let b = mul a a
c = add a b
in (c, d, e) }""", lambda x: x + x * x, [0])
def test_simple_outfeed(self):
self.assertRewrite("""
{ lambda ; a d e.
let b = add a a
c f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] b d e
in (c, f, g) }""", lambda x: hcb.id_print(x + x), [0])
def test_simple_outfeed_without_input_token(self):
self.assertRewrite("""
{ lambda ; a b.
let e = create_token a b
f = create_token a b
c = add a b
d g h = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] c e f
in (d,) }""", lambda x1, x2: hcb.id_print(x1 + x2), [1, 2],
has_input_token=False, has_output_token=False)
def test_simple_outfeed_without_input_token_nor_invars(self):
self.assertRewrite("""
{ lambda ; .
let b = create_token
c = create_token
a d e = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] 42 b c
in (a,) }""", lambda: hcb.id_print(42), [],
has_input_token=False, has_output_token=False)
def test_multiple_tap_without_dependencies(self):
def f(x):
hcb.id_print(x, what="x")
hcb.id_print(x + 1, what="x + 1")
return 2
self.assertRewrite("""
{ lambda ; a c d.
let _ e f = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a c d
b = add a 1
_ g h = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] b e f
in (2, g, h) }""", f, [1])
def test_cond(self):
y = jnp.ones(5) # captured const
def func(x, z):
return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)),
z, lambda a: (hcb.id_print(a), y))
self.assertRewrite("""
{ lambda a ; b c h i.
let d = gt c 0
e = convert_element_type[ new_dtype=int32 ] d
f g j k =
cond[ branches=( { lambda ; a b c d f g.
let e h i = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] d f g
in (e, a, h, i) }
{ lambda ; f_ a b c g h.
let d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(5,) ] 0.00
in (a, d, g, h) } )
linear=(False, False, False, False, False, False) ] e a 1 2 c h i
in (f, g, j, k) }""", func, [y, 5])
def test_while(self):
ct_body = jnp.ones(5, np.float32) # captured const for the body
ct_cond = jnp.ones(5, np.float32) # captured const for the conditional
def func(x):
# x: f32[5]
# c: (f32[5], f32)
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
(x, np.float32(1.)))
self.assertRewrite("""
{ lambda a b ; c f g.
let d e h i =
while[ body_jaxpr={ lambda ; a b c f g.
let d h i = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] c f g
e = add d 1.00
in (a, e, h, i) }
body_nconsts=1
cond_jaxpr={ lambda ; a b c g h.
let d = add b a
e = reduce_sum[ axes=(0,) ] d
f = lt c e
in (f,) }
cond_nconsts=1 ] a b c 1.00 f g
in (d, e, h, i) }""", func, [ct_body])
def test_while_pred_outfeed(self):
"""A while with outfeed in the pred."""
ct_body = jnp.ones(5) # captured const for the body
ct_cond = jnp.ones(2) # captured const for the conditional
def func(x):
return lax.while_loop(lambda c: hcb.id_print(ct_cond, result=c[1]) < 5,
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
(x, 1))
self.assertRewrite("""
{ lambda a b ; c f g.
let j k l = xla_call[ call_jaxpr={ lambda ; a b c g h.
let d i j = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a g h
e = id_tap_dep c d
f = lt e 5
in (f, i, j) }
donated_invars=(False, False, False, False, False)
name=cond_before ] a c 1 f g
bf d e h i =
while[ body_jaxpr={ lambda ; r s t u v w x.
let y z ba bb =
xla_call[ call_jaxpr={ lambda ; a b c f g.
let d h i = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] c f g
e = add d 1
in (a, e, h, i) }
donated_invars=(False, False, False, False, False)
name=body ] s u v w x
bc bd be =
xla_call[ call_jaxpr={ lambda ; a b c g h.
let d i j = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a g h
e = id_tap_dep c d
f = lt e 5
in (f, i, j) }
donated_invars=(False, False, False, False, False)
name=cond_body ] r y z ba bb
in (bc, y, z, bd, be) }
body_nconsts=2
cond_jaxpr={ lambda ; m n o p q.
let
in (m,) }
cond_nconsts=0 ] a b j c 1 k l
in (d, e, h, i) }""", func, [ct_body])
def test_scan(self):
y = jnp.ones(5) # captured const
def func(x):
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
self.assertRewrite("""
{ lambda a ; b f g.
let c d h i e =
scan[ jaxpr={ lambda ; a b c g h d.
let e f i j =
outside_call[ arg_treedef=PyTreeDef(tuple, [*,*])
callback=...
has_token=True
identity=True ] b c g h
in (e, f, i, j, a) }
length=5
linear=(False, False, False, False, False, False)
num_carry=4
num_consts=1
reverse=False
unroll=1 ] a 1 2 f g b
in (c, d, e, h, i) }""", func, [y])
def test_scan_custom_jvp(self):
"""custom JVP, inside scan.
This exercises the custom_jvp_call_jaxpr primitives."""
@jax.custom_jvp
def f(x):
return x * hcb.id_print(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = 3. * x * hcb.id_print(x_dot)
return primal_out, tangent_out
def g(x):
# Sum f(x_i)
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
np.full(x.shape[1:], 0.), # Like x w/o leading dim
x)[0]
arg = np.full((5,), 0.7)
self.assertRewrite("""
{ lambda ; a c d.
let b e f _ =
scan[ jaxpr={ lambda ; a e f b.
let c g h = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e.
let b f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a d e
c = mul a b
in (c, f, g) }
num_consts=0 ] b e f
d = add a c
in (d, g, h, 0.00) }
length=5
linear=(False, False, False, False)
num_carry=3
num_consts=0
reverse=False
unroll=1 ] 0.00 c d a
in (b, e, f) }""", g, [arg])
self.assertRewrite("""
{ lambda ; a d e.
let _ _ f g _ b =
scan[ jaxpr={ lambda ; a b h i c d.
let e j k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e.
let b f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a d e
c = mul a b
in (c, f, g) }
num_consts=0 ] c h i
f = add a e
g = mul c 3.00
in (f, *, j, k, 0.00, g) }
length=5
linear=(False, True, False, False, False, True)
num_carry=4
num_consts=0
reverse=False
unroll=1 ] 0.00 * d e a *
_ _ h i _ c =
scan[ jaxpr={ lambda ; a b g h c d.
let e = mul b d
f i j = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True
transforms=(('transpose',),) ] e g h
in (*, b, i, j, *, f) }
length=5
linear=(True, True, False, False, True, False)
num_carry=4
num_consts=0
reverse=True
unroll=1 ] * 1.00 f g * b
in (c, h, i) }""", jax.grad(g), [arg])
def test_scan_custom_vjp(self):
"""custom VJP, inside scan.
This exercises the custom_vjp_call_jaxpr primitives."""
@jax.custom_vjp
def f(x):
return x * hcb.id_print(x)
# f_fwd: a -> (b, residual)
def f_fwd(x):
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
return residual * hcb.id_print(ct_b),
f.defvjp(f_fwd, f_bwd)
def g(x):
# Sum f(x_i)
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
np.full(x.shape[1:], 0.), # Like x w/o leading dim
x)[0]
arg = np.full((2,), 0.7)
self.assertRewrite("""
{ lambda ; a c d.
let b e f _ =
scan[ jaxpr={ lambda ; a e f b.
let c g h = custom_vjp_call_jaxpr[
fun_jaxpr={ lambda ; a d e.
let b f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a d e
c = mul a b
in (c, f, g) }
num_consts=0
] b e f
d = add a c
in (d, g, h, 0.00) }
length=2
linear=(False, False, False, False)
num_carry=3
num_consts=0
reverse=False
unroll=1 ] 0.00 c d a
in (b, e, f) }""", g, [arg])
self.assertRewrite("""
{ lambda ; a d e.
let _ _ f g _ b =
scan[ jaxpr={ lambda ; a b h i c d.
let e j k = custom_vjp_call_jaxpr[
fun_jaxpr={ lambda ; a d e.
let b f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a d e
c = mul a b
in (c, f, g) }
num_consts=0
] c h i
f = add a e
g = mul c 3.00
in (f, *, j, k, 0.00, g) }
length=2
linear=(False, True, False, False, False, True)
num_carry=4
num_consts=0
reverse=False
unroll=1 ] 0.00 * d e a *
_ _ h i _ c =
scan[ jaxpr={ lambda ; a b g h c d.
let e i j = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] b g h
f = mul d e
in (*, b, i, j, *, f) }
length=2
linear=(True, True, False, False, True, False)
num_carry=4
num_consts=0
reverse=True
unroll=1 ] * 1.00 f g * b
in (c, h, i) }""", jax.grad(g), [arg])
def test_remat_loop(self):
def f(k, x):
x = hcb.id_print(k + x)
return -k * x
def loss(k):
return lax.fori_loop(0, 1, jax.remat(f), k)
self.assertRewrite("""
{ lambda ; a c d.
let _ _ b e f =
while[ body_jaxpr={ lambda ; a b c f g.
let d = add a 1
e h i = remat_call[ call_jaxpr={ lambda ; a b g h.
let c = add a b
d i j = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] c g h
e = neg a
f = mul e d
in (f, i, j) }
concrete=False
name=f ] a c f g
in (d, b, e, h, i) }
body_nconsts=0
cond_jaxpr={ lambda ; a b c e f.
let d = lt a b
in (d,) }
cond_nconsts=0 ] 0 1 a c d
in (b, e, f) }""", loss, [2])
def test_named_call(self):
def tap_scalar(init, do_print=False):
@partial(jax.named_call, name="step")
def step(acc, step_nr):
acc = acc + step_nr
maybe_print(do_print, step_nr, what="step_nr")
return acc, None
return lax.scan(step, init, np.arange(2, dtype=np.int32))
self.assertRewrite("""
{ lambda a ; b d e.
let c = scan[ jaxpr={ lambda ; a b.
let c = named_call[ call_jaxpr={ lambda ; a b.
let c = add a b
in (c,) }
name=step ] a b
in (c,) }
length=2
linear=(False, False)
num_carry=1
num_consts=0
reverse=False
unroll=1 ] b a
in (c, d, e) }""", tap_scalar, [np.int32(3)])
def test_pmap(self):
def f(xv):
jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
axis_name="i")(xv)
self.assertRewrite("""
{ lambda ; a b c.
let _ d e = xla_pmap[ axis_name=i
axis_size=1
backend=None
call_jaxpr={ lambda ; a d e.
let b f g = outside_call[ arg_treedef=*
callback=...
has_token=True
identity=True ] a d e
c = sin b
in (c, f, g) }
devices=None
donated_invars=(False, False, False)
global_arg_shapes=(None,)
global_axis_size=None
in_axes=(0, 0, 0)
name=<lambda>
out_axes=(0, 0, 0) ] a b c
in (d, e) }""", f, [np.array([2.], dtype=np.float32)])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())