mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 23:36:06 +00:00

The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008.
2878 lines
101 KiB
Python
2878 lines
101 KiB
Python
# Copyright 2020 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Sequence
|
|
from functools import partial
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
from typing import Callable, Optional
|
|
import unittest
|
|
from unittest import skip, SkipTest
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
from jax import ad_checkpoint
|
|
from jax import config
|
|
from jax import dtypes
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
from jax import tree_util
|
|
from jax.experimental import host_callback as hcb
|
|
from jax.experimental import pjit
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax._src import core
|
|
from jax._src import xla_bridge
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib import xla_client
|
|
|
|
xops = xla_client.ops
|
|
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
class _TestingOutputStream:
|
|
"""Use as `output_stream` for tests."""
|
|
|
|
def __init__(self):
|
|
self._output = []
|
|
self._test_method_name = None
|
|
|
|
def write(self, what: str) -> None:
|
|
logging.info(f"output_stream[{self._test_method_name}]: {what}")
|
|
self._output.append(what)
|
|
|
|
@property
|
|
def output(self):
|
|
return "".join(self._output)
|
|
|
|
@property
|
|
def output_sorted_by_device(self):
|
|
# Assume that the output is a sequence of strings including metadata
|
|
# and data, with metadata containing `device: xxx`
|
|
by_device = [] # each element is a pair (device, str_list)
|
|
for s in self._output:
|
|
m = re.match(r".*device: (\S+)", s)
|
|
if m:
|
|
by_device.append((m.group(1), []))
|
|
assert by_device, f"output does not include 'device:': {self._output}"
|
|
by_device[-1][1].append(s)
|
|
|
|
sorted_by_device = sorted(by_device, key=lambda x: x[0])
|
|
return "\n".join(itertools.chain(*[s[1] for s in sorted_by_device]))
|
|
|
|
def __str__(self):
|
|
return "TestingOutputStream"
|
|
|
|
def reset(self):
|
|
self._output = []
|
|
|
|
|
|
testing_stream = _TestingOutputStream()
|
|
|
|
|
|
def fun1(a):
|
|
"""Function used for several `id_tap` tests."""
|
|
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
|
|
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
|
|
return y ** 2 # Some computation to make the gradient interesting
|
|
|
|
|
|
def fun1_equiv(a): # Numerical equivalent of fun1
|
|
return (a * 2.) ** 2
|
|
|
|
|
|
def maybe_print(do_print: bool,
|
|
arg,
|
|
what: str,
|
|
tap_with_device: Optional[bool] = False,
|
|
device_index: int = 0):
|
|
"""Conditionally print on testing_string"""
|
|
if do_print:
|
|
return hcb.id_print(
|
|
arg,
|
|
what=what,
|
|
output_stream=testing_stream,
|
|
tap_with_device=tap_with_device,
|
|
device_index=device_index)
|
|
else:
|
|
return arg
|
|
|
|
|
|
def local_devices():
|
|
# Tests require using not more than 2 devices.
|
|
return jax.local_devices()[:2]
|
|
|
|
|
|
ignore_jit_of_pmap_warning = partial(
|
|
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
|
|
|
|
|
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase,
|
|
expected: str, what: str):
|
|
"""A variant that preprocesses the string to eliminate non-determinism in
|
|
floating point values, and several uninteresting id_tap primitive params.
|
|
"""
|
|
|
|
# Sometimes we get floating points in the output; we round them
|
|
def repl_floats(match_group):
|
|
matched = match_group.group(0)
|
|
if matched == ".": return matched
|
|
x = np.around(float(matched), decimals=2)
|
|
return f"{x:.2f}"
|
|
|
|
what = re.sub(r"\-?\d+\.[\-\def]*", repl_floats, what)
|
|
what = re.sub(r"output_stream=[^\]\n,]*,?", "", what)
|
|
what = re.sub(r"threshold=[^\]\n,]*,?", "", what)
|
|
what = re.sub(r"bwd=[^\]\n]*", "", what)
|
|
what = re.sub(r"out_trees=[^\]\n]*", "", what)
|
|
what = re.sub(r"fwd_jaxpr_thunk=[^\]\n]*", "", what)
|
|
what = re.sub(r"jvp_jaxpr_thunk=[^\]\n]*", "", what)
|
|
# Empty lines
|
|
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
|
|
|
|
def repl_func(match_group):
|
|
matched = match_group.group(3)
|
|
if "function _print_consumer" in matched:
|
|
return match_group.group(1) + "=_print"
|
|
else:
|
|
return match_group.group(1) + "=..."
|
|
|
|
what = re.sub(r"((tap_func_)|(callback))=([^\]\n,]*),?", repl_func, what)
|
|
tst.assertMultiLineStrippedEqual(expected, what)
|
|
|
|
|
|
def helper_set_hlo_dump():
|
|
flags_str = os.getenv("XLA_FLAGS", "")
|
|
import shutil
|
|
dump_dir = "/tmp/xla_dump"
|
|
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to={dump_dir}"
|
|
if os.path.isdir(dump_dir):
|
|
logging.warning("Deleting old XLA dump directory %s", dump_dir)
|
|
shutil.rmtree(dump_dir)
|
|
logging.warning("Setting XLA dump directory %s", dump_dir)
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
def helper_print_optimized_hlo(fun, *args):
|
|
backend = xla_bridge.get_backend(platform=jtu.device_under_test())
|
|
c = jax.jit(fun, backend=backend.platform).lower(*args)
|
|
logging.info(re.sub(r", metadata.*", "", c.compile().as_text()))
|
|
|
|
|
|
def helper_log_ir(name,
|
|
f_jax,
|
|
*args,
|
|
num_partitions=None,
|
|
strip_metadata=False):
|
|
logging.info(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
|
|
jax_comp = f_jax.lower(*args)
|
|
logging.info(f"HLO[{name}]: {jax_comp.compiler_ir(dialect='hlo').as_hlo_text()}")
|
|
jax_optimized_hlo = jax_comp.compile().as_text()
|
|
if strip_metadata:
|
|
jax_optimized_hlo = re.sub(r", metadata.*", "", jax_optimized_hlo)
|
|
logging.info(f"Optimized HLO[{name}]: {jax_optimized_hlo}")
|
|
|
|
|
|
prev_xla_flags = None
|
|
|
|
|
|
def setUpModule():
|
|
global prev_xla_flags
|
|
# This will control the CPU devices. On TPU we always have 2 devices
|
|
prev_xla_flags = jtu.set_host_platform_device_count(2)
|
|
|
|
|
|
# Reset to previous configuration in case other test modules will be run.
|
|
def tearDownModule():
|
|
prev_xla_flags()
|
|
|
|
|
|
def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
|
|
expected_2CPUs: str):
|
|
"""Check that the multi-device output is equal to the expected.
|
|
|
|
The tests run with 2 devices if available, otherwise 1 device.
|
|
We adjust the expected output here for 1 device.
|
|
|
|
Args:
|
|
expected_2CPUs: the expected output for 2 CPUs. If there is only
|
|
one device, this is trimmed to the first device. If the current
|
|
device_under_test is not a CPU, then we change the names
|
|
"""
|
|
expected = expected_2CPUs
|
|
if len(local_devices()) == 1:
|
|
start_device_1 = expected.find('device: cpu:1')
|
|
if start_device_1 >= 0:
|
|
expected = expected[0:start_device_1]
|
|
|
|
def replace_device_name(m) -> str:
|
|
return str(local_devices()[int(m.group(1))])
|
|
|
|
expected = re.sub(r'cpu:(\d+)', replace_device_name, expected)
|
|
what = testing_stream.output_sorted_by_device
|
|
return assertMultiLineStrippedEqual(tst, expected, what)
|
|
|
|
|
|
class HostCallbackTapTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
|
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise SkipTest("host_callback not implemented in PJRT C API")
|
|
|
|
testing_stream.reset()
|
|
testing_stream._test_method_name = self._testMethodName
|
|
self.old_flags = os.getenv("XLA_FLAGS", "")
|
|
|
|
def tearDown(self) -> None:
|
|
if os.getenv("XLA_FLAGS") != self.old_flags:
|
|
os.environ["XLA_FLAGS"] = self.old_flags
|
|
xla_bridge.get_backend.cache_clear()
|
|
hcb.barrier_wait("HostCallbackTapTest.tearDown")
|
|
super().tearDown()
|
|
|
|
def test_tap_eval(self):
|
|
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: a * 2
|
|
10.00
|
|
what: y * 3
|
|
30.00""", testing_stream.output)
|
|
|
|
def test_tap_with_tuple_results(self):
|
|
def func2(x):
|
|
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
|
|
return x1 + y1
|
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
( 6.00 9.00 )""", testing_stream.output)
|
|
|
|
def test_tap_with_dict_results(self):
|
|
def func2(x):
|
|
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
|
|
return res["a"] + res["b"]
|
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
{ a=6.00 b=9.00 }""", testing_stream.output)
|
|
|
|
def test_tap_with_result(self):
|
|
def func2(x):
|
|
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
|
output_stream=testing_stream)
|
|
return x1
|
|
|
|
self.assertEqual(3. * 4., func2(3.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
( 6.00 9.00 )""", testing_stream.output)
|
|
|
|
def test_tap_with_result_no_arg(self):
|
|
def tap_func(arg, transforms):
|
|
testing_stream.write(f"called tap_func with {arg}")
|
|
|
|
def func2(x):
|
|
x1 = hcb.id_tap(tap_func, None, result=x)
|
|
return x1
|
|
|
|
self.assertEqual(3., func2(3.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, "called tap_func with None",
|
|
testing_stream.output)
|
|
|
|
def test_tap_result_unused(self):
|
|
def tap_func(arg, transforms):
|
|
testing_stream.write(f"called tap_func with {arg}")
|
|
def func2(x):
|
|
hcb.id_tap(tap_func, None)
|
|
return x
|
|
|
|
self.assertEqual(3., func2(3.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, "called tap_func with None",
|
|
testing_stream.output)
|
|
|
|
def test_tap_with_device(self):
|
|
def func2(x):
|
|
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
|
output_stream=testing_stream,
|
|
tap_with_device=True)
|
|
return x1
|
|
|
|
self.assertEqual(3. * 4., func2(3.))
|
|
hcb.barrier_wait()
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0
|
|
( 6.00 9.00 )""")
|
|
|
|
def test_tap_eval_exception(self):
|
|
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
|
raise SkipTest("TODO: implement error handling for customcall")
|
|
# Simulate a tap error
|
|
def tap_err(*args, **kwargs):
|
|
raise ValueError("Some user message")
|
|
|
|
def func(x):
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
x2 = hcb.id_tap(tap_err, x1 + 1)
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
return x3
|
|
|
|
with self.assertRaisesRegex(
|
|
hcb.CallbackException,
|
|
re.compile("There were exceptions during callback processing. Last one was:.*"
|
|
"ValueError: Some user message", re.DOTALL)):
|
|
func(0)
|
|
hcb.barrier_wait()
|
|
|
|
# We should have received everything before the error
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: x1
|
|
1
|
|
what: x3
|
|
3""", testing_stream.output)
|
|
|
|
def test_tap_empty(self):
|
|
"""Tap empty arrays."""
|
|
hcb.id_print((), output_stream=testing_stream)
|
|
hcb.id_print((1., np.ones((2, 0))), what="second", output_stream=testing_stream)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
( )
|
|
what: second
|
|
( 1.00 [] )""", testing_stream.output)
|
|
|
|
def test_tap_jit_simple(self):
|
|
jit_fun1 = jax.jit(lambda x: 3. * hcb.id_print(
|
|
2. * x, what="here", output_stream=testing_stream))
|
|
self.assertAllClose(6. * 5., jit_fun1(5.))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: here
|
|
10.00""", testing_stream.output)
|
|
|
|
def test_tap_jit_no_invars(self):
|
|
def func(): # jitted function does not take arguments
|
|
return hcb.id_print(42, output_stream=testing_stream)
|
|
|
|
self.assertAllClose(42, jax.jit(func)())
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
42""", testing_stream.output)
|
|
|
|
def test_tap_jit_multiple_invars(self):
|
|
def func(x1, x2):
|
|
return hcb.id_print(x1 + x2, output_stream=testing_stream)
|
|
|
|
self.assertAllClose(42, jax.jit(func)(40, 2))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
42""", testing_stream.output)
|
|
|
|
def test_tap_jit_constant(self):
|
|
def func(x):
|
|
return hcb.id_print(42, result=x, output_stream=testing_stream)
|
|
|
|
self.assertAllClose(5, jax.jit(func)(5))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
42""", testing_stream.output)
|
|
|
|
def test_tap_jit_sequence1(self):
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
logging.info("%s: %s", self._testMethodName,
|
|
jax.make_jaxpr(func)(1))
|
|
logging.info("%s: %s", self._testMethodName,
|
|
jax.xla_computation(func, backend=jtu.device_under_test())(1).as_hlo_text())
|
|
self.assertEqual(2, jax.jit(func)(1))
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2""", testing_stream.output)
|
|
|
|
def test_tap_jit2(self):
|
|
"""A sequence of JIT."""
|
|
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
return x2
|
|
|
|
self.assertEqual(2, jax.jit(func)(1))
|
|
self.assertEqual(11, jax.jit(func)(10))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2
|
|
where: 1
|
|
10
|
|
where: 2
|
|
11""", testing_stream.output)
|
|
|
|
def test_tap_jit_result_unused(self):
|
|
"""We can id_print even if we don't use the result."""
|
|
|
|
def func(x):
|
|
hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
hcb.id_print(x + 1, where="2", output_stream=testing_stream)
|
|
return x + 1
|
|
|
|
self.assertEqual(2, jax.jit(func)(1))
|
|
self.assertEqual(11, jax.jit(func)(10))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2
|
|
where: 1
|
|
10
|
|
where: 2
|
|
11""", testing_stream.output)
|
|
|
|
def test_tap_jit_nested(self):
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
def func_nested(x):
|
|
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
|
|
return x2
|
|
|
|
x3 = jax.jit(func_nested)(x1)
|
|
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
|
|
|
|
self.assertEqual(3, jax.jit(func)(1))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: nested
|
|
2
|
|
where: 3
|
|
3""", testing_stream.output)
|
|
|
|
def test_tap_jit_devices(self):
|
|
"""Running on multiple devices."""
|
|
logging.info("%s: has devices %s", self._testMethodName, local_devices())
|
|
|
|
def func(x, device_id):
|
|
x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream)
|
|
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
|
|
return x2
|
|
|
|
for d in local_devices():
|
|
self.assertEqual(112, jax.jit(func, device=d, static_argnums=1)(111, d.id))
|
|
hcb.barrier_wait()
|
|
logging.info("%s: found output %s", self._testMethodName,
|
|
testing_stream.output)
|
|
self.assertEqual(
|
|
len(local_devices()), len(re.findall(r"111", testing_stream.output)))
|
|
self.assertEqual(
|
|
len(local_devices()), len(re.findall(r"112", testing_stream.output)))
|
|
|
|
@jtu.sample_product(with_jit=[True, False])
|
|
def test_tap_pytree(self, with_jit=False):
|
|
def func(x, what=""):
|
|
"""Returns some pytrees depending on x"""
|
|
if what == "pair_1_x":
|
|
return (1, x)
|
|
elif what == "pair_x_2x":
|
|
return (x, 2 * x)
|
|
elif what == "dict":
|
|
return dict(a=2 * x, b=3 * x)
|
|
else:
|
|
assert False
|
|
|
|
tap_count = 0
|
|
|
|
def tap_func(a, _, *, what=""):
|
|
nonlocal tap_count
|
|
tap_count += 1
|
|
self.assertEqual(func(5, what), a)
|
|
|
|
transform = jax.jit if with_jit else lambda f: f
|
|
for what in ("pair_1_x", "pair_x_2x", "dict"):
|
|
transformed = transform(
|
|
lambda x: hcb.id_tap(
|
|
partial(tap_func, what=what),
|
|
func(x, what),
|
|
result=func(x * 2, what))
|
|
)(5)
|
|
self.assertEqual(func(10, what), transformed)
|
|
hcb.barrier_wait() # Wait for receivers to be done
|
|
self.assertEqual(3, tap_count)
|
|
|
|
@jtu.sample_product(concurrent=[True, False])
|
|
def test_tap_multiple(self, concurrent=False):
|
|
"""Call id_tap multiple times, concurrently or in sequence. """
|
|
if concurrent and jtu.test_device_matches(["cpu", "gpu"]):
|
|
# TODO(necula): if there is device side concurrency, outfeeds from
|
|
# different computations can be interleaved. For example, it seems that
|
|
# on GPU if multiple host threads run a jit computation, the multiple
|
|
# computations are interleaved on the GPU. This can result in the outfeed
|
|
# trains being interleaved, which will trigger an error.
|
|
# The solution is to fix on GPU the receiving logic so that we can outfeed
|
|
# the train as one tuple, and receive it one piece as a time. Then the
|
|
# trains should be atomic.
|
|
# See also b/160692602.
|
|
raise SkipTest("concurrent id_tap not supported on CPU, GPU")
|
|
|
|
received = set()
|
|
count = 5
|
|
|
|
def pause_tap(idx, _):
|
|
received.add(int(idx))
|
|
logging.info("Starting do_tap %s. Sleeping 1sec ...", idx)
|
|
time.sleep(0.3)
|
|
logging.info("Finish do_tap %s", idx)
|
|
|
|
def do_tap(idx):
|
|
jax.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
|
|
|
|
if concurrent:
|
|
threads = [
|
|
threading.Thread(
|
|
name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,))
|
|
for idx in range(count)
|
|
]
|
|
[t.start() for t in threads]
|
|
[t.join() for t in threads]
|
|
else:
|
|
for idx in range(count):
|
|
do_tap(idx)
|
|
|
|
hcb.barrier_wait()
|
|
self.assertEqual(received, set(range(count)))
|
|
|
|
# TODO(necula): see comment for test_multiple_tap. Here we disable also
|
|
# on TPU, because the barrier_wait runs on all devices, including on the CPU
|
|
# where it would run into concurrency problems.
|
|
@skip("Concurrency not supported")
|
|
def test_tap_multiple_barriers(self):
|
|
"""Call barrier_wait concurrently."""
|
|
|
|
def pause_tap(*args, **kwargs):
|
|
logging.info("pause_tap waiting")
|
|
time.sleep(0.3)
|
|
logging.info("pause_tap done")
|
|
|
|
def long_run(x):
|
|
return hcb.id_tap(pause_tap, x)
|
|
|
|
jax.jit(long_run)(5.)
|
|
|
|
def try_barrier(idx):
|
|
logging.info("Starting test barrier %s", idx)
|
|
hcb.barrier_wait()
|
|
logging.info("Finished test barrier %s", idx)
|
|
|
|
threads = [
|
|
threading.Thread(
|
|
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
|
|
for idx in range(3)
|
|
]
|
|
[t.start() for t in threads]
|
|
[t.join() for t in threads]
|
|
|
|
@jtu.sample_product(with_jit=[True, False])
|
|
def test_tap_cond(self, with_jit=False):
|
|
"""A conditional"""
|
|
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
|
lambda x: hcb.id_print(x, where="cond_t",
|
|
output_stream=testing_stream),
|
|
lambda x: hcb.id_print(-1, where="cond_f", result=x,
|
|
output_stream=testing_stream),
|
|
x2 + 1)
|
|
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
|
|
return x5
|
|
|
|
transform = jax.jit if with_jit else lambda f: f
|
|
self.assertEqual(4, transform(func)(1))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2
|
|
where: cond_f
|
|
-1
|
|
where: end
|
|
4""", testing_stream.output)
|
|
|
|
@jtu.sample_product(with_jit=[True, False])
|
|
def test_tap_while_cond(self, with_jit=False):
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
def body(x):
|
|
x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream)
|
|
x4 = lax.cond(x % 2 == 0,
|
|
lambda x: hcb.id_print(x, where="w_b_t",
|
|
output_stream=testing_stream),
|
|
lambda x: hcb.id_print(-1, where="w_b_f",
|
|
result=x, output_stream=testing_stream),
|
|
x3 + 1)
|
|
return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream)
|
|
|
|
x10 = lax.while_loop(lambda x: x <= 3, body, x2)
|
|
res = hcb.id_print(x10, where="end", output_stream=testing_stream)
|
|
return res
|
|
|
|
transform = jax.jit if with_jit else lambda f: f
|
|
self.assertEqual(4, transform(func)(1))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2
|
|
where: w_b_1
|
|
2
|
|
where: w_b_t
|
|
3
|
|
where: w_b_2
|
|
3
|
|
where: w_b_1
|
|
3
|
|
where: w_b_f
|
|
-1
|
|
where: w_b_2
|
|
4
|
|
where: end
|
|
4""", testing_stream.output)
|
|
|
|
def test_tap_jit_while_pred_tap(self):
|
|
"""While with printing in the conditional."""
|
|
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1")
|
|
x10 = lax.while_loop(lambda x: hcb.id_print(x < 3,
|
|
where="w_p",
|
|
output_stream=testing_stream),
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
output_stream=testing_stream),
|
|
x1)
|
|
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
|
|
return res
|
|
|
|
self.assertEqual(3, jax.jit(func)(1))
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self,
|
|
"""
|
|
where: w_p
|
|
True
|
|
where: w_b
|
|
2
|
|
where: w_p
|
|
True
|
|
where: w_b
|
|
3
|
|
where: w_p
|
|
False
|
|
where: 3
|
|
3""", testing_stream.output)
|
|
|
|
@jtu.sample_product(with_jit=[True, False])
|
|
def test_tap_scan_cond(self, with_jit=True):
|
|
def func(x):
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
def body(c, x):
|
|
x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
|
|
x4 = lax.cond(x % 2 == 0,
|
|
lambda x: hcb.id_print(x, where="s_t", output_stream=testing_stream),
|
|
lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream),
|
|
x3 + 1)
|
|
return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream))
|
|
|
|
_, x10 = lax.scan(body, x2, jnp.arange(3))
|
|
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
|
|
return res
|
|
|
|
if with_jit:
|
|
func = jax.jit(func)
|
|
res = func(1)
|
|
self.assertAllClose(jnp.arange(1, 4), res)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
where: 1
|
|
1
|
|
where: 2
|
|
2
|
|
where: s_1
|
|
0
|
|
where: s_t
|
|
1
|
|
where: s_2
|
|
1
|
|
where: s_1
|
|
1
|
|
where: s_f
|
|
-1
|
|
where: s_2
|
|
2
|
|
where: s_1
|
|
2
|
|
where: s_t
|
|
3
|
|
where: s_2
|
|
3
|
|
where: 10
|
|
[1 2 3]""", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
@jtu.sample_product(
|
|
nr_args=[1, 2],
|
|
shape=[(), (2,), (2, 3), (2, 3, 4)],
|
|
dtype=jtu.dtypes.all,
|
|
)
|
|
def test_tap_jit_dtypes(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
|
|
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
|
|
raise SkipTest(f"host_callback not implemented for {dtype}.")
|
|
if dtype == np.bool_:
|
|
args = [self.rng().choice(a=[True, False], size=shape)]
|
|
else:
|
|
args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
|
|
if nr_args > 1:
|
|
args = args * nr_args
|
|
jit_fun1 = jax.jit(lambda xs: hcb.id_print(
|
|
xs,
|
|
a_new_test="************",
|
|
testcase_name=f"{shape=}_{dtype=}_{nr_args=}"))
|
|
|
|
res = jit_fun1(args)
|
|
self.assertAllClose(args, res, check_dtypes=True)
|
|
|
|
def test_tap_jit_large(self):
|
|
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
|
|
jax.jit(hcb.id_print)(arg)
|
|
|
|
def test_tap_jit_several_together(self):
|
|
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
|
|
jax.jit(lambda x, y: hcb.id_print((x, y, x * 2)))(arg, jnp.ones(100, dtype=jnp.int32))
|
|
|
|
def test_tap_jit_interleaving(self):
|
|
# Several jit's without data dependencies; they may interfere
|
|
count = 0 # Count tap invocations
|
|
nr_arrays = 5
|
|
|
|
def tap_func(arg, _):
|
|
nonlocal count
|
|
assert len(arg) == nr_arrays
|
|
count += 1
|
|
|
|
# This is the function that we'll run multiple times
|
|
def func(x, count):
|
|
for i in range(count):
|
|
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)])[-1]
|
|
return x
|
|
|
|
x = jnp.array(1, dtype=np.int32)
|
|
res = 0
|
|
for _ in range(10):
|
|
# No dependencies between the jit invocations
|
|
res += jax.jit(lambda x: func(x, 10))(x)
|
|
hcb.barrier_wait()
|
|
self.assertEqual(100, count)
|
|
|
|
def test_tap_jit_tap_exception(self):
|
|
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
|
raise SkipTest("TODO: implement error handling for customcall")
|
|
# Simulate a tap error
|
|
def tap_err(*args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def func(x):
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
x2 = hcb.id_tap(tap_err, x1 + 1)
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
return x3
|
|
|
|
res = jax.jit(func)(0) # No error yet
|
|
with self.assertRaises(hcb.CallbackException):
|
|
hcb.barrier_wait()
|
|
|
|
# Even though the receiver thread raised, the main thread should still
|
|
# return 3.
|
|
self.assertEqual(3, res)
|
|
# We should have received all others
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: x1
|
|
1
|
|
what: x3
|
|
3""", testing_stream.output)
|
|
|
|
def test_tap_while(self):
|
|
"""Executing while, even without JIT uses compiled code"""
|
|
y = jnp.ones(5) # captured const
|
|
|
|
def func(x):
|
|
return lax.while_loop(
|
|
lambda c: c[1] < 5,
|
|
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
|
(x, 1))
|
|
|
|
func(y)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
1
|
|
2
|
|
3
|
|
4""", testing_stream.output)
|
|
|
|
def test_tap_jvp(self):
|
|
jvp_fun1 = lambda x, xt: jax.jvp(fun1, (x,), (xt,))
|
|
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
|
self.assertAllClose(100., res_primals, check_dtypes=False)
|
|
self.assertAllClose(4., res_tangents, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: a * 2
|
|
10.00
|
|
what: y * 3
|
|
30.00""", testing_stream.output)
|
|
|
|
def test_tap_grad_primal_unused(self):
|
|
# The output of id_print is not needed for backwards pass
|
|
def func(x):
|
|
return 2. * hcb.id_print(x * 3., what="x * 3",
|
|
output_stream=testing_stream)
|
|
|
|
grad_func = jax.grad(func)
|
|
arg = jnp.float32(5.)
|
|
jaxpr = str(jax.make_jaxpr(grad_func)(arg))
|
|
# making the Jaxpr does not print anything
|
|
hcb.barrier_wait()
|
|
|
|
treedef = tree_util.tree_structure(arg)
|
|
assertMultiLineStrippedEqual(
|
|
self, f"""
|
|
{{ lambda ; a:f32[]. let
|
|
b:f32[] = mul a 3.00
|
|
c:f32[] = outside_call[
|
|
arg_treedef={treedef}
|
|
callback=...
|
|
device_index=0
|
|
identity=True
|
|
] b
|
|
_:f32[] = mul 2.00 c
|
|
d:f32[] = mul 2.00 1.00
|
|
e:f32[] = mul d 3.00
|
|
in (e,) }}""", jaxpr)
|
|
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
res_grad = grad_func(arg)
|
|
hcb.barrier_wait()
|
|
|
|
self.assertAllClose(6., res_grad, check_dtypes=False)
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: x * 3
|
|
15.00""", testing_stream.output)
|
|
|
|
def test_tap_grad_simple(self):
|
|
def func(x):
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
|
return x * hcb.id_print(y * 3., what="y * 3",
|
|
output_stream=testing_stream)
|
|
|
|
grad_func = jax.grad(func)
|
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
|
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: x * 2
|
|
10.00
|
|
what: y * 3
|
|
30.00""", testing_stream.output)
|
|
|
|
def test_tap_grad_grad(self):
|
|
def func(x):
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
|
return x * (y * 3.)
|
|
|
|
grad_func = jax.grad(jax.grad(func))
|
|
# making the Jaxpr does not print anything
|
|
_ = jax.make_jaxpr(grad_func)(5.)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
|
|
|
self.assertAllClose(12., res_grad, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: x * 2
|
|
10.00""", testing_stream.output)
|
|
|
|
def test_tap_grad_pytree(self):
|
|
def func(x):
|
|
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
|
|
result=(x * 4., x * 5.),
|
|
output_stream=testing_stream)
|
|
return x4 + 2. * x5
|
|
|
|
x = jnp.float32(5.)
|
|
grad_func = jax.grad(func)
|
|
print(jax.make_jaxpr(grad_func)(x))
|
|
res_grad = grad_func(x)
|
|
self.assertAllClose(14., res_grad, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: pair
|
|
( 10.00 15.00 )""", testing_stream.output)
|
|
|
|
def test_tap_jvp_float0(self):
|
|
def f(x, yint):
|
|
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
|
|
return x * yint
|
|
|
|
res = jax.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
|
|
self.assertAllClose((6., 0.6), res)
|
|
|
|
def test_tap_grad_float0(self):
|
|
def func(x, yint):
|
|
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
|
|
return x * yint.astype(x.dtype)
|
|
|
|
grad_func = jax.grad(func)
|
|
|
|
res_grad = grad_func(jnp.float32(5.), jnp.int32(2))
|
|
self.assertAllClose(2., res_grad, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
what: pair
|
|
( 5.00 2 )""", testing_stream.output)
|
|
|
|
def test_tap_grad_float0_result(self):
|
|
# https://github.com/google/jax/issues/7340
|
|
# x is a Tuple[f32[2], s32[3]]
|
|
x = (np.array([.7, .8], dtype=np.float32),
|
|
np.array([11, 12, 13], dtype=np.int32))
|
|
def f_jax(x):
|
|
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
|
|
return (3. * x[0], x[1])
|
|
|
|
def f_jax_vjp(x):
|
|
res, pullback = jax.vjp(f_jax, x)
|
|
g, = pullback((np.ones(x[0].shape, dtype=x[0].dtype),
|
|
np.zeros(x[1].shape, dtype=dtypes.float0)))
|
|
return g
|
|
|
|
g = f_jax_vjp(x)
|
|
self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0])
|
|
self.assertEqual(dtypes.float0, g[1].dtype)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
|
|
|
def test_tap_higher_order_grad_float0_result(self):
|
|
# https://github.com/google/jax/issues/7340
|
|
# x is a Tuple[f32[2], s32[3]]
|
|
x = (np.array([.7, .8], dtype=np.float32),
|
|
np.array([11, 12, 13], dtype=np.int32))
|
|
def f_jax(x):
|
|
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
|
|
return (jnp.sin(x[0]), x[1])
|
|
|
|
def wrap_vjp(f, args, res_f_of_args):
|
|
# Given a function "f" and "args" return the f_vjp and args_vjp
|
|
def make_ct(res):
|
|
res_dtype = np.result_type(res)
|
|
if res_dtype == dtypes.float0:
|
|
return res
|
|
ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype)
|
|
return np.ones(np.shape(res), dtype=ct_dtype)
|
|
cts = tree_util.tree_map(make_ct, res_f_of_args)
|
|
def f_vjp(args, cts):
|
|
res, pullback = jax.vjp(f, *args)
|
|
return pullback(cts)
|
|
return (f_vjp, (args, cts))
|
|
|
|
res = f_jax(x)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
# 1st order
|
|
f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res)
|
|
res_vjp1 = f_jax_vjp1(*args_vjp1)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
( [0.70 0.80] [11 12 13] )""", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
# 2nd order
|
|
f_jax_vjp2, args_vjp2 = wrap_vjp(f_jax_vjp1, args_vjp1, res_vjp1)
|
|
res_vjp2 = f_jax_vjp2(*args_vjp2)
|
|
|
|
# 3rd order
|
|
f_jax_vjp3, args_vjp3 = wrap_vjp(f_jax_vjp2, args_vjp2, res_vjp2)
|
|
_ = f_jax_vjp3(*args_vjp3)
|
|
|
|
def test_tap_vmap(self):
|
|
vmap_fun1 = jax.vmap(fun1)
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
|
vmap_fun1(vargs)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
|
|
[ 8.00 10.00]
|
|
transforms: [('batch', {'batch_dims': (0,)})] what: y * 3
|
|
[24.00 30.00]""", testing_stream.output)
|
|
|
|
def test_tap_vmap_not_batched(self):
|
|
x = 3.
|
|
|
|
def func(y):
|
|
# x is not mapped, y is mapped
|
|
_, y = hcb.id_print((x, y), output_stream=testing_stream)
|
|
return x + y
|
|
|
|
vmap_func = jax.vmap(func)
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
|
_ = vmap_func(vargs)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
transforms: [('batch', {'batch_dims': (None, 0)})]
|
|
( 3.00 [4.00 5.00] )""", testing_stream.output)
|
|
|
|
def test_tap_vmap_vmap(self):
|
|
# A 2D tensor with x[i, j] = i + j using 2 vmap
|
|
def sum(x, y):
|
|
return hcb.id_print(x + y, output_stream=testing_stream)
|
|
|
|
def sum_rows(xv, y):
|
|
return jax.vmap(sum, in_axes=(0, None))(xv, y)
|
|
|
|
def sum_all(xv, yv):
|
|
return jax.vmap(sum_rows, in_axes=(None, 0))(xv, yv)
|
|
|
|
xv = jnp.arange(5, dtype=np.int32)
|
|
yv = jnp.arange(3, dtype=np.int32)
|
|
# assertMultiLineStrippedEqual(self, "", str(jax.make_jaxpr(sum_all)(xv, yv)))
|
|
_ = sum_all(xv, yv)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
|
|
[[0 1 2 3 4]
|
|
[1 2 3 4 5]
|
|
[2 3 4 5 6]]""", testing_stream.output)
|
|
|
|
def test_tap_vmap_while(self):
|
|
"""Vmap of while."""
|
|
|
|
def func(x):
|
|
# like max(x, 2)
|
|
x1 = hcb.id_print(x, where="before:x", output_stream=testing_stream)
|
|
x2 = lax.while_loop(
|
|
lambda x: x < 2, lambda x: hcb.id_print(
|
|
x + 1, where="body:x+1", output_stream=testing_stream), x1)
|
|
res = hcb.id_print(x2, where="after:x", output_stream=testing_stream)
|
|
return res
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
|
self.assertAllClose(
|
|
np.array([2, 2, 2, 3, 4]),
|
|
jax.jit(jax.vmap(func))(inputs),
|
|
check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(
|
|
self, """
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: before:x
|
|
[0 1 2 3 4]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
|
|
[1 2 3 4 5]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: body:x+1
|
|
[2 3 3 4 5]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: after:x
|
|
[2 2 2 3 4]""", testing_stream.output)
|
|
|
|
def test_tap_vmap_while_tap_cond(self):
|
|
"""Vmap of while, with a tap in the conditional."""
|
|
|
|
def func(x):
|
|
# like max(x, 2)
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
|
|
output_stream=testing_stream),
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
output_stream=testing_stream),
|
|
x1)
|
|
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
|
return res
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
|
res = jax.jit(jax.vmap(func))(inputs)
|
|
hcb.barrier_wait()
|
|
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
|
|
assertMultiLineStrippedEqual(self, """
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 1
|
|
[0 1 2 3 4]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
[ True True False False False]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
[1 2 3 4 5]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
[ True False False False False]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
[2 3 3 4 5]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
[False False False False False]
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 3
|
|
[2 2 2 3 4]""", testing_stream.output)
|
|
|
|
def test_tap_transforms_doc(self):
|
|
# Examples from the documentation
|
|
def power3(x):
|
|
y = x * x
|
|
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
|
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
|
return y * x
|
|
|
|
print(f"impl = {power3(3.)}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
print(f"jvp = {jax.jvp(power3, (3.,), (0.1,))}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
@jax.custom_jvp
|
|
def print_tangents(arg):
|
|
return None
|
|
|
|
@print_tangents.defjvp
|
|
def print_tangents_jvp(primals, tangents):
|
|
arg_dot, = tangents
|
|
hcb.id_print(arg_dot, what="tangents", output_stream=testing_stream)
|
|
return primals, tangents
|
|
|
|
def power3_with_tangents(x):
|
|
y = x * x
|
|
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
|
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
|
print_tangents((x, y))
|
|
return y * x
|
|
|
|
print(f"jvp = {jax.jvp(power3_with_tangents, (3.,), (0.1,))}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )
|
|
what: tangents
|
|
( 0.1 0.6 )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
print(f"grad = {jax.grad(power3)(3.)}")
|
|
hcb.barrier_wait()
|
|
# Only the primals by default
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
@jax.custom_vjp
|
|
def print_cotangents(arg):
|
|
# Must return the argument for which we want the cotangent.
|
|
return arg
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
def print_cotangents_fwd(arg):
|
|
return print_cotangents(arg), None
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
def print_cotangents_bwd(residual, ct_b):
|
|
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
|
|
return ct_b,
|
|
|
|
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
|
|
|
|
def power3_with_cotangents(x):
|
|
y = x * x
|
|
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
|
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
|
|
# Must use the output of print_cotangents
|
|
(x1, y1) = print_cotangents((x, y))
|
|
return y1 * x1
|
|
|
|
print(f"grad = {jax.grad(power3_with_cotangents)(3.)}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )
|
|
what: cotangents
|
|
( 9. 3. )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
# TODO: grad of grad
|
|
|
|
print(f"vmap = {jax.vmap(power3)(np.array([2., 3.]))}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
|
( [2. 3.] [4. 9.] )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
print(f"vmap o grad {jax.vmap(jax.grad(power3))(np.array([2., 3.]))}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
|
( [2. 3.] [4. 9.] )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
print(f"vmap o grad {jax.vmap(jax.grad(power3_with_cotangents))(np.array([2., 3.]))}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
transforms: [('batch', {'batch_dims': (0, 0)})] what: x,x^2
|
|
( [2. 3.] [4. 9.] )
|
|
transforms: [('batch', {'batch_dims': (0, 0)})] what: cotangents
|
|
( [4. 9.] [2. 3.] )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
print(f"grad o remat = {jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)}")
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: x,x^2
|
|
( 3. 9. )
|
|
what: x,x^2
|
|
( 27. 729. )
|
|
what: x,x^2
|
|
( 3. 9. )"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
def test_tap_pmap(self):
|
|
if len(local_devices()) < 2:
|
|
raise SkipTest("test requires at least 2 devices")
|
|
|
|
def power3(x):
|
|
y = x * x
|
|
# Print both 'x' and 'x^2'. Must pack as a tuple.
|
|
_, y = hcb.id_print((x, y),
|
|
what="x,x^2",
|
|
output_stream=testing_stream,
|
|
tap_with_device=True)
|
|
return y * x
|
|
|
|
pmap_power3 = jax.pmap(power3, devices=local_devices())
|
|
xv = np.array([3, 4], dtype=np.int32)
|
|
res = pmap_power3(xv)
|
|
hcb.barrier_wait()
|
|
self.assertAllClose(xv * xv * xv, res, check_dtypes=False)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(
|
|
self, """
|
|
device: cpu:0 what: x,x^2
|
|
( 3 9 )
|
|
device: cpu:1 what: x,x^2
|
|
( 4 16 )""")
|
|
|
|
def test_tap_pmap_vmap(self):
|
|
# A matrix M[ij] = i * 10 + j
|
|
nr_devices = len(local_devices())
|
|
shape = (nr_devices, 3)
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
dtype=np.int32)
|
|
|
|
def fun1(x, do_print=False): # x: i32
|
|
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
|
|
|
|
pmap_vmap_fun1 = jax.pmap(
|
|
jax.vmap(partial(fun1, do_print=True)), devices=local_devices())
|
|
|
|
res = pmap_vmap_fun1(matrix)
|
|
hcb.barrier_wait()
|
|
expected_res = jax.pmap(
|
|
jax.vmap(partial(fun1, do_print=False)), devices=local_devices())(
|
|
matrix)
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[0.00 2.00 4.00]
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[20.00 22.00 24.00]""")
|
|
|
|
def test_tap_pmap_pmap_vmap(self):
|
|
# A matrix M[ijk] = i * 100 + j * 10 + k
|
|
nr_devices = len(local_devices())
|
|
if nr_devices % 2 != 0:
|
|
raise SkipTest("test works only on even number of devices")
|
|
|
|
shape = (2, nr_devices // 2, 3)
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun1(x, do_print=False): # x: f32
|
|
y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)
|
|
return y ** 2
|
|
|
|
pmap_fun1 = jax.pmap(
|
|
jax.pmap(jax.vmap(partial(fun1, do_print=True))),
|
|
devices=local_devices())
|
|
res = pmap_fun1(matrix)
|
|
hcb.barrier_wait()
|
|
expected_res = jax.pmap(
|
|
jax.pmap(jax.vmap(partial(fun1, do_print=False))),
|
|
devices=local_devices())(
|
|
matrix)
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[0.00 2.00 4.00]
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[200.00 202.00 204.00]""")
|
|
|
|
@ignore_jit_of_pmap_warning()
|
|
def test_tap_pmap_pmap_extra(self):
|
|
"""pmap of a pmap surrounded by extra code."""
|
|
# A matrix M[ij] = i * 10 + j
|
|
nr_devices = len(local_devices())
|
|
if nr_devices != 2:
|
|
raise SkipTest("test works only on 2 devices")
|
|
shape = (2, 1, 3)
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
# This will be printed on all devices, with shape [1, 3]
|
|
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
|
|
res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
|
|
# This will be printed on all devices, with shape [1, 3]
|
|
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
|
|
|
|
res = jax.pmap(partial(fun, do_print=True))(matrix)
|
|
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 what: before
|
|
[[1.00 2.00 3.00]]
|
|
device: cpu:0 what: inside
|
|
[2.00 4.00 6.00]
|
|
device: cpu:0 what: after
|
|
[[3.00 5.00 7.00]]
|
|
device: cpu:1 what: before
|
|
[[101.00 102.00 103.00]]
|
|
device: cpu:1 what: inside
|
|
[202.00 204.00 206.00]
|
|
device: cpu:1 what: after
|
|
[[203.00 205.00 207.00]]""")
|
|
|
|
def test_tap_jvp_pmap_vmap(self):
|
|
# A matrix M[ijk] = i * 100 + j * 10 * k
|
|
nr_devices = len(local_devices())
|
|
shape = (nr_devices, 2, 3)
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
# x: f32[3]
|
|
return jax.jvp(jax.pmap(jax.vmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True))),
|
|
(xv,), (.1 * jnp.ones_like(xv),))
|
|
|
|
res = fun(matrix, do_print=True)
|
|
hcb.barrier_wait()
|
|
expected_res = fun(matrix, do_print=False)
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :]
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[[ 0.00 2.00 4.00]
|
|
[20.00 22.00 24.00]]
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[[200.00 202.00 204.00]
|
|
[220.00 222.00 224.00]]""")
|
|
|
|
def test_tap_vmap_pmap(self):
|
|
# A matrix M[ijk] = i * 100 + j * 10 * k
|
|
nr_devices = len(local_devices())
|
|
shape = (2, nr_devices, 3)
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
# x: f32[3]
|
|
return jax.vmap(jax.pmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)))(xv)
|
|
|
|
res = fun(matrix, do_print=True)
|
|
hcb.barrier_wait()
|
|
expected_res = fun(matrix, do_print=False)
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
# Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[:, 0, :]
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[[ 0.00 2.00 4.00]
|
|
[200.00 202.00 204.00]]
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
[[ 20.00 22.00 24.00]
|
|
[220.00 222.00 224.00]]""")
|
|
|
|
@ignore_jit_of_pmap_warning()
|
|
def test_tap_jit_pmap_extra(self):
|
|
"""jit of a pmap surrounded by extra code."""
|
|
# A matrix M[ij] = i * 10 + j
|
|
nr_devices = len(local_devices())
|
|
assert nr_devices in (1, 2)
|
|
shape = (nr_devices, 3)
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
# This will be printed on all devices with shape (nr_devices, 3)
|
|
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
|
|
res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
|
|
# This will be printed on all devices with shape (nr_devices, 3)
|
|
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
|
|
|
|
res = jax.jit(partial(fun, do_print=True))(matrix)
|
|
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
if len(local_devices()) == 2:
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 what: before
|
|
[[ 1.00 2.00 3.00]
|
|
[11.00 12.00 13.00]]
|
|
device: cpu:0 what: inside
|
|
[2.00 4.00 6.00]
|
|
device: cpu:0 what: after
|
|
[[ 3.00 5.00 7.00]
|
|
[23.00 25.00 27.00]]
|
|
device: cpu:1 what: before
|
|
[[ 1.00 2.00 3.00]
|
|
[11.00 12.00 13.00]]
|
|
device: cpu:1 what: inside
|
|
[22.00 24.00 26.00]
|
|
device: cpu:1 what: after
|
|
[[ 3.00 5.00 7.00]
|
|
[23.00 25.00 27.00]]""")
|
|
else:
|
|
assert len(local_devices()) == 1
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0 what: before
|
|
[[1.00 2.00 3.00]]
|
|
device: cpu:0 what: inside
|
|
[2.00 4.00 6.00]
|
|
device: cpu:0 what: after
|
|
[[3.00 5.00 7.00]]""")
|
|
|
|
@unittest.skip("cond of pmap does not work in JAX. Issue #5178.")
|
|
def test_tap_cond_pmap(self):
|
|
# A matrix M[ij] = i * 10 + j
|
|
nr_devices = len(local_devices())
|
|
shape = (nr_devices, 3)
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
dtype=np.float32)
|
|
|
|
def fun1(x, do_print=False):
|
|
return maybe_print(do_print, x * 2., "x * 2")
|
|
|
|
def fun2(cond, xv, do_print=False):
|
|
return lax.cond(cond, jax.pmap(partial(fun1, do_print=do_print)),
|
|
lambda xv: xv, xv)
|
|
|
|
res = fun2(True, matrix)
|
|
self.assertAllClose(fun2(True, matrix, do_print=False), res, check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
assertMultiLineStrippedEqual(self, """
|
|
TBD""", testing_stream.output)
|
|
|
|
@jtu.sample_product(device_index=[0, 1])
|
|
def test_tap_pjit(self, device_index=0):
|
|
if (device_index != 0 and
|
|
not hcb._HOST_CALLBACK_OUTFEED.value and
|
|
jtu.test_device_matches(["cpu"])):
|
|
# See comment in host_callback.py.
|
|
raise SkipTest("device_index works only with outfeed on CPU")
|
|
|
|
devices = np.array(local_devices())
|
|
nr_devices = len(devices)
|
|
if nr_devices < 2:
|
|
raise SkipTest("test requires at least 2 devices")
|
|
|
|
logging.info(f"test_tap_pjit is running on devices {devices}.")
|
|
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
|
|
# y: i32[3, 4]
|
|
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
|
|
y = jnp.ones((3, 4), np.int32)
|
|
|
|
@partial(jax.named_call, name="fun1") # for xprof debugging
|
|
def fun1(x):
|
|
z = jnp.dot(x, y)
|
|
return hcb.id_print(z, what="z",
|
|
output_stream=testing_stream,
|
|
tap_with_device=True, device_index=device_index)
|
|
|
|
pjit_fun1 = pjit.pjit(fun1, in_shardings=(P("d"),), out_shardings=P("d"))
|
|
|
|
with jax.sharding.Mesh(devices, ["d"]):
|
|
# Print the internal IR
|
|
helper_log_ir(
|
|
f"{self._testMethodName}.pjit",
|
|
pjit_fun1,
|
|
x,
|
|
num_partitions=nr_devices)
|
|
res = pjit_fun1(x)
|
|
|
|
self.assertAllClose(jnp.dot(x, y), res)
|
|
hcb.barrier_wait("before check")
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
# Note that a single call is made.
|
|
assertMultiDeviceOutputEqual(
|
|
self, f"""
|
|
device: cpu:{device_index} what: z
|
|
[[ 3 3 3 3]
|
|
[33 33 33 33]]""")
|
|
|
|
def test_tap_scan_custom_jvp(self):
|
|
"""custom JVP, inside scan.
|
|
This exercises the custom_jvp_call_jaxpr primitives."""
|
|
|
|
@jax.custom_jvp
|
|
def f(x):
|
|
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
x_dot, = tangents
|
|
primal_out = f(x)
|
|
tangent_out = 3. * x * hcb.id_print(x_dot, output_stream=testing_stream, what="x_dot")
|
|
return primal_out, tangent_out
|
|
|
|
def g(x):
|
|
# Sum f(x_i)
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
x)[0]
|
|
|
|
arg = np.full((2,), 0.7)
|
|
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
|
|
hcb.barrier_wait()
|
|
self.assertMultiLineStrippedEqual("""
|
|
what: x
|
|
0.7
|
|
what: x
|
|
0.7""", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
self.assertMultiLineStrippedEqual("""
|
|
what: x
|
|
0.7
|
|
what: x
|
|
0.7
|
|
transforms: ['transpose'] what: x_dot
|
|
2.1
|
|
transforms: ['transpose'] what: x_dot
|
|
2.1""", testing_stream.output)
|
|
|
|
def test_tap_scan_custom_vjp(self):
|
|
"""custom VJP, inside scan.
|
|
This exercises the custom_vjp_call_jaxpr primitives."""
|
|
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
def f_fwd(x):
|
|
return f(x), 3. * x
|
|
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
def f_bwd(residual, ct_b):
|
|
return residual * hcb.id_print(ct_b, output_stream=testing_stream, what="ct_b"),
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
def g(x):
|
|
# Sum f(x_i)
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
x)[0]
|
|
|
|
arg = np.full((2,), 0.7)
|
|
|
|
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
|
|
hcb.barrier_wait()
|
|
self.assertMultiLineStrippedEqual("""
|
|
what: x
|
|
0.7
|
|
what: x
|
|
0.7""", testing_stream.output)
|
|
testing_stream.reset()
|
|
|
|
self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False)
|
|
hcb.barrier_wait()
|
|
self.assertMultiLineStrippedEqual("""
|
|
what: x
|
|
0.7
|
|
what: x
|
|
0.7
|
|
what: ct_b
|
|
1.
|
|
what: ct_b
|
|
1.""", testing_stream.output)
|
|
|
|
def test_tap_callback_delay(self):
|
|
hcb.callback_extra = lambda dev: time.sleep(1)
|
|
|
|
def func(x):
|
|
for i in range(5):
|
|
x = hcb.id_print(x * i, what="x times i")
|
|
return x
|
|
|
|
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
|
|
def test_tap_callback_delay_barrier(self):
|
|
hcb.callback_extra = lambda dev: time.sleep(2)
|
|
|
|
def func(x):
|
|
for i in range(1, 4):
|
|
x = hcb.id_print(x * i, what=f"x times {i}", output_stream=testing_stream)
|
|
return x
|
|
|
|
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
# Wait for the results
|
|
hcb.barrier_wait("first")
|
|
expected = """
|
|
what: x times 1
|
|
[[0. 1. 2.]
|
|
[3. 4. 5.]]
|
|
what: x times 2
|
|
[[ 0. 2. 4.]
|
|
[ 6. 8. 10.]]
|
|
what: x times 3
|
|
[[ 0. 6. 12.]
|
|
[18. 24. 30.]]"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
testing_stream.reset()
|
|
# Call again
|
|
jax.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
hcb.barrier_wait("second")
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
def test_tap_error_bad_consumer_id(self):
|
|
"""Try to use reserved consumer ID 0.
|
|
|
|
Check that we get the proper error from the runtime."""
|
|
if not hcb._use_outfeed(jtu.device_under_test()):
|
|
raise SkipTest("test works only for outfeed")
|
|
comp = xla_client.XlaBuilder(self._testMethodName)
|
|
token = hcb.xops.CreateToken(comp)
|
|
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Consumer ID cannot be a reserved value: 0"):
|
|
hcb._callback_handler_data.receiver.add_outfeed(
|
|
comp, token, 0,
|
|
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0)
|
|
|
|
def test_tap_error_different_shapes(self):
|
|
"""Try to register different shapes for the same consumer ID."""
|
|
if not hcb._use_outfeed(jtu.device_under_test()):
|
|
raise SkipTest("test works only for outfeed")
|
|
comp = xla_client.XlaBuilder(self._testMethodName)
|
|
token = hcb.xops.CreateToken(comp)
|
|
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
|
hcb._callback_handler_data.receiver.add_outfeed(
|
|
comp, token, 123,
|
|
[xops.Constant(comp, np.zeros((2, 3), dtype=np.float32))], 0)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, ".*does not match previous shape element_type.*"):
|
|
hcb._callback_handler_data.receiver.add_outfeed(
|
|
comp, token, 123,
|
|
[xops.Constant(comp, np.zeros((2, 3), dtype=np.int32))], 0)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, ".*does not match previous shape element_type.*"):
|
|
hcb._callback_handler_data.receiver.add_outfeed(
|
|
comp, token, 123,
|
|
[xops.Constant(comp, np.zeros((2,), dtype=np.float32))], 0)
|
|
|
|
def test_tap_id_tap_removed_kwargs(self):
|
|
def func(x, transforms, y):
|
|
pass
|
|
|
|
with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"):
|
|
hcb.id_tap(func, 1, y=2)
|
|
|
|
def test_tap_id_tap_random_key(self):
|
|
# See https://github.com/google/jax/issues/13949
|
|
with jax.enable_custom_prng():
|
|
@jax.jit
|
|
def f(x):
|
|
def tap(tap_x, _): pass
|
|
return hcb.id_tap(tap, x, result=x)
|
|
f(jax.random.PRNGKey(123))
|
|
|
|
def test_tap_odeint(self):
|
|
# TODO: find a smaller repro for bug #4015
|
|
# Seems to be xla_call(scan(xla_call)), all under grad.
|
|
from jax.experimental.ode import odeint
|
|
|
|
def f(x, t, k):
|
|
x = hcb.id_print(x)
|
|
return -k * x
|
|
|
|
def loss(k=1.0):
|
|
t = jnp.linspace(0, 0.001, num=2)
|
|
xs = odeint(f, 1.0, t, k)
|
|
return xs[-1]
|
|
|
|
jax.grad(loss)(1.0) # should not fail
|
|
|
|
def test_tap_remat_0(self):
|
|
def f(i, k):
|
|
x = hcb.id_print(k + i, output_stream=testing_stream)
|
|
return k * x
|
|
|
|
def loss(k):
|
|
return lax.fori_loop(0, 2, jax.remat(f), k)
|
|
|
|
print(loss(3))
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
3
|
|
10"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
@jtu.sample_product(
|
|
use_result=[True, False],
|
|
grad_func=["grad", "value_and_grad"],
|
|
use_remat=["old", "new", "none"],
|
|
)
|
|
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
|
if use_remat == "old": raise SkipTest()
|
|
|
|
def f(x):
|
|
id_print_result = hcb.id_print(x, output_stream=testing_stream)
|
|
if use_result:
|
|
x = id_print_result
|
|
return 3. * x
|
|
grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad
|
|
if use_remat == "old":
|
|
trans_f = jax.remat(f)
|
|
elif use_remat == "new":
|
|
trans_f = ad_checkpoint.checkpoint(f)
|
|
else:
|
|
assert use_remat == "none"
|
|
trans_f = f
|
|
print(jax.make_jaxpr(grad_f(trans_f))(2.))
|
|
grad_f(trans_f)(2.)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
if use_remat == "none":
|
|
# GOOD: whether or not we use_result, we get the same callback.
|
|
expected = "2."
|
|
else: # use_remat
|
|
if use_result:
|
|
expected = """
|
|
2.
|
|
2."""
|
|
else:
|
|
if use_remat == "old":
|
|
# TODO: we should see two callbacks
|
|
expected = ""
|
|
else:
|
|
# Good: we see two callbacks, whether or not we use the result.
|
|
expected = """
|
|
2.
|
|
2."""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
def test_tap_named_call(self):
|
|
def tap_scalar(init, do_print=False):
|
|
@partial(jax.named_call, name="step")
|
|
def step(acc, step_nr):
|
|
acc = acc + step_nr
|
|
maybe_print(do_print, step_nr, what="step_nr")
|
|
return acc, None
|
|
|
|
return lax.scan(step, init, np.arange(2))
|
|
|
|
self.assertAllClose(tap_scalar(3, do_print=False), tap_scalar(3, do_print=True))
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
what: step_nr
|
|
0
|
|
what: step_nr
|
|
1"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
|
|
class HostCallbackCallTest(jtu.JaxTestCase):
|
|
"""Tests for hcb.call"""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
|
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise SkipTest("host_callback not implemented in PJRT C API")
|
|
|
|
testing_stream.reset()
|
|
testing_stream._test_method_name = self._testMethodName
|
|
|
|
def tearDown(self) -> None:
|
|
hcb.barrier_wait("HostCallbackCallTest.tearDown")
|
|
super().tearDown()
|
|
|
|
def call_log_testing_stream(self, func, arg, *, result_shape, name=""):
|
|
"""Call `func` and log inputs and outputs to the testing stream"""
|
|
|
|
def call_log(arg):
|
|
def val2str(v):
|
|
return np.array2string(np.array(arg))
|
|
testing_stream.write(f"Call {name}({val2str(arg)})\n")
|
|
res = func(arg)
|
|
testing_stream.write(f" = {val2str(res)}\n")
|
|
return res
|
|
return hcb.call(call_log, arg, result_shape=result_shape)
|
|
|
|
def test_call_simple(self):
|
|
|
|
def f_outside(x):
|
|
return 2 * x
|
|
|
|
def fun(x):
|
|
y = hcb.call(f_outside, x + 1, result_shape=x)
|
|
return 3 * (1 + y)
|
|
|
|
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
|
|
self.assertAllClose(3 * (1 + 2 * (arg + 1)), fun(arg))
|
|
|
|
def test_primitive_compilation(self):
|
|
|
|
def f_outside(x):
|
|
return 2 * x
|
|
|
|
def fun(x):
|
|
return hcb.call(f_outside, x, result_shape=x)
|
|
|
|
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
|
|
with jtu.count_primitive_compiles() as count:
|
|
for _ in range(3):
|
|
self.assertAllClose(2 * arg, fun(arg))
|
|
self.assertEqual(count[0], 1)
|
|
|
|
@jtu.sample_product(
|
|
dtype=[dtype for dtype in jtu.dtypes.all if dtype != np.bool_],
|
|
)
|
|
def test_call_types(self, dtype=np.float64):
|
|
|
|
def f_outside(x):
|
|
# Use x + x to ensure that the result type is the same
|
|
return x + x
|
|
|
|
def fun(x):
|
|
return hcb.call(f_outside, x + x, result_shape=x)
|
|
|
|
arg = np.arange(24, dtype=dtype).reshape((2, 3, 4))
|
|
self.assertAllClose(arg + arg + arg + arg, fun(arg), check_dtypes=True)
|
|
|
|
def test_call_types_bool(self, dtype=np.float64):
|
|
|
|
def f_outside(x):
|
|
return np.invert(x)
|
|
|
|
def fun(x):
|
|
return hcb.call(f_outside, x, result_shape=x)
|
|
|
|
arg = self.rng().choice(a=[True, False], size=(2, 3, 4))
|
|
self.assertAllClose(np.invert(arg), fun(arg))
|
|
|
|
def test_call_tuples(self):
|
|
|
|
def f_outside(args):
|
|
x, y = args
|
|
return y, x # Swap the tuple
|
|
|
|
def fun(x):
|
|
xy = hcb.call(f_outside, (x, x + 1), result_shape=(x, x))
|
|
return 2 * xy[0] + 3 * xy[1]
|
|
|
|
arg = np.arange(24, dtype=np.int32).reshape((2, 3, 4))
|
|
self.assertAllClose(2 * (arg + 1) + 3 * arg, fun(arg))
|
|
|
|
def test_call_no_arg(self):
|
|
"""Call with no arguments."""
|
|
result = np.ones((2,), dtype=np.float32)
|
|
def f_outside(in_tuple):
|
|
assert len(in_tuple) == 0
|
|
return result
|
|
def fun(x):
|
|
return x + hcb.call(f_outside, (),
|
|
result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype))
|
|
self.assertAllClose(2. + result, fun(2.))
|
|
|
|
def test_call_empty_arg(self):
|
|
"""Call with empty array."""
|
|
result = np.full((2,), 3., dtype=np.float32)
|
|
def f_outside(x0): # x0: f32[2, 0]
|
|
return result
|
|
x0 = np.ones((2, 0), dtype=np.float32)
|
|
def fun(x):
|
|
return x + hcb.call(f_outside, x0,
|
|
result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype))
|
|
self.assertAllClose(2. + result, fun(2.))
|
|
|
|
def test_call_empty_arg_inside_pytree(self):
|
|
"""Call taking tuple with an empty array and a non-empty one."""
|
|
x0 = np.ones((2, 0), dtype=np.float32)
|
|
x1 = np.full((2,), 3., dtype=np.float32)
|
|
result = x1
|
|
def f_outside(in_tuple): # x0: f32[2, 0] x1: f32[2]
|
|
return in_tuple[1]
|
|
|
|
def fun(x):
|
|
res = hcb.call(f_outside, (x0, x1),
|
|
result_shape=jax.ShapeDtypeStruct(result.shape, result.dtype))
|
|
return x + res
|
|
self.assertAllClose(2. + result, fun(2.))
|
|
|
|
def test_call_empty_result(self):
|
|
"""Call returning empty array."""
|
|
result_shape = (2, 0)
|
|
def f_outside(_):
|
|
return np.ones(result_shape, dtype=np.float32)
|
|
def fun(x):
|
|
return x + hcb.call(f_outside, 1.,
|
|
result_shape=jax.ShapeDtypeStruct(result_shape, np.float32))
|
|
self.assertAllClose(f_outside(0.), fun(2.))
|
|
|
|
def test_call_empty_result_inside_pytree(self):
|
|
"""Call returning a tuple with an empty array and a non-empty one."""
|
|
result_shape_0 = (2, 0)
|
|
result_shape_2 = (0,)
|
|
def f_outside(_):
|
|
return (np.ones(result_shape_0, dtype=np.float32),
|
|
np.ones((1,), dtype=np.float32),
|
|
np.ones(result_shape_2, dtype=np.float32))
|
|
def fun(x):
|
|
res = hcb.call(f_outside, 1.,
|
|
result_shape=(jax.ShapeDtypeStruct(result_shape_0, np.float32),
|
|
jax.ShapeDtypeStruct((1,), np.float32),
|
|
jax.ShapeDtypeStruct(result_shape_2, np.float32)))
|
|
self.assertEqual(result_shape_0, res[0].shape)
|
|
self.assertEqual(result_shape_2, res[2].shape)
|
|
return x + res[1]
|
|
self.assertAllClose(2 + np.ones((1,), dtype=np.float32), fun(2.))
|
|
|
|
def test_call_empty_result_all_pytree(self):
|
|
"""Call returning a tuple of empty arrays."""
|
|
result_shape = (2, 0)
|
|
def f_outside(_):
|
|
return (np.ones(result_shape, dtype=np.float32),
|
|
np.ones(result_shape, dtype=np.float32))
|
|
def fun(x):
|
|
res = hcb.call(f_outside, 1.,
|
|
result_shape=(jax.ShapeDtypeStruct(result_shape, np.float32),
|
|
jax.ShapeDtypeStruct(result_shape, np.float32)))
|
|
return x + res[0] + res[1]
|
|
self.assertAllClose(np.ones(result_shape, dtype=np.float32),
|
|
fun(2.))
|
|
|
|
def test_call_no_result(self):
|
|
def f_outside(arg):
|
|
self.call_log_testing_stream(lambda x: None, arg,
|
|
result_shape=None,
|
|
name="outside")
|
|
return arg
|
|
|
|
self.assertAllClose((3., 4.), f_outside((3., 4.)))
|
|
hcb.barrier_wait()
|
|
expected = """
|
|
Call outside([3. 4.])
|
|
= [3. 4.]"""
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
def test_call_cond(self):
|
|
def f_outside(args):
|
|
x, y = args
|
|
return x * y
|
|
|
|
def loop(x, use_outside=True):
|
|
def body(i, acc):
|
|
return lax.cond(i % 2 == 1,
|
|
lambda _: (hcb.call(f_outside, (acc, i),
|
|
result_shape=acc)
|
|
if use_outside else f_outside((acc, i))),
|
|
lambda _: acc,
|
|
None)
|
|
|
|
return lax.fori_loop(0, 18, body, x)
|
|
|
|
res_inside = loop(1.2, use_outside=False)
|
|
self.assertAllClose(res_inside, jax.jit(loop)(1.2))
|
|
|
|
def test_call_jit_scan_call(self):
|
|
def f_outside(x):
|
|
return x
|
|
|
|
def loop(x, use_outside=True):
|
|
def body(carry, i):
|
|
if use_outside:
|
|
return carry + hcb.call(f_outside, i,
|
|
result_shape=i), None
|
|
else:
|
|
return carry + i, None
|
|
|
|
return lax.scan(body, 0, x)
|
|
|
|
x = np.arange(5, dtype=np.int32)
|
|
|
|
res_outside = jax.jit(partial(loop, use_outside=True))(x)
|
|
self.assertAllClose(res_outside, loop(x, use_outside=False))
|
|
|
|
def test_call_doc_example1(self):
|
|
"""Examples from the documentation: simplest, call a function"""
|
|
|
|
def host_eig(x):
|
|
return np.linalg.eigvals(x)
|
|
|
|
shape = (2, 5, 4, 4)
|
|
|
|
m = np.ones(shape, dtype=np.float32)
|
|
|
|
def fun(m):
|
|
eig_m = hcb.call(host_eig, m,
|
|
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
|
|
return eig_m
|
|
|
|
expected_res = np.linalg.eigvals(m)
|
|
self.assertAllClose(expected_res, fun(m))
|
|
@jtu.skip_on_devices("gpu")
|
|
def test_call_doc_example_hlo(self):
|
|
"""Examples from the documentation: simplest, call a function."""
|
|
|
|
def fun1(m):
|
|
return jnp.sin(hcb.call(lambda x: np.cos,
|
|
jnp.cos(m),
|
|
result_shape=m))
|
|
|
|
m = np.ones((2,), np.float32)
|
|
helper_print_optimized_hlo(fun1, m)
|
|
|
|
def fun2(m):
|
|
x = hcb.call(lambda x: None, 2, result_shape=())
|
|
return x
|
|
|
|
m = np.ones((2,), np.float32)
|
|
helper_print_optimized_hlo(fun2, m)
|
|
|
|
def test_call_with_device(self):
|
|
def callback_func(x, device=None):
|
|
testing_stream.write(f"device: {device}\n Called with {x}")
|
|
return x
|
|
|
|
def func(x):
|
|
return hcb.call(callback_func, x,
|
|
result_shape=x,
|
|
call_with_device=True)
|
|
|
|
self.assertEqual(3., func(3.))
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0
|
|
Called with 3.00""")
|
|
|
|
def test_call_pmap(self):
|
|
# Works for 1 or 2 devices
|
|
def callback_func(x, device=None):
|
|
testing_stream.write(f"device: {device}\n Called with {x}")
|
|
return x * np.array(3, np.int32)
|
|
|
|
def fun(x): # x: i32
|
|
return hcb.call(callback_func, x * 2,
|
|
result_shape=x,
|
|
call_with_device=True)
|
|
|
|
xv = jnp.arange(len(local_devices()), dtype=jnp.int32)
|
|
res = jax.pmap(fun)(xv)
|
|
self.assertAllClose(jax.pmap(lambda x: x * 6)(xv), res)
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(self, """
|
|
device: cpu:0
|
|
Called with 0
|
|
device: cpu:1
|
|
Called with 2""")
|
|
|
|
def test_call_vmap(self):
|
|
def f_outside(x): return x
|
|
|
|
def fun(x):
|
|
return hcb.call(f_outside, x, result_shape=x)
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
"batching rules are implemented only for id_tap, not for call"):
|
|
jax.vmap(fun)(np.ones((2, 3)))
|
|
|
|
@jtu.sample_product(device_index=[0, 1])
|
|
@jtu.skip_on_devices("cpu") # TODO: RET_CHECK failure
|
|
def test_call_pjit(self, device_index=0):
|
|
devices = np.array(local_devices())
|
|
nr_devices = len(devices)
|
|
if nr_devices < 2:
|
|
raise SkipTest("test requires at least 2 devices")
|
|
|
|
logging.info(f"test_call_pjit is running on devices {devices}.")
|
|
# x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...]
|
|
# y: i32[3, 4]
|
|
x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3]
|
|
y = jnp.ones((3, 4), np.int32)
|
|
|
|
def callback_x5_func(x, device=None):
|
|
testing_stream.write(f"device: {device}\n Called with {x}")
|
|
return x * np.array(5, np.int32)
|
|
|
|
def fun(x):
|
|
xy = jnp.dot(x, y)
|
|
return hcb.call(
|
|
callback_x5_func, xy, result_shape=xy, call_with_device=True,
|
|
device_index=device_index)
|
|
|
|
pjit_fun = pjit.pjit(fun, in_shardings=(P("d"),), out_shardings=P("d"))
|
|
with jax.sharding.Mesh(devices, ["d"]):
|
|
# Print the internal IR
|
|
helper_log_ir(
|
|
f"{self._testMethodName}.pjit",
|
|
pjit_fun,
|
|
x,
|
|
num_partitions=nr_devices)
|
|
|
|
res = pjit_fun(x)
|
|
|
|
expected_res = jnp.dot(x, y) * np.array(5, np.int32)
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
hcb.barrier_wait("before assertion")
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
assertMultiDeviceOutputEqual(
|
|
self, f"""
|
|
device: cpu:{device_index}
|
|
Called with [[ 3 3 3 3]
|
|
[33 33 33 33]]""")
|
|
|
|
def test_call_error_bad_result_shape(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"):
|
|
hcb.call(lambda x: x, 3., result_shape="string")
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"The values must be either numeric scalars, or must have 'shape' and 'dtype' attributes"):
|
|
hcb.call(lambda x: x, 3., result_shape=lambda x: x)
|
|
hcb.barrier_wait("wait for error")
|
|
|
|
def helper_check_callback_errors(self, thunk: Callable,
|
|
expected_exc_txt: str):
|
|
"""Calls thunk() and checks for expected exceptions.
|
|
"""
|
|
if jtu.test_device_matches(["cpu"]):
|
|
# On CPU the runtime crashes, and the tests are all aborted
|
|
raise SkipTest("TODO: CPU runtime crashes on unexpected infeed")
|
|
elif jtu.test_device_matches(["gpu"]):
|
|
# On GPU we get a nice error back to Python
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"(.* Mismatch between infeed source buffer shape s8.12345."
|
|
"|.*The destination shape does not match the source shape.)"):
|
|
thunk()
|
|
elif jtu.test_device_matches(["tpu"]):
|
|
# On TPU we get no error!!!
|
|
raise SkipTest("TODO: TPU runtime does not check infeed, and just computes with garbage")
|
|
|
|
# Both on GPU and TPU we also get an error during the barrier_wait at the
|
|
# end of the test. Run a barrier_wait now, to consume that error.
|
|
with self.assertRaisesRegex(
|
|
hcb.CallbackException,
|
|
re.compile(
|
|
"There were exceptions during callback processing.*Last one was:.*" +
|
|
expected_exc_txt,
|
|
re.DOTALL)):
|
|
hcb.barrier_wait("Waiting for error")
|
|
|
|
def test_call_error_callback_throws_exception(self):
|
|
def f_outside(x):
|
|
raise ValueError("user exception")
|
|
def fun(x):
|
|
return hcb.call(f_outside, x, result_shape=x)
|
|
|
|
self.helper_check_callback_errors(lambda: fun(3.),
|
|
"ValueError: user exception")
|
|
|
|
def test_call_error_callback_returns_unexpected_shape(self):
|
|
def fun(x):
|
|
return hcb.call(lambda x: (x, x), x, result_shape=x)
|
|
|
|
self.helper_check_callback_errors(lambda: fun(3.),
|
|
"Callback func .* should have returned a result with pytree")
|
|
|
|
def test_call_error_then_compute(self):
|
|
# Continue computation on device after error
|
|
def f_outside(x):
|
|
raise ValueError("user exception")
|
|
def fun(x):
|
|
x1 = hcb.call(f_outside, x, result_shape=x)
|
|
return x1
|
|
arg = np.arange(3, dtype=np.int32)
|
|
self.helper_check_callback_errors(lambda: self.assertAllClose(arg, fun(arg)),
|
|
"ValueError: user exception")
|
|
|
|
|
|
def call_jax_other_device(jax_outside_fun, arg, *, device):
|
|
"""Calls a JAX function on a specific device with simple support for reverse AD.
|
|
|
|
Functions whose name starts with "jax_outside" are called on another device,
|
|
by way of hcb.call.
|
|
"""
|
|
|
|
def run_jax_outside_fun(arg):
|
|
return jax.jit(jax_outside_fun)(jax.device_put(arg, device))
|
|
|
|
@jax.custom_vjp
|
|
def make_call(arg):
|
|
return hcb.call(run_jax_outside_fun, arg,
|
|
result_shape=jax.eval_shape(jax_outside_fun, arg))
|
|
|
|
# Define the fwd and bwd custom_vjp functions
|
|
def make_call_vjp_fwd(arg):
|
|
# Return the primal argument as the residual. Use `make_call` for the
|
|
# primal computation to enable higher-order AD.
|
|
return make_call(arg), arg # Return the primal argument as the residual
|
|
|
|
def make_call_vjp_bwd(res, ct_res):
|
|
arg = res # residual is the primal argument
|
|
|
|
def jax_outside_vjp_fun(arg_and_ct):
|
|
arg, ct = arg_and_ct
|
|
_, f_vjp = jax.vjp(jax_outside_fun, arg)
|
|
ct_in, = f_vjp(ct)
|
|
return ct_in
|
|
|
|
return (call_jax_other_device(jax_outside_vjp_fun, (arg, ct_res), device=device),)
|
|
|
|
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
|
return make_call(arg)
|
|
|
|
|
|
class CallJaxTest(jtu.JaxTestCase):
|
|
"""Tests using `call_jax_other_device`."""
|
|
|
|
def setUp(self):
|
|
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
|
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise SkipTest("host_callback not implemented in PJRT C API")
|
|
|
|
if not jtu.test_device_matches(["cpu"]):
|
|
assert jax.devices("cpu")
|
|
self.outside_device = jax.devices("cpu")[0]
|
|
else:
|
|
if len(jax.devices("cpu")) == 1:
|
|
raise SkipTest("Test needs at least two devices. On CPU use XLA_FLAGS=--xla_force_host_platform_device_count=2")
|
|
self.outside_device = jax.devices("cpu")[1]
|
|
super().setUp()
|
|
|
|
def test_jax_impl(self):
|
|
def f_jax(x):
|
|
return jnp.sin(x)
|
|
|
|
def f_outside(x):
|
|
return call_jax_other_device(f_jax, x, device=self.outside_device)
|
|
|
|
self.assertAllClose(f_jax(3.), f_outside(3.))
|
|
self.assertAllClose(f_jax(3.), jax.jit(f_outside)(3.))
|
|
|
|
def test_jax_impl_pytree(self):
|
|
def f_jax(x):
|
|
# x : dict(a=..., b=...) and output is a list of two elements
|
|
return [jnp.sin(x["a"]), jnp.sin(x["b"])]
|
|
|
|
def f_outside(x):
|
|
return call_jax_other_device(f_jax, x, device=self.outside_device)
|
|
|
|
x = dict(a=3., b=4.)
|
|
res_jax = f_jax(x)
|
|
# print(f"outside_jaxpr = {jax.make_jaxpr(f_outside)(x)}")
|
|
res_outside = f_outside(x)
|
|
self.assertAllClose(res_jax, res_outside)
|
|
|
|
def test_jax_grad(self):
|
|
def f_jax(x):
|
|
return 2. * jnp.sin(x)
|
|
|
|
def f_outside(x):
|
|
return 2. * call_jax_other_device(jnp.sin, x, device=self.outside_device)
|
|
|
|
res_jax = jax.grad(f_jax)(3.)
|
|
self.assertAllClose(res_jax, jax.grad(f_outside)(3.))
|
|
|
|
def test_jax_grad_pytree(self):
|
|
def f_jax(x):
|
|
# x : dict(a=..., b=...) and output is a float
|
|
return 3. * jnp.sin(x["a"]) + jnp.sin(x["b"])
|
|
|
|
def f_outside(x):
|
|
return call_jax_other_device(f_jax, x, device=self.outside_device)
|
|
|
|
x = dict(a=3., b=4.)
|
|
res_jax = jax.grad(f_jax)(x)
|
|
self.assertAllClose(res_jax, jax.grad(f_outside)(x))
|
|
|
|
def test_jax_grad_of_grad(self):
|
|
def f_jax(x):
|
|
return 2. * x * x * x
|
|
|
|
def f_outside(x):
|
|
return 2. * call_jax_other_device(lambda x: x * x * x, x, device=self.outside_device)
|
|
|
|
res_jax = jax.grad(jax.grad(f_jax))(5.)
|
|
res_outside = jax.grad(jax.grad(f_outside))(5.)
|
|
self.assertAllClose(res_jax, res_outside)
|
|
|
|
|
|
class OutfeedRewriterTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
|
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
|
if xla_bridge.using_pjrt_c_api():
|
|
raise SkipTest("host_callback not implemented in PJRT C API")
|
|
super().setUp()
|
|
|
|
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
|
has_input_token=True, has_output_token=True):
|
|
"""Check that the rewrite of func(*args) matches expected."""
|
|
jaxpr = jax.make_jaxpr(func)(*args)
|
|
rewritten = hcb._rewrite_closed_jaxpr(jaxpr, # noqa: F841
|
|
has_input_token, has_output_token)
|
|
# Since it is somewhat annoying to update the Jaxpr assertions when we change
|
|
# the Jaxpr printing, we do not check these by default. It is recommended that
|
|
# before making changes to the code generation and Jaxpr rewriting, turn on
|
|
# the checking, update the expected Jaxpr, and then make the changes.
|
|
# assertMultiLineStrippedEqual(self, expected, str(rewritten))
|
|
del rewritten
|
|
|
|
def test_no_outfeed(self):
|
|
self.assertRewrite("""
|
|
{ lambda ; a.
|
|
let b = mul a a
|
|
c = add a b
|
|
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False,
|
|
has_output_token=False)
|
|
self.assertRewrite("""
|
|
{ lambda ; a d e.
|
|
let b = mul a a
|
|
c = add a b
|
|
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
|
|
self.assertRewrite("""
|
|
{ lambda ; a d e.
|
|
let b = mul a a
|
|
c = add a b
|
|
in (c, d, e) }""", lambda x: x + x * x, [0])
|
|
|
|
def test_simple_outfeed(self):
|
|
self.assertRewrite("""
|
|
{ lambda ; a d e.
|
|
let b = add a a
|
|
c f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] b d e
|
|
in (c, f, g) }""", lambda x: hcb.id_print(x + x), [0])
|
|
|
|
def test_simple_outfeed_without_input_token(self):
|
|
self.assertRewrite("""
|
|
{ lambda ; a b.
|
|
let e = create_token a b
|
|
f = create_token a b
|
|
c = add a b
|
|
d g h = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] c e f
|
|
in (d,) }""", lambda x1, x2: hcb.id_print(x1 + x2), [1, 2],
|
|
has_input_token=False, has_output_token=False)
|
|
|
|
def test_simple_outfeed_without_input_token_nor_invars(self):
|
|
self.assertRewrite("""
|
|
{ lambda ; .
|
|
let b = create_token
|
|
c = create_token
|
|
a d e = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] 42 b c
|
|
in (a,) }""", lambda: hcb.id_print(42), [],
|
|
has_input_token=False, has_output_token=False)
|
|
|
|
def test_multiple_tap_without_dependencies(self):
|
|
def f(x):
|
|
hcb.id_print(x, what="x")
|
|
hcb.id_print(x + 1, what="x + 1")
|
|
return 2
|
|
|
|
self.assertRewrite("""
|
|
{ lambda ; a c d.
|
|
let _ e f = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a c d
|
|
b = add a 1
|
|
_ g h = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] b e f
|
|
in (2, g, h) }""", f, [1])
|
|
|
|
def test_cond(self):
|
|
y = jnp.ones(5) # captured const
|
|
|
|
def func(x, z):
|
|
return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)),
|
|
z, lambda a: (hcb.id_print(a), y))
|
|
|
|
self.assertRewrite("""
|
|
{ lambda a ; b c h i.
|
|
let d = gt c 0
|
|
e = convert_element_type[ new_dtype=int32 ] d
|
|
f g j k =
|
|
cond[ branches=( { lambda ; a b c d f g.
|
|
let e h i = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] d f g
|
|
in (e, a, h, i) }
|
|
{ lambda ; f_ a b c g h.
|
|
let d = broadcast_in_dim[ broadcast_dimensions=( )
|
|
shape=(5,) ] 0.00
|
|
in (a, d, g, h) } )
|
|
linear=(False, False, False, False, False, False) ] e a 1 2 c h i
|
|
in (f, g, j, k) }""", func, [y, 5])
|
|
|
|
def test_while(self):
|
|
ct_body = jnp.ones(5, np.float32) # captured const for the body
|
|
ct_cond = jnp.ones(5, np.float32) # captured const for the conditional
|
|
|
|
def func(x):
|
|
# x: f32[5]
|
|
# c: (f32[5], f32)
|
|
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
|
|
(x, np.float32(1.)))
|
|
|
|
self.assertRewrite("""
|
|
{ lambda a b ; c f g.
|
|
let d e h i =
|
|
while[ body_jaxpr={ lambda ; a b c f g.
|
|
let d h i = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] c f g
|
|
e = add d 1.00
|
|
in (a, e, h, i) }
|
|
body_nconsts=1
|
|
cond_jaxpr={ lambda ; a b c g h.
|
|
let d = add b a
|
|
e = reduce_sum[ axes=(0,) ] d
|
|
f = lt c e
|
|
in (f,) }
|
|
cond_nconsts=1 ] a b c 1.00 f g
|
|
in (d, e, h, i) }""", func, [ct_body])
|
|
|
|
def test_while_pred_outfeed(self):
|
|
"""A while with outfeed in the pred."""
|
|
ct_body = jnp.ones(5) # captured const for the body
|
|
ct_cond = jnp.ones(2) # captured const for the conditional
|
|
|
|
def func(x):
|
|
return lax.while_loop(lambda c: hcb.id_print(ct_cond, result=c[1]) < 5,
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
|
|
(x, 1))
|
|
|
|
self.assertRewrite("""
|
|
{ lambda a b ; c f g.
|
|
let j k l = xla_call[ call_jaxpr={ lambda ; a b c g h.
|
|
let d i j = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a g h
|
|
e = id_tap_dep c d
|
|
f = lt e 5
|
|
in (f, i, j) }
|
|
donated_invars=(False, False, False, False, False)
|
|
name=cond_before ] a c 1 f g
|
|
bf d e h i =
|
|
while[ body_jaxpr={ lambda ; r s t u v w x.
|
|
let y z ba bb =
|
|
xla_call[ call_jaxpr={ lambda ; a b c f g.
|
|
let d h i = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] c f g
|
|
e = add d 1
|
|
in (a, e, h, i) }
|
|
donated_invars=(False, False, False, False, False)
|
|
name=body ] s u v w x
|
|
bc bd be =
|
|
xla_call[ call_jaxpr={ lambda ; a b c g h.
|
|
let d i j = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a g h
|
|
e = id_tap_dep c d
|
|
f = lt e 5
|
|
in (f, i, j) }
|
|
donated_invars=(False, False, False, False, False)
|
|
name=cond_body ] r y z ba bb
|
|
in (bc, y, z, bd, be) }
|
|
body_nconsts=2
|
|
cond_jaxpr={ lambda ; m n o p q.
|
|
let
|
|
in (m,) }
|
|
cond_nconsts=0 ] a b j c 1 k l
|
|
in (d, e, h, i) }""", func, [ct_body])
|
|
|
|
def test_scan(self):
|
|
y = jnp.ones(5) # captured const
|
|
|
|
def func(x):
|
|
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
|
|
|
|
self.assertRewrite("""
|
|
{ lambda a ; b f g.
|
|
let c d h i e =
|
|
scan[ jaxpr={ lambda ; a b c g h d.
|
|
let e f i j =
|
|
outside_call[ arg_treedef=PyTreeDef(tuple, [*,*])
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] b c g h
|
|
in (e, f, i, j, a) }
|
|
length=5
|
|
linear=(False, False, False, False, False, False)
|
|
num_carry=4
|
|
num_consts=1
|
|
reverse=False
|
|
unroll=1 ] a 1 2 f g b
|
|
in (c, d, e, h, i) }""", func, [y])
|
|
|
|
def test_scan_custom_jvp(self):
|
|
"""custom JVP, inside scan.
|
|
This exercises the custom_jvp_call_jaxpr primitives."""
|
|
|
|
@jax.custom_jvp
|
|
def f(x):
|
|
return x * hcb.id_print(x)
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
x_dot, = tangents
|
|
primal_out = f(x)
|
|
tangent_out = 3. * x * hcb.id_print(x_dot)
|
|
return primal_out, tangent_out
|
|
|
|
def g(x):
|
|
# Sum f(x_i)
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
x)[0]
|
|
|
|
arg = np.full((5,), 0.7)
|
|
self.assertRewrite("""
|
|
{ lambda ; a c d.
|
|
let b e f _ =
|
|
scan[ jaxpr={ lambda ; a e f b.
|
|
let c g h = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e.
|
|
let b f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a d e
|
|
c = mul a b
|
|
in (c, f, g) }
|
|
num_consts=0 ] b e f
|
|
d = add a c
|
|
in (d, g, h, 0.00) }
|
|
length=5
|
|
linear=(False, False, False, False)
|
|
num_carry=3
|
|
num_consts=0
|
|
reverse=False
|
|
unroll=1 ] 0.00 c d a
|
|
in (b, e, f) }""", g, [arg])
|
|
self.assertRewrite("""
|
|
{ lambda ; a d e.
|
|
let _ _ f g _ b =
|
|
scan[ jaxpr={ lambda ; a b h i c d.
|
|
let e j k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e.
|
|
let b f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a d e
|
|
c = mul a b
|
|
in (c, f, g) }
|
|
num_consts=0 ] c h i
|
|
f = add a e
|
|
g = mul c 3.00
|
|
in (f, *, j, k, 0.00, g) }
|
|
length=5
|
|
linear=(False, True, False, False, False, True)
|
|
num_carry=4
|
|
num_consts=0
|
|
reverse=False
|
|
unroll=1 ] 0.00 * d e a *
|
|
_ _ h i _ c =
|
|
scan[ jaxpr={ lambda ; a b g h c d.
|
|
let e = mul b d
|
|
f i j = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True
|
|
transforms=(('transpose',),) ] e g h
|
|
in (*, b, i, j, *, f) }
|
|
length=5
|
|
linear=(True, True, False, False, True, False)
|
|
num_carry=4
|
|
num_consts=0
|
|
reverse=True
|
|
unroll=1 ] * 1.00 f g * b
|
|
in (c, h, i) }""", jax.grad(g), [arg])
|
|
|
|
def test_scan_custom_vjp(self):
|
|
"""custom VJP, inside scan.
|
|
This exercises the custom_vjp_call_jaxpr primitives."""
|
|
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return x * hcb.id_print(x)
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
def f_fwd(x):
|
|
return f(x), 3. * x
|
|
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
def f_bwd(residual, ct_b):
|
|
return residual * hcb.id_print(ct_b),
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
def g(x):
|
|
# Sum f(x_i)
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
x)[0]
|
|
|
|
arg = np.full((2,), 0.7)
|
|
self.assertRewrite("""
|
|
{ lambda ; a c d.
|
|
let b e f _ =
|
|
scan[ jaxpr={ lambda ; a e f b.
|
|
let c g h = custom_vjp_call_jaxpr[
|
|
fun_jaxpr={ lambda ; a d e.
|
|
let b f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a d e
|
|
c = mul a b
|
|
in (c, f, g) }
|
|
num_consts=0
|
|
] b e f
|
|
d = add a c
|
|
in (d, g, h, 0.00) }
|
|
length=2
|
|
linear=(False, False, False, False)
|
|
num_carry=3
|
|
num_consts=0
|
|
reverse=False
|
|
unroll=1 ] 0.00 c d a
|
|
in (b, e, f) }""", g, [arg])
|
|
self.assertRewrite("""
|
|
{ lambda ; a d e.
|
|
let _ _ f g _ b =
|
|
scan[ jaxpr={ lambda ; a b h i c d.
|
|
let e j k = custom_vjp_call_jaxpr[
|
|
fun_jaxpr={ lambda ; a d e.
|
|
let b f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a d e
|
|
c = mul a b
|
|
in (c, f, g) }
|
|
num_consts=0
|
|
] c h i
|
|
f = add a e
|
|
g = mul c 3.00
|
|
in (f, *, j, k, 0.00, g) }
|
|
length=2
|
|
linear=(False, True, False, False, False, True)
|
|
num_carry=4
|
|
num_consts=0
|
|
reverse=False
|
|
unroll=1 ] 0.00 * d e a *
|
|
_ _ h i _ c =
|
|
scan[ jaxpr={ lambda ; a b g h c d.
|
|
let e i j = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] b g h
|
|
f = mul d e
|
|
in (*, b, i, j, *, f) }
|
|
length=2
|
|
linear=(True, True, False, False, True, False)
|
|
num_carry=4
|
|
num_consts=0
|
|
reverse=True
|
|
unroll=1 ] * 1.00 f g * b
|
|
in (c, h, i) }""", jax.grad(g), [arg])
|
|
|
|
def test_remat_loop(self):
|
|
def f(k, x):
|
|
x = hcb.id_print(k + x)
|
|
return -k * x
|
|
|
|
def loss(k):
|
|
return lax.fori_loop(0, 1, jax.remat(f), k)
|
|
|
|
self.assertRewrite("""
|
|
{ lambda ; a c d.
|
|
let _ _ b e f =
|
|
while[ body_jaxpr={ lambda ; a b c f g.
|
|
let d = add a 1
|
|
e h i = remat_call[ call_jaxpr={ lambda ; a b g h.
|
|
let c = add a b
|
|
d i j = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] c g h
|
|
e = neg a
|
|
f = mul e d
|
|
in (f, i, j) }
|
|
concrete=False
|
|
name=f ] a c f g
|
|
in (d, b, e, h, i) }
|
|
body_nconsts=0
|
|
cond_jaxpr={ lambda ; a b c e f.
|
|
let d = lt a b
|
|
in (d,) }
|
|
cond_nconsts=0 ] 0 1 a c d
|
|
in (b, e, f) }""", loss, [2])
|
|
|
|
def test_named_call(self):
|
|
def tap_scalar(init, do_print=False):
|
|
@partial(jax.named_call, name="step")
|
|
def step(acc, step_nr):
|
|
acc = acc + step_nr
|
|
maybe_print(do_print, step_nr, what="step_nr")
|
|
return acc, None
|
|
|
|
return lax.scan(step, init, np.arange(2, dtype=np.int32))
|
|
|
|
self.assertRewrite("""
|
|
{ lambda a ; b d e.
|
|
let c = scan[ jaxpr={ lambda ; a b.
|
|
let c = named_call[ call_jaxpr={ lambda ; a b.
|
|
let c = add a b
|
|
in (c,) }
|
|
name=step ] a b
|
|
in (c,) }
|
|
length=2
|
|
linear=(False, False)
|
|
num_carry=1
|
|
num_consts=0
|
|
reverse=False
|
|
unroll=1 ] b a
|
|
in (c, d, e) }""", tap_scalar, [np.int32(3)])
|
|
|
|
def test_pmap(self):
|
|
def f(xv):
|
|
jax.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
|
|
axis_name="i")(xv)
|
|
|
|
self.assertRewrite("""
|
|
{ lambda ; a b c.
|
|
let _ d e = xla_pmap[ axis_name=i
|
|
axis_size=1
|
|
backend=None
|
|
call_jaxpr={ lambda ; a d e.
|
|
let b f g = outside_call[ arg_treedef=*
|
|
callback=...
|
|
has_token=True
|
|
identity=True ] a d e
|
|
c = sin b
|
|
in (c, f, g) }
|
|
devices=None
|
|
donated_invars=(False, False, False)
|
|
global_axis_size=None
|
|
in_axes=(0, 0, 0)
|
|
name=<lambda>
|
|
out_axes=(0, 0, 0) ] a b c
|
|
in (d, e) }""", f, [np.array([2.], dtype=np.float32)])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|