2020-05-08 17:18:11 +03:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
import functools
|
2020-12-13 10:44:20 +02:00
|
|
|
import itertools
|
2020-05-08 17:18:11 +03:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import re
|
2020-07-04 18:12:58 +03:00
|
|
|
import threading
|
|
|
|
import time
|
2020-12-13 10:44:20 +02:00
|
|
|
from typing import Callable, Optional, Sequence
|
2020-05-08 17:18:11 +03:00
|
|
|
from unittest import SkipTest
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import lax
|
2020-05-10 19:54:46 +03:00
|
|
|
from jax import numpy as jnp
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax.config import config
|
|
|
|
from jax.experimental import host_callback as hcb
|
|
|
|
from jax.lib import xla_bridge
|
2020-08-18 10:17:38 -07:00
|
|
|
from jax.util import prod
|
2020-07-04 18:12:58 +03:00
|
|
|
import numpy as np
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class _TestingOutputStream(object):
|
|
|
|
"""Use as `output_stream` for tests."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._output = []
|
2020-07-04 18:12:58 +03:00
|
|
|
self.test_method_name = None
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def write(self, what: str) -> None:
|
2020-07-04 18:12:58 +03:00
|
|
|
print(f"output_stream[{self.test_method_name}]: {what}", end="")
|
2020-05-08 17:18:11 +03:00
|
|
|
self._output.append(what)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output(self):
|
|
|
|
return "".join(self._output)
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
@property
|
|
|
|
def output_sorted_by_device(self):
|
|
|
|
# Assume that the output is a sequence of strings including metadata
|
|
|
|
# and data, with metadata containing `device: xxx`
|
|
|
|
by_device = [] # each element is a pair (device, str_list)
|
|
|
|
for s in self._output:
|
|
|
|
m = re.match(r'.*device: (\S+)', s)
|
|
|
|
if m:
|
|
|
|
by_device.append((m.group(1), []))
|
|
|
|
by_device[-1][1].append(s)
|
|
|
|
|
|
|
|
sorted_by_device = sorted(by_device, key=lambda x: x[0])
|
|
|
|
return "\n".join(itertools.chain(*[s[1] for s in sorted_by_device]))
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def __str__(self):
|
|
|
|
return "TestingOutputStream"
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self._output = []
|
|
|
|
|
|
|
|
|
|
|
|
testing_stream = _TestingOutputStream()
|
|
|
|
|
|
|
|
|
|
|
|
def fun1(a):
|
|
|
|
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
|
|
|
|
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
|
|
|
|
return y**2 # Some computation to make the gradient interesting
|
|
|
|
|
|
|
|
|
|
|
|
def fun1_equiv(a): # Numerical equivalent of fun`
|
|
|
|
return (a * 2.)**2
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def maybe_print(do_print: bool, arg, what: str, tap_with_device: Optional[bool] = False):
|
|
|
|
"""Conditionally print on testing_string"""
|
|
|
|
if do_print:
|
|
|
|
return hcb.id_print(arg, what=what,
|
|
|
|
output_stream=testing_stream, tap_with_device=tap_with_device)
|
|
|
|
else:
|
|
|
|
return arg
|
|
|
|
|
|
|
|
ignore_jit_of_pmap_warning = functools.partial(
|
|
|
|
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase,
|
|
|
|
expected: str, what: str):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""A variant that preprocesses the string to eliminate non-determinism in
|
2020-07-04 18:12:58 +03:00
|
|
|
floating point values, and several uninteresting id_tap primitive params.
|
|
|
|
"""
|
2020-05-08 17:18:11 +03:00
|
|
|
# Sometimes we get floating points in the output; we round them
|
|
|
|
def repl_floats(match_group):
|
|
|
|
matched = match_group.group(0)
|
|
|
|
if matched == ".": return matched
|
2020-05-10 19:54:46 +03:00
|
|
|
x = np.around(float(matched), decimals=2)
|
2020-05-08 17:18:11 +03:00
|
|
|
return f"{x:.2f}"
|
|
|
|
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
|
2020-09-24 14:24:02 +03:00
|
|
|
what = re.sub(r"output_stream=[^\]\n,]*,?", "", what)
|
|
|
|
what = re.sub(r"threshold=[^\]\n,]*,?", "", what)
|
2020-08-12 09:20:26 +03:00
|
|
|
what = re.sub(r"bwd=[^\]\n]*", "", what)
|
|
|
|
what = re.sub(r"out_trees=[^\]\n]*", "", what)
|
|
|
|
what = re.sub(r"fwd_jaxpr_thunk=[^\]\n]*", "", what)
|
|
|
|
what = re.sub(r"jvp_jaxpr_thunk=[^\]\n]*", "", what)
|
2020-05-08 17:18:11 +03:00
|
|
|
# Empty lines
|
|
|
|
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
|
|
|
|
def repl_func(match_group):
|
|
|
|
matched = match_group.group(0)
|
|
|
|
if "function _print_consumer" in matched:
|
2020-07-04 18:12:58 +03:00
|
|
|
return "tap_func_=_print"
|
2020-05-08 17:18:11 +03:00
|
|
|
else:
|
|
|
|
return "..."
|
2020-09-24 14:24:02 +03:00
|
|
|
what = re.sub(r"tap_func_=([^\]\n,]*),?", repl_func, what)
|
2020-05-08 17:18:11 +03:00
|
|
|
tst.assertMultiLineStrippedEqual(expected, what)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
prev_xla_flags = None
|
|
|
|
def setUpModule():
|
|
|
|
global prev_xla_flags
|
|
|
|
# This will control the CPU devices. On TPU we always have 2 devices
|
|
|
|
prev_xla_flags = jtu.set_host_platform_device_count(2)
|
|
|
|
|
|
|
|
|
|
|
|
# Reset to previous configuration in case other test modules will be run.
|
|
|
|
def tearDownModule():
|
|
|
|
prev_xla_flags()
|
|
|
|
|
|
|
|
|
|
|
|
def assertMultiDeviceOutputEqual(tst: jtu.JaxTestCase,
|
|
|
|
expected_2CPUs: str):
|
|
|
|
"""Check that the multi-device output is equal to the expected.
|
|
|
|
|
|
|
|
The tests run with 2 CPU devices on CPU (due to the flag), also
|
|
|
|
on TPU (due to how the TPU tests are set up), but only 1 device on
|
|
|
|
GPU. We adjust the expected output here for 1 device.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
expected_2CPUs: the expected output for 2 CPUs. If there is only
|
|
|
|
one device, this is trimmed to the first device. If the current
|
|
|
|
device_under_test is not a CPU, then we change the names
|
|
|
|
"""
|
|
|
|
assert api.device_count() in (1, 2)
|
|
|
|
expected = expected_2CPUs
|
|
|
|
if api.device_count() == 1:
|
|
|
|
start_device_1 = expected.find('device: cpu:1')
|
|
|
|
if start_device_1 >= 0:
|
|
|
|
expected = expected[0:start_device_1]
|
|
|
|
|
|
|
|
def replace_device_name(m) -> str:
|
|
|
|
return str(api.devices()[int(m.group(1))])
|
|
|
|
expected = re.sub(r'cpu:(\d+)', replace_device_name, expected)
|
|
|
|
what = testing_stream.output_sorted_by_device
|
|
|
|
return assertMultiLineStrippedEqual(tst, expected, what)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class HostCallbackTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
testing_stream.reset()
|
2020-07-04 18:12:58 +03:00
|
|
|
testing_stream.test_method_name = self._testMethodName
|
2020-05-08 17:18:11 +03:00
|
|
|
self.old_flags = os.getenv("XLA_FLAGS", "")
|
|
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
|
if os.getenv("XLA_FLAGS") != self.old_flags:
|
|
|
|
os.environ["XLA_FLAGS"] = self.old_flags
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
2020-09-25 15:28:23 +03:00
|
|
|
hcb.barrier_wait("HostCallbackTest.tearDown")
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def helper_set_hlo_dump(self):
|
|
|
|
flags_str = os.getenv("XLA_FLAGS", "")
|
|
|
|
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def test_eval(self):
|
2020-06-07 14:45:15 +03:00
|
|
|
# TODO: renable jaxpr golden tests when changing host_callback
|
|
|
|
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.)))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: a * 2
|
|
|
|
10.00
|
|
|
|
what: y * 3
|
|
|
|
30.00""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_tuple_results(self):
|
|
|
|
def func2(x):
|
|
|
|
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
|
|
|
|
return x1 + y1
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
[ 6.00
|
|
|
|
9.00 ]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_dict_results(self):
|
|
|
|
def func2(x):
|
|
|
|
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
|
|
|
|
return res["a"] + res["b"]
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
{ a=6.00
|
|
|
|
b=9.00 }""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_result(self):
|
|
|
|
def func2(x):
|
|
|
|
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
|
|
|
output_stream=testing_stream)
|
|
|
|
return x1
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(3. * 4., func2(3.))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
[ 6.00
|
|
|
|
9.00 ]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def test_print_with_device(self):
|
|
|
|
def func2(x):
|
|
|
|
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
|
|
|
output_stream=testing_stream,
|
|
|
|
tap_with_device=True)
|
|
|
|
return x1
|
|
|
|
|
|
|
|
self.assertEqual(3. * 4., func2(3.))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0
|
|
|
|
[ 6.00
|
|
|
|
9.00 ]""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_eval_tap_exception(self):
|
|
|
|
# Simulate a tap error
|
|
|
|
def tap_err(*args, **kwargs):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
2020-09-14 02:47:28 -07:00
|
|
|
x2 = hcb.id_tap(tap_err, x1 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
|
|
|
with self.assertRaises(hcb.TapFunctionException):
|
2020-07-04 18:12:58 +03:00
|
|
|
func(0)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
# We should have received everything before the error
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x1
|
|
|
|
1
|
|
|
|
what: x3
|
|
|
|
3""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_simple(self):
|
|
|
|
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
|
|
|
|
2. * x, what="here", output_stream=testing_stream))
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertAllClose(6. * 5., jit_fun1(5.))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: here
|
|
|
|
10.00""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-09-24 14:24:02 +03:00
|
|
|
def test_jit_no_invars(self):
|
|
|
|
def func(): # jitted function does not take arguments
|
|
|
|
return hcb.id_print(42, output_stream=testing_stream)
|
|
|
|
|
|
|
|
self.assertAllClose(42, api.jit(func)())
|
|
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
42""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_multiple_invars(self):
|
|
|
|
def func(x1, x2):
|
|
|
|
return hcb.id_print(x1 + x2, output_stream=testing_stream)
|
|
|
|
|
|
|
|
self.assertAllClose(42, api.jit(func)(40, 2))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
42""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_jit_constant(self):
|
|
|
|
def func(x):
|
|
|
|
return hcb.id_print(42, result=x, output_stream=testing_stream)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertAllClose(5, api.jit(func)(5))
|
|
|
|
hcb.barrier_wait()
|
2020-05-24 10:50:07 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
42""", testing_stream.output)
|
2020-05-24 10:50:07 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_jit_sequence1(self):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
logging.info("%s: %s", self._testMethodName,
|
|
|
|
api.make_jaxpr(func)(1))
|
|
|
|
logging.info("%s: %s", self._testMethodName,
|
2020-05-11 17:43:55 -04:00
|
|
|
api.xla_computation(func)(1).as_hlo_text())
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(2, api.jit(func)(1))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit2(self):
|
|
|
|
"""A sequence of JIT."""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(2, api.jit(func)(1))
|
|
|
|
self.assertEqual(11, api.jit(func)(10))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: 1
|
|
|
|
10
|
|
|
|
where: 2
|
|
|
|
11""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-09-24 14:24:02 +03:00
|
|
|
def test_jit_result_unused(self):
|
|
|
|
"""We can id_print even if we don't use the result."""
|
2020-09-24 15:08:07 +03:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise SkipTest("Test requires omnistaging")
|
2020-09-24 14:24:02 +03:00
|
|
|
def func(x):
|
|
|
|
hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
hcb.id_print(x + 1, where="2", output_stream=testing_stream)
|
|
|
|
return x + 1
|
|
|
|
|
|
|
|
self.assertEqual(2, api.jit(func)(1))
|
|
|
|
self.assertEqual(11, api.jit(func)(10))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: 1
|
|
|
|
10
|
|
|
|
where: 2
|
|
|
|
11""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_jit_nested(self):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
def func_nested(x):
|
|
|
|
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
x3 = api.jit(func_nested)(x1)
|
|
|
|
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(3, api.jit(func)(1))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: nested
|
|
|
|
2
|
|
|
|
where: 3
|
|
|
|
3""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_devices(self):
|
|
|
|
"""Running on multiple devices."""
|
|
|
|
devices = api.local_devices()
|
|
|
|
logging.info(f"{self._testMethodName}: has devices {devices}")
|
|
|
|
def func(x, device_id):
|
|
|
|
x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
for d in devices:
|
|
|
|
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
logging.info(f"{self._testMethodName}: found output {testing_stream.output}")
|
|
|
|
self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output)))
|
|
|
|
self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output)))
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_pytree(self, with_jit=False):
|
|
|
|
def func(x, what=""):
|
|
|
|
"""Returns some pytrees depending on x"""
|
|
|
|
if what == "pair_1_x":
|
|
|
|
return (1, x)
|
|
|
|
elif what == "pair_x_2x":
|
|
|
|
return (x, 2 * x)
|
|
|
|
elif what == "dict":
|
|
|
|
return dict(a=2 * x, b=3 * x)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
tap_count = 0
|
2020-09-14 02:47:28 -07:00
|
|
|
|
|
|
|
def tap_func(a, _, *, what=""):
|
2020-05-08 17:18:11 +03:00
|
|
|
nonlocal tap_count
|
|
|
|
tap_count += 1
|
|
|
|
self.assertEqual(func(5, what), a)
|
|
|
|
|
|
|
|
transform = api.jit if with_jit else lambda f: f
|
2020-07-04 18:12:58 +03:00
|
|
|
for what in ("pair_1_x", "pair_x_2x", "dict"):
|
2020-09-14 02:47:28 -07:00
|
|
|
transformed = transform(
|
|
|
|
lambda x: hcb.id_tap(
|
|
|
|
functools.partial(tap_func, what=what),
|
|
|
|
func(x, what),
|
|
|
|
result=func(x * 2, what))
|
|
|
|
)(5)
|
|
|
|
self.assertEqual(func(10, what), transformed)
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait() # Wait for receivers to be done
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertEqual(3, tap_count)
|
|
|
|
|
2020-07-08 16:08:54 +03:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_concurrent_{concurrent}",
|
|
|
|
concurrent=concurrent)
|
|
|
|
for concurrent in [True, False]))
|
|
|
|
def test_multiple_tap(self, concurrent=False):
|
|
|
|
"""Call id_tap multiple times, concurrently or in sequence. """
|
|
|
|
if concurrent and jtu.device_under_test() == "gpu":
|
|
|
|
# TODO(necula): it seems that on GPU if multiple host threads run
|
2020-09-24 14:24:02 +03:00
|
|
|
# a jit computation, the multiple computations are interleaved on the
|
2020-07-08 16:08:54 +03:00
|
|
|
# GPU. This can result in the outfeed trains being interleaved, which
|
|
|
|
# will trigger an error. The solution is to fix on GPU the receiving
|
|
|
|
# logic so that we can outfeed the train as one tuple, and receive it
|
|
|
|
# one piece as a time. Then the trains should be atomic.
|
|
|
|
# See also b/160692602.
|
|
|
|
raise SkipTest("concurrent id_tap not supported on GPU")
|
|
|
|
received = set()
|
|
|
|
count = 5
|
2020-09-14 02:47:28 -07:00
|
|
|
def pause_tap(idx, _):
|
2020-07-08 16:08:54 +03:00
|
|
|
received.add(int(idx))
|
|
|
|
logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...")
|
|
|
|
time.sleep(0.3)
|
|
|
|
logging.info(f"Finish do_tap {idx}")
|
|
|
|
|
|
|
|
def do_tap(idx):
|
|
|
|
api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)
|
|
|
|
|
|
|
|
if concurrent:
|
|
|
|
threads = [
|
|
|
|
threading.Thread(
|
|
|
|
name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,))
|
|
|
|
for idx in range(count)
|
|
|
|
]
|
|
|
|
[t.start() for t in threads]
|
|
|
|
[t.join() for t in threads]
|
|
|
|
else:
|
|
|
|
for idx in range(count):
|
|
|
|
do_tap(idx)
|
|
|
|
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertEqual(received, set(range(count)))
|
|
|
|
|
|
|
|
# TODO(necula): see comment for test_multiple_tap.
|
|
|
|
@jtu.skip_on_devices("gpu")
|
|
|
|
def test_multiple_barriers(self):
|
|
|
|
"""Call barrier_wait concurrently."""
|
|
|
|
|
|
|
|
def pause_tap(*args, **kwargs):
|
|
|
|
logging.info("pause_tap waiting")
|
|
|
|
time.sleep(0.3)
|
|
|
|
logging.info("pause_tap done")
|
|
|
|
|
|
|
|
def long_run(x):
|
|
|
|
return hcb.id_tap(pause_tap, x)
|
|
|
|
|
|
|
|
api.jit(long_run)(5.)
|
|
|
|
|
|
|
|
def try_barrier(idx):
|
|
|
|
logging.info(f"Starting test barrier {idx}")
|
|
|
|
hcb.barrier_wait()
|
|
|
|
logging.info(f"Finished test barrier {idx}")
|
|
|
|
|
|
|
|
threads = [
|
|
|
|
threading.Thread(
|
|
|
|
name=f"barrier_{idx}", target=try_barrier, args=(idx,))
|
|
|
|
for idx in range(3)
|
|
|
|
]
|
|
|
|
[t.start() for t in threads]
|
|
|
|
[t.join() for t in threads]
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_cond(self, with_jit=False):
|
|
|
|
"""A conditional"""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-07-04 18:12:58 +03:00
|
|
|
lambda x: hcb.id_print(x, where="cond_t",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="cond_f", result=x,
|
|
|
|
output_stream=testing_stream),
|
2020-05-14 09:02:29 -07:00
|
|
|
x2 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
|
|
|
|
return x5
|
|
|
|
|
|
|
|
transform = api.jit if with_jit else lambda f: f
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(4, transform(func)(1))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: cond_f
|
|
|
|
-1
|
|
|
|
where: end
|
|
|
|
4""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
2020-07-04 18:12:58 +03:00
|
|
|
dict(testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
2020-05-08 17:18:11 +03:00
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_while_cond(self, with_jit=False):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
def body(x):
|
|
|
|
x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream)
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: hcb.id_print(x, where="w_b_t",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="w_b_f",
|
|
|
|
result=x, output_stream=testing_stream),
|
|
|
|
x3 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream)
|
|
|
|
x10 = lax.while_loop(lambda x: x <= 3, body, x2)
|
|
|
|
res = hcb.id_print(x10, where="end", output_stream=testing_stream)
|
|
|
|
return res
|
2020-05-24 10:50:07 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
transform = api.jit if with_jit else lambda f: f
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(4, transform(func)(1))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: w_b_1
|
|
|
|
2
|
|
|
|
where: w_b_t
|
|
|
|
3
|
|
|
|
where: w_b_2
|
|
|
|
3
|
|
|
|
where: w_b_1
|
|
|
|
3
|
|
|
|
where: w_b_f
|
|
|
|
-1
|
|
|
|
where: w_b_2
|
|
|
|
4
|
|
|
|
where: end
|
|
|
|
4""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_jit_while_pred_tap(self):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""While with printing in the conditional."""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1")
|
2020-05-24 10:50:07 +03:00
|
|
|
x10 = lax.while_loop(lambda x: hcb.id_print(x < 3,
|
|
|
|
where="w_p",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
|
2020-05-08 17:18:11 +03:00
|
|
|
return res
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertEqual(3, api.jit(func)(1))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self,
|
2020-09-14 02:47:28 -07:00
|
|
|
"""
|
|
|
|
where: w_p
|
|
|
|
True
|
|
|
|
where: w_b
|
|
|
|
2
|
|
|
|
where: w_p
|
|
|
|
True
|
|
|
|
where: w_b
|
|
|
|
3
|
|
|
|
where: w_p
|
|
|
|
False
|
|
|
|
where: 3
|
|
|
|
3""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_scan_cond(self, with_jit=False):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
def body(c, x):
|
|
|
|
x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: hcb.id_print(x, where="s_t", output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream),
|
|
|
|
x3 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream))
|
|
|
|
|
2020-05-10 19:54:46 +03:00
|
|
|
_, x10 = lax.scan(body, x2, jnp.arange(3))
|
2020-05-08 17:18:11 +03:00
|
|
|
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
if with_jit:
|
|
|
|
func = api.jit(func)
|
|
|
|
res = func(1)
|
|
|
|
self.assertAllClose(jnp.array([1, 2, 3]), res)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: s_1
|
|
|
|
0
|
|
|
|
where: s_t
|
|
|
|
1
|
|
|
|
where: s_2
|
|
|
|
1
|
|
|
|
where: s_1
|
|
|
|
1
|
|
|
|
where: s_f
|
|
|
|
-1
|
|
|
|
where: s_2
|
|
|
|
2
|
|
|
|
where: s_1
|
|
|
|
2
|
|
|
|
where: s_t
|
|
|
|
3
|
|
|
|
where: s_2
|
|
|
|
3
|
|
|
|
where: 10
|
|
|
|
[1 2 3]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
|
|
|
|
shape=shape,
|
|
|
|
dtype=dtype,
|
|
|
|
nr_args=nr_args) for nr_args in [1, 2]
|
|
|
|
for shape in [(), (2,), (2, 3), (2, 3, 4)]
|
2020-07-07 17:01:38 -07:00
|
|
|
for dtype in jtu.dtypes.all))
|
2020-05-10 19:54:46 +03:00
|
|
|
def test_jit_types(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
|
|
|
|
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
|
2020-05-08 17:18:11 +03:00
|
|
|
raise SkipTest(f"id_print jit not implemented for {dtype}.")
|
|
|
|
if jtu.device_under_test() == "tpu":
|
2020-05-10 19:54:46 +03:00
|
|
|
if dtype in (jnp.int16,):
|
2020-05-08 17:18:11 +03:00
|
|
|
raise SkipTest(f"transfering {dtype} not supported on TPU")
|
2020-08-18 10:17:38 -07:00
|
|
|
args = [jnp.arange(prod(shape), dtype=dtype).reshape(shape)]
|
2020-05-08 17:18:11 +03:00
|
|
|
if nr_args > 1:
|
|
|
|
args = args * nr_args
|
|
|
|
jit_fun1 = api.jit(lambda xs: hcb.id_print(
|
|
|
|
xs,
|
|
|
|
a_new_test="************",
|
|
|
|
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
res = jit_fun1(args)
|
|
|
|
self.assertAllClose(args, res)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_jit_large(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
|
2020-07-04 18:12:58 +03:00
|
|
|
api.jit(hcb.id_print)(arg)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_jit_several_together(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
|
2020-07-04 18:12:58 +03:00
|
|
|
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_jit_interleaving(self):
|
|
|
|
# Several jit's without data dependencies; they may interfere
|
|
|
|
count = 0 # Count tap invocations
|
|
|
|
nr_arrays = 5
|
2020-09-14 02:47:28 -07:00
|
|
|
def tap_func(arg, _):
|
2020-05-08 17:18:11 +03:00
|
|
|
nonlocal count
|
|
|
|
assert len(arg) == nr_arrays
|
|
|
|
count += 1
|
|
|
|
# This is the function that we'll run multiple times
|
|
|
|
def func(x, count):
|
|
|
|
for i in range(count):
|
2020-09-14 02:47:28 -07:00
|
|
|
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)])[-1]
|
2020-05-08 17:18:11 +03:00
|
|
|
return x
|
2020-07-04 18:12:58 +03:00
|
|
|
|
|
|
|
x = jnp.array(1, dtype=np.int32)
|
|
|
|
res = 0
|
|
|
|
for _ in range(10):
|
|
|
|
# No dependencies between the jit invocations
|
|
|
|
res += api.jit(lambda x: func(x, 10))(x)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertEqual(100, count)
|
|
|
|
|
|
|
|
def test_jit_tap_exception(self):
|
|
|
|
# Simulate a tap error
|
|
|
|
def tap_err(*args, **kwargs):
|
|
|
|
raise NotImplementedError
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
2020-09-14 02:47:28 -07:00
|
|
|
x2 = hcb.id_tap(tap_err, x1 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
res = api.jit(func)(0) # No error yet
|
2020-05-08 17:18:11 +03:00
|
|
|
with self.assertRaises(hcb.TapFunctionException):
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
# Even though the receiver thread raised, the main thread should still
|
|
|
|
# return 3.
|
|
|
|
self.assertEqual(3, res)
|
|
|
|
# We should have received all others
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x1
|
|
|
|
1
|
|
|
|
what: x3
|
|
|
|
3""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_while(self):
|
|
|
|
"""Executing while, even without JIT uses compiled code"""
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return lax.while_loop(
|
|
|
|
lambda c: c[1] < 5,
|
|
|
|
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
|
|
|
(x, 1))
|
2020-07-04 18:12:58 +03:00
|
|
|
func(y)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
1
|
|
|
|
2
|
|
|
|
3
|
|
|
|
4""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jvp(self):
|
|
|
|
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
2020-07-04 18:12:58 +03:00
|
|
|
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(100., res_primals, check_dtypes=False)
|
|
|
|
self.assertAllClose(4., res_tangents, check_dtypes=False)
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: a * 2
|
|
|
|
10.00
|
|
|
|
transforms: ['jvp'] what: a * 2
|
|
|
|
0.20
|
|
|
|
what: y * 3
|
|
|
|
30.00
|
|
|
|
transforms: ['jvp'] what: y * 3
|
|
|
|
0.60""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_primal_unused(self):
|
2020-09-24 15:08:07 +03:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise SkipTest("Test requires omnistaging")
|
2020-05-08 17:18:11 +03:00
|
|
|
# The output of id_print is not needed for backwards pass
|
|
|
|
def func(x):
|
2020-07-04 18:12:58 +03:00
|
|
|
return 2. * hcb.id_print(x * 3., what="x * 3",
|
|
|
|
output_stream=testing_stream)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
grad_func = api.grad(func)
|
2020-07-04 18:12:58 +03:00
|
|
|
jaxpr = str(api.make_jaxpr(grad_func)(5.))
|
2020-09-24 14:24:02 +03:00
|
|
|
# making the Jaxpr does not print anything
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-09-24 14:24:02 +03:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda ; a.
|
2020-09-24 14:24:02 +03:00
|
|
|
let b = mul a 3.00
|
|
|
|
c = id_tap[ arg_treedef_=*
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print what='x * 3') ] b
|
|
|
|
_ = mul c 2.00
|
|
|
|
d = mul 1.00 2.00
|
|
|
|
e _ = id_tap[ arg_treedef_=*
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print what='x * 3')
|
|
|
|
transforms=(('jvp',), ('transpose',)) ] d 0.00
|
|
|
|
f = mul e 3.00
|
|
|
|
in (f,) }""", jaxpr)
|
|
|
|
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
self.assertAllClose(6., res_grad, check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x * 3
|
|
|
|
15.00
|
|
|
|
transforms: ['jvp', 'transpose'] what: x * 3
|
|
|
|
2.00""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_simple(self):
|
|
|
|
def func(x):
|
|
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
2020-07-04 18:12:58 +03:00
|
|
|
return x * hcb.id_print(y * 3., what="y * 3",
|
|
|
|
output_stream=testing_stream)
|
2020-05-08 17:18:11 +03:00
|
|
|
grad_func = api.grad(func)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x * 2
|
|
|
|
10.00
|
|
|
|
what: y * 3
|
|
|
|
30.00
|
|
|
|
transforms: ['jvp', 'transpose'] what: y * 3
|
|
|
|
5.00
|
|
|
|
transforms: ['jvp', 'transpose'] what: x * 2
|
|
|
|
15.00""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_double(self):
|
2020-09-24 15:08:07 +03:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise SkipTest("Test requires omnistaging")
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x):
|
|
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
|
|
|
return x * (y * 3.)
|
|
|
|
|
|
|
|
grad_func = api.grad(api.grad(func))
|
2020-09-24 14:24:02 +03:00
|
|
|
# making the Jaxpr does not print anything
|
2020-07-04 18:12:58 +03:00
|
|
|
_ = api.make_jaxpr(grad_func)(5.)
|
|
|
|
hcb.barrier_wait()
|
2020-09-24 14:24:02 +03:00
|
|
|
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
self.assertAllClose(12., res_grad, check_dtypes=False)
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x * 2
|
|
|
|
10.00
|
|
|
|
transforms: ['jvp', 'transpose'] what: x * 2
|
|
|
|
15.00
|
|
|
|
transforms: ['jvp', 'transpose', 'jvp', 'transpose'] what: x * 2
|
|
|
|
2.00
|
|
|
|
transforms: ['jvp', 'transpose'] what: x * 2
|
|
|
|
3.00""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
vmap_fun1 = api.vmap(fun1)
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
2020-07-04 18:12:58 +03:00
|
|
|
vmap_fun1(vargs)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] what: a * 2
|
|
|
|
[ 8.00 10.00]
|
|
|
|
transforms: [('batch', {'batch_dims': (0, 0)})] what: y * 3
|
|
|
|
[24.00 30.00]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap_not_batched(self):
|
|
|
|
x = 3.
|
|
|
|
def func(y):
|
|
|
|
# x is not mapped, y is mapped
|
|
|
|
_, y = hcb.id_print((x, y), output_stream=testing_stream)
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
vmap_func = api.vmap(func)
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
2020-07-04 18:12:58 +03:00
|
|
|
_ = vmap_func(vargs)
|
|
|
|
hcb.barrier_wait()
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: [('batch', {'batch_dims': (None, 0)})]
|
|
|
|
[ 3.00
|
|
|
|
[4.00 5.00] ]""", testing_stream.output)
|
2020-05-23 13:49:27 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_double_vmap(self):
|
|
|
|
# A 2D tensor with x[i, j] = i + j using 2 vmap
|
|
|
|
def sum(x, y):
|
|
|
|
return hcb.id_print(x + y, output_stream=testing_stream)
|
|
|
|
def sum_rows(xv, y):
|
|
|
|
return api.vmap(sum, in_axes=(0, None))(xv, y)
|
|
|
|
def sum_all(xv, yv):
|
|
|
|
return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)
|
|
|
|
|
|
|
|
xv = jnp.arange(5, dtype=np.int32)
|
|
|
|
yv = jnp.arange(3, dtype=np.int32)
|
2020-06-07 14:45:15 +03:00
|
|
|
#assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv)))
|
2020-07-04 18:12:58 +03:00
|
|
|
_ = sum_all(xv, yv)
|
|
|
|
hcb.barrier_wait()
|
2020-05-23 13:49:27 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: [('batch', {'batch_dims': (0,)}), ('batch', {'batch_dims': (0,)})]
|
|
|
|
[[0 1 2 3 4]
|
|
|
|
[1 2 3 4 5]
|
|
|
|
[2 3 4 5 6]]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_vmap_while(self):
|
|
|
|
"""Vmap of while."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
# like max(x, 2)
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = lax.while_loop(lambda x: x < 2,
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
|
|
|
check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
2020-05-24 10:50:07 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 1
|
|
|
|
[0 1 2 3 4]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
|
|
[1 2 3 4 5]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
|
|
[2 3 3 4 5]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 3
|
|
|
|
[2 2 2 3 4]""", testing_stream.output)
|
2020-05-24 10:50:07 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap_while_tap_cond(self):
|
|
|
|
"""Vmap of while, with a tap in the conditional."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
# like max(x, 2)
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
2020-07-04 18:12:58 +03:00
|
|
|
res = api.jit(api.vmap(func))(inputs)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
|
2020-05-24 10:50:07 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 1
|
|
|
|
[0 1 2 3 4]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
|
|
[ True True False False False]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
|
|
[1 2 3 4 5]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
|
|
[ True False False False False]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_b
|
|
|
|
[2 3 3 4 5]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: w_c
|
|
|
|
[False False False False False]
|
|
|
|
transforms: [('batch', {'batch_dims': (0,)})] where: 3
|
|
|
|
[2 2 2 3 4]""", testing_stream.output)
|
2020-05-24 10:50:07 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_pmap(self):
|
2020-12-13 10:44:20 +02:00
|
|
|
xv = jnp.arange(api.device_count(), dtype=jnp.int32)
|
|
|
|
def fun1(x, do_print=False): # x: i32
|
|
|
|
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
|
|
|
|
|
|
|
|
pmap_fun1 = api.pmap(functools.partial(fun1, do_print=True))
|
|
|
|
res = pmap_fun1(xv)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected_res = api.pmap(functools.partial(fun1, do_print=False))(xv)
|
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 what: x * 2
|
|
|
|
0
|
|
|
|
device: cpu:1 what: x * 2
|
|
|
|
2""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_pmap_vmap(self):
|
|
|
|
# A matrix M[ij] = i * 10 + j
|
|
|
|
nr_devices = api.device_count()
|
|
|
|
shape = (nr_devices, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
|
|
dtype=np.int32)
|
|
|
|
|
|
|
|
def fun1(x, do_print=False): # x: i32
|
|
|
|
return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True)
|
|
|
|
|
|
|
|
pmap_vmap_fun1 = api.pmap(api.vmap(functools.partial(fun1, do_print=True)))
|
|
|
|
|
|
|
|
res = pmap_vmap_fun1(matrix)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected_res = api.pmap(api.vmap(functools.partial(fun1, do_print=False)))(matrix)
|
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[0.00 2.00 4.00]
|
|
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[20.00 22.00 24.00]""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_pmap_pmap_vmap(self):
|
|
|
|
# A matrix M[ijk] = i * 100 + j * 10 + k
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
if nr_devices % 2 != 0:
|
|
|
|
raise SkipTest("test works only on even number of devices")
|
|
|
|
|
|
|
|
shape = (2, nr_devices // 2, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
def fun1(x, do_print=False): # x: f32
|
|
|
|
y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)
|
|
|
|
return y ** 2
|
|
|
|
|
|
|
|
pmap_fun1 = api.pmap(api.pmap(api.vmap(functools.partial(fun1, do_print=True))))
|
|
|
|
res = pmap_fun1(matrix)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected_res = api.pmap(api.pmap(api.vmap(functools.partial(fun1, do_print=False))))(matrix)
|
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[0.00 2.00 4.00]
|
|
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[200.00 202.00 204.00]""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@ignore_jit_of_pmap_warning()
|
|
|
|
def test_pmap_pmap_extra(self):
|
|
|
|
"""pmap of a pmap surrounded by extra code."""
|
|
|
|
# A matrix M[ij] = i * 10 + j
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
if nr_devices != 2:
|
|
|
|
raise SkipTest("test works only on 2 devices")
|
|
|
|
shape = (2, 1, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
|
|
# This will be printed on all devices, with shape [1, 3]
|
|
|
|
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
|
|
|
|
res = api.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
|
|
|
|
# This will be printed on all devices, with shape [1, 3]
|
|
|
|
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
|
|
|
|
|
|
|
|
res = api.pmap(functools.partial(fun, do_print=True))(matrix)
|
|
|
|
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 what: before
|
|
|
|
[[1.00 2.00 3.00]]
|
|
|
|
device: cpu:0 what: inside
|
|
|
|
[2.00 4.00 6.00]
|
|
|
|
device: cpu:0 what: after
|
|
|
|
[[3.00 5.00 7.00]]
|
|
|
|
device: cpu:1 what: before
|
|
|
|
[[101.00 102.00 103.00]]
|
|
|
|
device: cpu:1 what: inside
|
|
|
|
[202.00 204.00 206.00]
|
|
|
|
device: cpu:1 what: after
|
|
|
|
[[203.00 205.00 207.00]]""")
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jvp_pmap_vmap(self):
|
|
|
|
# A matrix M[ijk] = i * 100 + j * 10 * k
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
shape = (nr_devices, 2, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
|
|
# x: f32[3]
|
|
|
|
return api.jvp(api.pmap(api.vmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True))),
|
|
|
|
(xv,), (.1 * jnp.ones_like(xv),))
|
|
|
|
|
|
|
|
res = fun(matrix, do_print=True)
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb.barrier_wait()
|
2020-12-13 10:44:20 +02:00
|
|
|
expected_res = fun(matrix, do_print=False)
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
2020-12-13 10:44:20 +02:00
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
# Device 0 will get to execute api.jvp(api.vmap(...)) for matrix[0, :, :]
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[[ 0.00 2.00 4.00]
|
|
|
|
[20.00 22.00 24.00]]
|
|
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
|
|
|
[[0.20 0.20 0.20]
|
|
|
|
[0.20 0.20 0.20]]
|
|
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[[200.00 202.00 204.00]
|
|
|
|
[220.00 222.00 224.00]]
|
|
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)}), 'jvp'] what: x * 2
|
|
|
|
[[0.20 0.20 0.20]
|
|
|
|
[0.20 0.20 0.20]]""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap_pmap(self):
|
|
|
|
# A matrix M[ijk] = i * 100 + j * 10 * k
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
shape = (2, nr_devices, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
def fun(xv, do_print=False):
|
|
|
|
# x: f32[3]
|
|
|
|
return api.vmap(api.pmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)))(xv)
|
|
|
|
|
|
|
|
res = fun(matrix, do_print=True)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected_res = fun(matrix, do_print=False)
|
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
# Assertion text is for 2 devices (also works for 1 device)
|
|
|
|
# Device 0 will get to execute api.jvp(api.vmap(...)) for matrix[:, 0, :]
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[[ 0.00 2.00 4.00]
|
|
|
|
[200.00 202.00 204.00]]
|
|
|
|
device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2
|
|
|
|
[[ 20.00 22.00 24.00]
|
|
|
|
[220.00 222.00 224.00]]""")
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@ignore_jit_of_pmap_warning()
|
|
|
|
def test_jit_pmap_extra(self):
|
|
|
|
"""jit of a pmap surrounded by extra code."""
|
|
|
|
# A matrix M[ij] = i * 10 + j
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
shape = (nr_devices, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
|
|
|
|
def fun(xv, do_print=False):
|
|
|
|
# This will be printed on all devices with shape (nr_devices, 3)
|
|
|
|
xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True)
|
|
|
|
res = api.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv)
|
|
|
|
# This will be printed on all devices with shape (nr_devices, 3)
|
|
|
|
return maybe_print(do_print, res + 1., "after", tap_with_device=True)
|
|
|
|
|
|
|
|
res = api.jit(functools.partial(fun, do_print=True))(matrix)
|
|
|
|
self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
if api.device_count() == 2:
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 what: before
|
|
|
|
[[ 1.00 2.00 3.00]
|
|
|
|
[11.00 12.00 13.00]]
|
|
|
|
device: cpu:0 what: inside
|
|
|
|
[2.00 4.00 6.00]
|
|
|
|
device: cpu:0 what: after
|
|
|
|
[[ 3.00 5.00 7.00]
|
|
|
|
[23.00 25.00 27.00]]
|
|
|
|
device: cpu:1 what: before
|
|
|
|
[[ 1.00 2.00 3.00]
|
|
|
|
[11.00 12.00 13.00]]
|
|
|
|
device: cpu:1 what: inside
|
|
|
|
[22.00 24.00 26.00]
|
|
|
|
device: cpu:1 what: after
|
|
|
|
[[ 3.00 5.00 7.00]
|
|
|
|
[23.00 25.00 27.00]]""")
|
|
|
|
else:
|
|
|
|
assert api.device_count() == 1
|
|
|
|
assertMultiDeviceOutputEqual(self, """
|
|
|
|
device: cpu:0 what: before
|
|
|
|
[[1.00 2.00 3.00]]
|
|
|
|
device: cpu:0 what: inside
|
|
|
|
[2.00 4.00 6.00]
|
|
|
|
device: cpu:0 what: after
|
|
|
|
[[3.00 5.00 7.00]]""")
|
|
|
|
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_cond_pmap(self):
|
|
|
|
raise SkipTest("cond of pmap does not work in JAX. Issue #5178.")
|
|
|
|
# A matrix M[ij] = i * 10 + j
|
|
|
|
nr_devices = api.local_device_count()
|
|
|
|
shape = (nr_devices, 3)
|
|
|
|
matrix = np.fromfunction(lambda i, j: 10. * i + j, shape,
|
|
|
|
dtype=np.float32)
|
|
|
|
def fun1(x, do_print=False):
|
|
|
|
return maybe_print(do_print, x * 2., "x * 2")
|
|
|
|
def fun2(cond, xv, do_print=False):
|
|
|
|
return lax.cond(cond, api.pmap(functools.partial(fun1, do_print=do_print)),
|
|
|
|
lambda xv: xv, xv)
|
|
|
|
res = fun2(True, matrix)
|
|
|
|
self.assertAllClose(fun2(True, matrix, do_print=False), res, check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
TBD""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
def test_scan_custom_jvp(self):
|
|
|
|
"""custom JVP, inside scan.
|
|
|
|
This exercises the custom_jvp_call_jaxpr primitives."""
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
x_dot, = tangents
|
|
|
|
primal_out = f(x)
|
|
|
|
tangent_out = 3. * x * hcb.id_print(x_dot, output_stream=testing_stream, what="x_dot")
|
|
|
|
return primal_out, tangent_out
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
# Sum f(x_i)
|
|
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
|
|
x)[0]
|
|
|
|
|
|
|
|
arg = np.full((2,), 0.7)
|
|
|
|
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
what: x
|
|
|
|
0.7""", testing_stream.output)
|
2020-08-12 09:20:26 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
self.assertAllClose(np.array([2.1, 2.1]), api.grad(g)(arg), check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
transforms: ['transpose'] what: x_dot
|
|
|
|
2.1
|
|
|
|
transforms: ['transpose'] what: x_dot
|
|
|
|
2.1""", testing_stream.output)
|
2020-08-12 09:20:26 +03:00
|
|
|
|
|
|
|
def test_scan_custom_vjp(self):
|
|
|
|
"""custom VJP, inside scan.
|
|
|
|
This exercises the custom_vjp_call_jaxpr primitives."""
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return x * hcb.id_print(x, output_stream=testing_stream, what="x")
|
|
|
|
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), 3. * x
|
|
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
|
|
def f_bwd(residual, ct_b):
|
|
|
|
return residual * hcb.id_print(ct_b, output_stream=testing_stream, what="ct_b"),
|
|
|
|
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
# Sum f(x_i)
|
|
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
|
|
x)[0]
|
|
|
|
|
|
|
|
arg = np.full((2,), 0.7)
|
|
|
|
|
|
|
|
self.assertAllClose(0.7 * 0.7 * 2, g(arg))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
what: x
|
|
|
|
0.7""", testing_stream.output)
|
2020-08-12 09:20:26 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
self.assertAllClose(np.array([2.1, 2.1]), api.grad(g)(arg), check_dtypes=False)
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
what: x
|
|
|
|
0.7
|
|
|
|
what: ct_b
|
|
|
|
1.
|
|
|
|
what: ct_b
|
|
|
|
1.""", testing_stream.output)
|
2020-08-12 09:20:26 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_mask(self):
|
|
|
|
# TODO(necula)
|
|
|
|
raise SkipTest("masking has regressed")
|
2020-07-04 18:12:58 +03:00
|
|
|
@functools.partial(api.mask, in_shapes=['n'], out_shape='')
|
2020-05-08 17:18:11 +03:00
|
|
|
def padded_sum(x):
|
2020-05-10 19:54:46 +03:00
|
|
|
return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
|
|
|
|
args = [jnp.arange(4)], dict(n=np.int64(2))
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda c f ; a b.
|
|
|
|
let d = lt c b
|
|
|
|
e = id_tap[ func=_print
|
|
|
|
logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
|
|
|
|
transforms=('mask',)
|
|
|
|
what=x ] a
|
|
|
|
g = select d e f
|
|
|
|
h = reduce_sum[ axes=(0,) ] g
|
|
|
|
in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = padded_sum(*args)
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertMultiLineStrippedEqual("""
|
2020-09-14 02:47:28 -07:00
|
|
|
logical_shapes: [(2,)] transforms: ['mask',) what: x
|
|
|
|
[0 1 2 3]
|
|
|
|
""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
def test_callback_delay(self):
|
|
|
|
hcb.callback_extra = lambda dev: time.sleep(1)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
for i in range(5):
|
|
|
|
x = hcb.id_print(x * i, what="x times i")
|
|
|
|
return x
|
|
|
|
|
|
|
|
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
|
|
|
|
|
|
def test_callback_delay_barrier(self):
|
|
|
|
hcb.callback_extra = lambda dev: time.sleep(2)
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
for i in range(1, 4):
|
|
|
|
x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream)
|
|
|
|
return x
|
|
|
|
|
|
|
|
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
|
|
# Wait for the results
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected = """
|
2020-09-14 02:47:28 -07:00
|
|
|
what: x times i
|
|
|
|
[[0. 1. 2.]
|
|
|
|
[3. 4. 5.]]
|
|
|
|
what: x times i
|
|
|
|
[[ 0. 2. 4.]
|
|
|
|
[ 6. 8. 10.]]
|
|
|
|
what: x times i
|
|
|
|
[[ 0. 6. 12.]
|
|
|
|
[18. 24. 30.]]"""
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
# Call again
|
|
|
|
api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
|
|
|
|
|
|
|
|
def test_error_bad_consumer_id(self):
|
|
|
|
"""Try to use reserved consumer ID 0.
|
|
|
|
|
|
|
|
Check that we get the proper error from the runtime."""
|
|
|
|
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
|
|
|
token = hcb.xops.CreateToken(comp)
|
2020-07-07 11:03:30 +03:00
|
|
|
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
2020-07-04 18:12:58 +03:00
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
|
|
"Consumer ID cannot be a reserved value: 0"):
|
|
|
|
hcb._outfeed_receiver.receiver.add_outfeed(
|
|
|
|
comp, token, 0,
|
|
|
|
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
|
|
|
|
|
|
|
def test_error_different_shapes(self):
|
|
|
|
"""Try to register different shapes for the same consumer ID."""
|
|
|
|
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
|
|
|
token = hcb.xops.CreateToken(comp)
|
2020-07-07 11:03:30 +03:00
|
|
|
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
2020-07-04 18:12:58 +03:00
|
|
|
hcb._outfeed_receiver.receiver.add_outfeed(
|
|
|
|
comp, token, 123,
|
|
|
|
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError, ".*does not match previous shape element_type.*"):
|
|
|
|
hcb._outfeed_receiver.receiver.add_outfeed(
|
|
|
|
comp, token, 123,
|
|
|
|
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
RuntimeError, ".*does not match previous shape element_type.*"):
|
|
|
|
hcb._outfeed_receiver.receiver.add_outfeed(
|
|
|
|
comp, token, 123,
|
|
|
|
[xla_bridge.constant(comp, np.zeros((2,), dtype=np.float32))])
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
def test_id_tap_deprecated_kwargs(self):
|
|
|
|
def func(x, transforms, y):
|
|
|
|
pass
|
|
|
|
with self.assertWarnsRegex(
|
|
|
|
FutureWarning, r"Support for \*\*kwargs in ``id_tap``"):
|
|
|
|
hcb.id_tap(func, 1, y=2)
|
|
|
|
|
2020-08-13 13:02:22 +03:00
|
|
|
def test_odeint(self):
|
|
|
|
# TODO: find a smaller repro for bug #4015
|
|
|
|
# Seems to be xla_call(scan(xla_call)), all under grad.
|
|
|
|
from jax.experimental.ode import odeint
|
|
|
|
|
|
|
|
def f(x, t, k):
|
|
|
|
x = hcb.id_print(x)
|
|
|
|
return -k * x
|
|
|
|
|
|
|
|
def loss(k=1.0):
|
|
|
|
t = jnp.linspace(0, 0.001, num=2)
|
|
|
|
xs = odeint(f, 1.0, t, k)
|
|
|
|
return xs[-1]
|
|
|
|
|
|
|
|
api.grad(loss)(1.0) # should not fail
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-10-16 10:52:56 +03:00
|
|
|
def test_remat(self):
|
|
|
|
def f(i, k):
|
|
|
|
x = hcb.id_print(k + i, output_stream=testing_stream)
|
|
|
|
return k * x
|
|
|
|
|
|
|
|
def loss(k):
|
|
|
|
return lax.fori_loop(0, 2, api.remat(f), k)
|
|
|
|
print(loss(3))
|
|
|
|
hcb.barrier_wait()
|
|
|
|
expected = """
|
|
|
|
3
|
|
|
|
10"""
|
|
|
|
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
|
|
|
|
2020-09-14 02:47:28 -07:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class OutfeedRewriterTest(jtu.JaxTestCase):
|
2020-07-04 18:12:58 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
|
|
|
has_input_token=True, has_output_token=True):
|
|
|
|
"""Check that the rewrite of func(*args) matches expected."""
|
2020-07-04 18:12:58 +03:00
|
|
|
jaxpr = api.make_jaxpr(func)(*args)
|
2020-10-16 11:37:23 +03:00
|
|
|
rewritten = hcb._rewrite_closed_jaxpr(jaxpr, # noqa: F841
|
2020-10-16 10:52:56 +03:00
|
|
|
has_input_token, has_output_token)
|
|
|
|
# Since it is somewhat annoying to update the Jaxpr assertions when we change
|
|
|
|
# the Jaxpr printing, we do not check these by default. It is recommended that
|
|
|
|
# before making changes to the code generation and Jaxpr rewriting, turn on
|
|
|
|
# the checking, update the expected Jaxpr, and then make the changes.
|
2020-12-13 10:44:20 +02:00
|
|
|
# assertMultiLineStrippedEqual(self, expected, str(rewritten))
|
2020-07-04 18:12:58 +03:00
|
|
|
del jaxpr
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_no_outfeed(self):
|
|
|
|
self.assertRewrite("""
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False,
|
|
|
|
has_output_token=False)
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertRewrite("""
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda ; a d.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertRewrite("""
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda ; a d.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c, d) }""", lambda x: x + x * x, [0])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_simple_outfeed(self):
|
|
|
|
self.assertRewrite("""
|
2020-09-14 02:47:28 -07:00
|
|
|
{ lambda ; a d.
|
|
|
|
let b = add a a
|
|
|
|
c e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
2020-09-24 14:24:02 +03:00
|
|
|
tap_func_=_print ] b d
|
2020-09-14 02:47:28 -07:00
|
|
|
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-09-24 14:24:02 +03:00
|
|
|
def test_simple_outfeed_without_input_token(self):
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a b.
|
|
|
|
let e = create_token a b
|
|
|
|
c = add a b
|
|
|
|
d f = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] c e
|
|
|
|
in (d,) }""", lambda x1, x2: hcb.id_print(x1 + x2), [1, 2],
|
|
|
|
has_input_token=False, has_output_token=False)
|
|
|
|
|
|
|
|
def test_simple_outfeed_without_input_token_nor_invars(self):
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; .
|
2020-09-24 15:08:07 +03:00
|
|
|
let b = create_token
|
2020-09-24 14:24:02 +03:00
|
|
|
a c = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] 42 b
|
|
|
|
in (a,) }""", lambda: hcb.id_print(42), [],
|
|
|
|
has_input_token=False, has_output_token=False)
|
|
|
|
|
|
|
|
def test_multiple_tap_without_dependencies(self):
|
|
|
|
def f(x):
|
|
|
|
hcb.id_print(x, what="x")
|
|
|
|
hcb.id_print(x + 1, what="x + 1")
|
|
|
|
return 2
|
|
|
|
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a c.
|
|
|
|
let _ d = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print what='x') ] a c
|
|
|
|
b = add a 1
|
|
|
|
_ e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print what='x + 1') ] b d
|
|
|
|
in (2, e) }""", f, [1])
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_cond(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x, z):
|
2020-05-10 19:54:46 +03:00
|
|
|
return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)),
|
2020-05-08 17:18:11 +03:00
|
|
|
z, lambda a: (hcb.id_print(a), y))
|
|
|
|
self.assertRewrite("""
|
2020-09-24 14:24:02 +03:00
|
|
|
{ lambda a ; b c h.
|
|
|
|
let d = gt c 0
|
2020-12-13 10:44:20 +02:00
|
|
|
e = convert_element_type[ new_dtype=int32
|
|
|
|
weak_type=False ] d
|
2020-09-24 14:24:02 +03:00
|
|
|
f g i = cond[ branches=( { lambda ; a b c d f.
|
|
|
|
let e g = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] d f
|
|
|
|
in (e, a, g) }
|
|
|
|
{ lambda ; f_ a b c g.
|
|
|
|
let d = broadcast_in_dim[ broadcast_dimensions=( )
|
|
|
|
shape=(5,) ] 0.00
|
|
|
|
in (a, d, g) } )
|
|
|
|
linear=(False, False, False, False, False) ] e a 1 2 c h
|
|
|
|
in (f, g, i) }""", func, [y, 5])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_while(self):
|
2020-09-24 14:24:02 +03:00
|
|
|
ct_body = jnp.ones(5, np.float32) # captured const for the body
|
|
|
|
ct_cond = jnp.ones(5, np.float32) # captured const for the conditional
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def func(x):
|
2020-09-24 14:24:02 +03:00
|
|
|
# x: f32[5]
|
|
|
|
# c: (f32[5], f32)
|
2020-05-24 10:50:07 +03:00
|
|
|
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
|
|
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
|
|
|
|
(x, np.float32(1.)))
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertRewrite("""
|
2020-09-24 14:24:02 +03:00
|
|
|
{ lambda a b ; c f.
|
|
|
|
let d e g = while[ body_jaxpr={ lambda ; a b c f.
|
2020-09-14 02:47:28 -07:00
|
|
|
let d g = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
2020-09-24 14:24:02 +03:00
|
|
|
tap_func_=_print ] c f
|
2020-09-14 02:47:28 -07:00
|
|
|
e = add d 1.00
|
2020-09-24 14:24:02 +03:00
|
|
|
in (a, e, g) }
|
|
|
|
body_nconsts=1
|
|
|
|
cond_jaxpr={ lambda ; a b c g.
|
|
|
|
let d = add b a
|
2020-09-14 02:47:28 -07:00
|
|
|
e = reduce_sum[ axes=(0,) ] d
|
2020-09-24 14:24:02 +03:00
|
|
|
f = lt c e
|
2020-09-14 02:47:28 -07:00
|
|
|
in (f,) }
|
2020-09-24 14:24:02 +03:00
|
|
|
cond_nconsts=1 ] a b c 1.00 f
|
2020-09-14 02:47:28 -07:00
|
|
|
in (d, e, g) }""", func, [ct_body])
|
2020-05-24 10:50:07 +03:00
|
|
|
|
|
|
|
def test_while_pred_outfeed(self):
|
|
|
|
"""A while with outfeed in the pred."""
|
|
|
|
ct_body = jnp.ones(5) # captured const for the body
|
|
|
|
ct_cond = jnp.ones(2) # captured const for the conditional
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return lax.while_loop(lambda c: hcb.id_print(ct_cond, result=c[1]) < 5,
|
|
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
|
|
|
|
(x, 1))
|
|
|
|
|
2020-07-04 18:12:58 +03:00
|
|
|
self.assertRewrite("""
|
2020-09-24 14:24:02 +03:00
|
|
|
{ lambda a b ; c f.
|
|
|
|
let h i = xla_call[ call_jaxpr={ lambda ; a b c f.
|
|
|
|
let _ d g = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a c f
|
|
|
|
e = lt d 5
|
|
|
|
in (e, g) }
|
2020-09-14 02:47:28 -07:00
|
|
|
donated_invars=(False, False, False, False)
|
2020-09-24 14:24:02 +03:00
|
|
|
name=cond_before ] a c 1 f
|
|
|
|
y d e g =
|
|
|
|
while[ body_jaxpr={ lambda ; n o p q r s.
|
|
|
|
let t u v = xla_call[ call_jaxpr={ lambda ; a b c f.
|
|
|
|
let d g = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] c f
|
|
|
|
e = add d 1
|
|
|
|
in (a, e, g) }
|
2020-09-14 02:47:28 -07:00
|
|
|
donated_invars=(False, False, False, False)
|
2020-09-24 14:24:02 +03:00
|
|
|
name=body ] o q r s
|
|
|
|
w x = xla_call[ call_jaxpr={ lambda ; a b c f.
|
|
|
|
let _ d g = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a c f
|
|
|
|
e = lt d 5
|
|
|
|
in (e, g) }
|
2020-09-14 02:47:28 -07:00
|
|
|
donated_invars=(False, False, False, False)
|
2020-09-24 14:24:02 +03:00
|
|
|
name=cond_body ] n t u v
|
|
|
|
in (w, t, u, x) }
|
|
|
|
body_nconsts=2
|
|
|
|
cond_jaxpr={ lambda ; j k l m.
|
2020-09-24 15:08:07 +03:00
|
|
|
let
|
2020-09-24 14:24:02 +03:00
|
|
|
in (j,) }
|
|
|
|
cond_nconsts=0 ] a b h c 1 i
|
|
|
|
in (d, e, g) }""", func, [ct_body])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_scan(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x):
|
|
|
|
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
|
|
|
|
self.assertRewrite("""
|
2020-09-24 14:24:02 +03:00
|
|
|
{ lambda a ; b f.
|
2020-09-14 02:47:28 -07:00
|
|
|
let c d g e =
|
2020-09-24 14:24:02 +03:00
|
|
|
scan[ jaxpr={ lambda ; a b c g d.
|
|
|
|
let e f h = id_tap[ arg_treedef_=PyTreeDef(tuple, [*,*])
|
2020-09-14 02:47:28 -07:00
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=2
|
2020-09-24 14:24:02 +03:00
|
|
|
tap_func_=_print ] b c g
|
|
|
|
in (e, f, h, a) }
|
2020-09-14 02:47:28 -07:00
|
|
|
length=5
|
|
|
|
linear=(False, False, False, False, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=1
|
|
|
|
reverse=False
|
2020-09-24 14:24:02 +03:00
|
|
|
unroll=1 ] a 1 2 f b
|
2020-09-14 02:47:28 -07:00
|
|
|
in (c, d, e, g) }""", func, [y])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-08-12 09:20:26 +03:00
|
|
|
def test_scan_custom_jvp(self):
|
|
|
|
"""custom JVP, inside scan.
|
|
|
|
This exercises the custom_jvp_call_jaxpr primitives."""
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x * hcb.id_print(x)
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
x_dot, = tangents
|
|
|
|
primal_out = f(x)
|
|
|
|
tangent_out = 3. * x * hcb.id_print(x_dot)
|
|
|
|
return primal_out, tangent_out
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
# Sum f(x_i)
|
|
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
|
|
x)[0]
|
|
|
|
|
|
|
|
arg = np.full((5,), 0.7)
|
|
|
|
self.assertRewrite("""
|
2020-09-24 14:24:02 +03:00
|
|
|
{ lambda ; a c.
|
|
|
|
let b d _ = scan[ jaxpr={ lambda ; a e b.
|
|
|
|
let c f = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d.
|
|
|
|
let b e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a d
|
|
|
|
c = mul a b
|
|
|
|
in (c, e) }
|
2020-12-13 10:44:20 +02:00
|
|
|
num_consts=0 ] b e
|
2020-09-24 14:24:02 +03:00
|
|
|
d = add a c
|
|
|
|
in (d, f, 0.00) }
|
|
|
|
length=5
|
|
|
|
linear=(False, False, False)
|
|
|
|
num_carry=2
|
|
|
|
num_consts=0
|
|
|
|
reverse=False
|
|
|
|
unroll=1 ] 0.00 c a
|
|
|
|
in (b, d) }""", g, [arg])
|
2020-08-12 09:20:26 +03:00
|
|
|
self.assertRewrite("""
|
2020-12-13 10:44:20 +02:00
|
|
|
{ lambda ; a d.
|
|
|
|
let _ _ e _ b =
|
|
|
|
scan[ jaxpr={ lambda ; a b h c d.
|
|
|
|
let e i = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d.
|
|
|
|
let b e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a d
|
|
|
|
c = mul a b
|
|
|
|
in (c, e) }
|
|
|
|
num_consts=0 ] c h
|
|
|
|
f = add a e
|
|
|
|
g = mul c 3.00
|
|
|
|
in (f, *, i, 0.00, g) }
|
|
|
|
length=5
|
|
|
|
linear=(False, True, False, True, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=0
|
|
|
|
reverse=False
|
|
|
|
unroll=1 ] 0.00 * d a *
|
|
|
|
_ _ f _ c =
|
|
|
|
scan[ jaxpr={ lambda ; a b g c d.
|
|
|
|
let e = mul b d
|
|
|
|
f h = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print
|
|
|
|
transforms=(('transpose',),) ] e g
|
|
|
|
in (*, b, h, *, f) }
|
|
|
|
length=5
|
|
|
|
linear=(True, True, True, False, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=0
|
|
|
|
reverse=True
|
|
|
|
unroll=1 ] * 1.00 e * b
|
|
|
|
in (c, f) }""", api.grad(g), [arg])
|
2020-08-12 09:20:26 +03:00
|
|
|
|
|
|
|
def test_scan_custom_vjp(self):
|
|
|
|
"""custom VJP, inside scan.
|
|
|
|
This exercises the custom_vjp_call_jaxpr primitives."""
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return x * hcb.id_print(x)
|
|
|
|
|
|
|
|
# f_fwd: a -> (b, residual)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), 3. * x
|
|
|
|
# f_bwd: (residual, CT b) -> [CT a]
|
|
|
|
def f_bwd(residual, ct_b):
|
|
|
|
return residual * hcb.id_print(ct_b),
|
|
|
|
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
# Sum f(x_i)
|
|
|
|
return lax.scan(lambda carry, inp: (carry + f(inp), 0.),
|
|
|
|
np.full(x.shape[1:], 0.), # Like x w/o leading dim
|
|
|
|
x)[0]
|
|
|
|
|
|
|
|
arg = np.full((2,), 0.7)
|
|
|
|
self.assertRewrite("""
|
2020-12-13 10:44:20 +02:00
|
|
|
{ lambda ; a c.
|
|
|
|
let b d _ = scan[ jaxpr={ lambda ; a e b.
|
|
|
|
let c f = custom_vjp_call_jaxpr[
|
|
|
|
fun_jaxpr={ lambda ; a d.
|
|
|
|
let b e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a d
|
|
|
|
c = mul a b
|
|
|
|
in (c, e) }
|
|
|
|
num_consts=0
|
|
|
|
] b e
|
|
|
|
d = add a c
|
|
|
|
in (d, f, 0.00) }
|
|
|
|
length=2
|
|
|
|
linear=(False, False, False)
|
|
|
|
num_carry=2
|
|
|
|
num_consts=0
|
|
|
|
reverse=False
|
|
|
|
unroll=1 ] 0.00 c a
|
|
|
|
in (b, d) }""", g, [arg])
|
2020-08-12 09:20:26 +03:00
|
|
|
self.assertRewrite("""
|
2020-12-13 10:44:20 +02:00
|
|
|
{ lambda ; a d.
|
|
|
|
let _ _ e _ b =
|
|
|
|
scan[ jaxpr={ lambda ; a b h c d.
|
|
|
|
let e i = custom_vjp_call_jaxpr[
|
|
|
|
fun_jaxpr={ lambda ; a d.
|
|
|
|
let b e = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] a d
|
|
|
|
c = mul a b
|
|
|
|
in (c, e) }
|
|
|
|
num_consts=0
|
|
|
|
] c h
|
|
|
|
f = add a e
|
|
|
|
g = mul c 3.00
|
|
|
|
in (f, *, i, 0.00, g) }
|
|
|
|
length=2
|
|
|
|
linear=(False, True, False, True, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=0
|
|
|
|
reverse=False
|
|
|
|
unroll=1 ] 0.00 * d a *
|
|
|
|
_ _ f _ c =
|
|
|
|
scan[ jaxpr={ lambda ; a b g c d.
|
|
|
|
let e h = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] b g
|
|
|
|
f = mul d e
|
|
|
|
in (*, b, h, *, f) }
|
|
|
|
length=2
|
|
|
|
linear=(True, True, True, False, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=0
|
|
|
|
reverse=True
|
|
|
|
unroll=1 ] * 1.00 e * b
|
|
|
|
in (c, f) }""", api.grad(g), [arg])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2020-10-16 10:52:56 +03:00
|
|
|
def test_remat_loop(self):
|
|
|
|
def f(k, x):
|
|
|
|
x = hcb.id_print(k + x)
|
|
|
|
return -k * x
|
|
|
|
|
|
|
|
def loss(k):
|
|
|
|
return lax.fori_loop(0, 1, api.remat(f), k)
|
|
|
|
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a c.
|
|
|
|
let _ _ b d =
|
|
|
|
while[ body_jaxpr={ lambda ; a b c f.
|
|
|
|
let d = add a 1
|
|
|
|
e g = remat_call[ call_jaxpr={ lambda ; a b g.
|
|
|
|
let c = add a b
|
|
|
|
d h = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print ] c g
|
|
|
|
e = neg a
|
|
|
|
f = mul e d
|
|
|
|
in (f, h) }
|
|
|
|
concrete=False
|
|
|
|
name=f ] a c f
|
|
|
|
in (d, b, e, g) }
|
|
|
|
body_nconsts=0
|
|
|
|
cond_jaxpr={ lambda ; a b c e.
|
|
|
|
let d = lt a b
|
|
|
|
in (d,) }
|
|
|
|
cond_nconsts=0 ] 0 1 a c
|
|
|
|
in (b, d) }""", loss, [2])
|
|
|
|
|
2020-12-13 10:44:20 +02:00
|
|
|
def test_pmap(self):
|
|
|
|
def f(xv):
|
|
|
|
api.pmap(lambda x: jnp.sin(hcb.id_print(x, tap_with_device=True)),
|
|
|
|
axis_name="i")(xv)
|
|
|
|
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a b.
|
|
|
|
let _ c = xla_pmap[ axis_name=i
|
|
|
|
axis_size=1
|
|
|
|
backend=None
|
|
|
|
call_jaxpr={ lambda ; a e.
|
|
|
|
let b f = id_tap[ arg_treedef_=*
|
|
|
|
has_token_=True
|
|
|
|
nr_tapped_args_=1
|
|
|
|
tap_func_=_print
|
|
|
|
tap_with_device_=True ] a e
|
|
|
|
c = convert_element_type[ new_dtype=float32
|
|
|
|
weak_type=False ] b
|
|
|
|
d = sin c
|
|
|
|
in (d, f) }
|
|
|
|
devices=None
|
|
|
|
donated_invars=(False, False)
|
|
|
|
global_arg_shapes=(None,)
|
|
|
|
global_axis_size=None
|
|
|
|
in_axes=(0, 0)
|
|
|
|
name=<lambda>
|
|
|
|
out_axes=(0, 0) ] a b
|
|
|
|
in (c,) }""", f, [np.array([2])])
|
2020-10-16 10:52:56 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
if __name__ == "__main__":
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|