2020-05-08 17:18:11 +03:00
|
|
|
# Copyright 2020 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
import logging
|
2020-05-10 19:54:46 +03:00
|
|
|
import numpy as np
|
2020-05-08 17:18:11 +03:00
|
|
|
import os
|
|
|
|
import re
|
2020-06-01 11:49:35 -07:00
|
|
|
from typing import Callable, Sequence
|
2020-05-08 17:18:11 +03:00
|
|
|
from unittest import SkipTest
|
|
|
|
|
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
|
|
|
from jax import api
|
|
|
|
from jax import lax
|
2020-05-10 19:54:46 +03:00
|
|
|
from jax import numpy as jnp
|
2020-05-08 17:18:11 +03:00
|
|
|
from jax import test_util as jtu
|
|
|
|
from jax.config import config
|
|
|
|
from jax.experimental import host_callback as hcb
|
|
|
|
from jax.lib import xla_bridge
|
|
|
|
|
|
|
|
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
FLAGS = config.FLAGS
|
|
|
|
|
|
|
|
def skip_if_jit_not_enabled():
|
|
|
|
if os.getenv("JAX_ENABLE_JIT_PRINT", "false") == "false":
|
|
|
|
raise SkipTest("print jit not enabled yet; use JAX_ENABLE_JIT_PRINT env.")
|
|
|
|
|
2020-05-08 17:58:25 -07:00
|
|
|
def supported_dtypes():
|
2020-05-10 19:54:46 +03:00
|
|
|
return sorted(jtu.supported_dtypes(), key=lambda x: np.dtype(x).name)
|
2020-05-08 17:58:25 -07:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
class _TestingOutputStream(object):
|
|
|
|
"""Use as `output_stream` for tests."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._output = []
|
|
|
|
self.testMethodName = None
|
|
|
|
|
|
|
|
def write(self, what: str) -> None:
|
|
|
|
print(f"output_stream[{self.testMethodName}]: {what}", end="")
|
|
|
|
self._output.append(what)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def output(self):
|
|
|
|
return "".join(self._output)
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return "TestingOutputStream"
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self._output = []
|
|
|
|
|
|
|
|
|
|
|
|
testing_stream = _TestingOutputStream()
|
|
|
|
|
|
|
|
|
|
|
|
def fun1(a):
|
|
|
|
y = hcb.id_print(a * 2., what="a * 2", output_stream=testing_stream)
|
|
|
|
y = hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream, result=y)
|
|
|
|
return y**2 # Some computation to make the gradient interesting
|
|
|
|
|
|
|
|
|
|
|
|
def fun1_equiv(a): # Numerical equivalent of fun`
|
|
|
|
return (a * 2.)**2
|
|
|
|
|
|
|
|
def assertMultiLineStrippedEqual(tst: jtu.JaxTestCase, expected: str, what: str):
|
|
|
|
"""A variant that preprocesses the string to eliminate non-determinism in
|
|
|
|
floating point values, and several uninteresting id_tap primitive params."""
|
|
|
|
# Sometimes we get floating points in the output; we round them
|
|
|
|
def repl_floats(match_group):
|
|
|
|
matched = match_group.group(0)
|
|
|
|
if matched == ".": return matched
|
|
|
|
# TODO: why can't we use here np.around?
|
2020-05-10 19:54:46 +03:00
|
|
|
x = np.around(float(matched), decimals=2)
|
2020-05-08 17:18:11 +03:00
|
|
|
return f"{x:.2f}"
|
|
|
|
what = re.sub(r"\-?\d*\.[\-\def]*", repl_floats, what)
|
|
|
|
what = re.sub(r"output_stream=[^\]\n]*", "", what)
|
|
|
|
what = re.sub(r"threshold=[^\]\n]*", "", what)
|
|
|
|
# Empty lines
|
|
|
|
what = re.sub(r"^\s*\n", "", what, flags=re.MULTILINE)
|
|
|
|
def repl_func(match_group):
|
|
|
|
matched = match_group.group(0)
|
|
|
|
if "function _print_consumer" in matched:
|
|
|
|
return "func=_print"
|
|
|
|
else:
|
|
|
|
return "..."
|
|
|
|
what = re.sub(r"func=(.*)", repl_func, what)
|
|
|
|
tst.assertMultiLineStrippedEqual(expected, what)
|
|
|
|
|
|
|
|
class HostCallbackTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
testing_stream.reset()
|
|
|
|
testing_stream.testMethodName = self._testMethodName
|
|
|
|
self.old_flags = os.getenv("XLA_FLAGS", "")
|
|
|
|
|
|
|
|
def tearDown(self) -> None:
|
|
|
|
if os.getenv("XLA_FLAGS") != self.old_flags:
|
|
|
|
os.environ["XLA_FLAGS"] = self.old_flags
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def helper_set_devices(self, nr_devices):
|
|
|
|
flags_str = os.getenv("XLA_FLAGS", "")
|
|
|
|
os.environ["XLA_FLAGS"] = (
|
|
|
|
flags_str +
|
|
|
|
" --xla_force_host_platform_device_count={}".format(nr_devices))
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
return api.devices()
|
|
|
|
|
|
|
|
def helper_set_hlo_dump(self):
|
|
|
|
flags_str = os.getenv("XLA_FLAGS", "")
|
|
|
|
os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
|
|
def test_eval(self):
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul a 2.00
|
|
|
|
c = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
what=a * 2 ] b
|
|
|
|
d = mul c 3.00
|
|
|
|
e f = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
|
|
|
what=y * 3 ] d c
|
2020-05-18 17:54:20 -04:00
|
|
|
g = integer_pow[ y=2 ] f
|
2020-05-08 17:18:11 +03:00
|
|
|
in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
|
|
|
|
self.assertEqual("", testing_stream.output)
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose((5. * 2.) ** 2, fun1(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: a * 2
|
|
|
|
10.00
|
|
|
|
what: y * 3
|
|
|
|
30.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_tuple_results(self):
|
|
|
|
def func2(x):
|
|
|
|
x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
|
|
|
|
return x1 + y1
|
|
|
|
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul a 2.00
|
|
|
|
c = mul a 3.00
|
|
|
|
d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
|
|
|
func=_print
|
|
|
|
] b c
|
|
|
|
f = add d e
|
|
|
|
in (f,) }""", str(api.make_jaxpr(func2)(3.)))
|
|
|
|
with hcb.outfeed_receiver():
|
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
[ 6.00
|
|
|
|
9.00 ]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_dict_results(self):
|
|
|
|
def func2(x):
|
|
|
|
res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
|
|
|
|
return res["a"] + res["b"]
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver():
|
|
|
|
self.assertEqual(3. * (2. + 3.), func2(3.))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ a=6.00
|
|
|
|
b=9.00 }""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_with_result(self):
|
|
|
|
def func2(x):
|
|
|
|
x1 = hcb.id_print((x * 2., x * 3.), result=x * 4.,
|
|
|
|
output_stream=testing_stream)
|
|
|
|
return x1
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver():
|
|
|
|
self.assertEqual(3. * 4., func2(3.))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
[ 6.00
|
|
|
|
9.00 ]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_eval_tap_exception(self):
|
|
|
|
# Simulate a tap error
|
|
|
|
def tap_err(*args, **kwargs):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
|
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
|
|
|
with self.assertRaises(hcb.TapFunctionException):
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = func(0)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
# We should have received everything before the error
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x1
|
|
|
|
1
|
|
|
|
what: x3
|
|
|
|
3""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_simple(self):
|
|
|
|
jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
|
|
|
|
2. * x, what="here", output_stream=testing_stream))
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
res = jit_fun1(5.)
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(6. * 5., res)
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: here
|
|
|
|
10.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_jit_constant(self):
|
|
|
|
def func(x):
|
|
|
|
return hcb.id_print(42, result=x, output_stream=testing_stream)
|
|
|
|
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = xla_call[ backend=None
|
|
|
|
call_jaxpr={ lambda ; a.
|
|
|
|
let b c = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
|
|
|
] 42 a
|
|
|
|
in (c,) }
|
|
|
|
device=None
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
donated_invars=(False,)
|
2020-05-24 10:50:07 +03:00
|
|
|
name=func ] a
|
|
|
|
in (b,) }""", str(api.make_jaxpr(api.jit(func))(5)))
|
|
|
|
self.assertEqual("", testing_stream.output)
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(5, api.jit(func)(5))
|
2020-05-24 10:50:07 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
42""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_jit_sequence1(self):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
logging.info("%s: %s", self._testMethodName,
|
|
|
|
api.make_jaxpr(func)(1))
|
|
|
|
logging.info("%s: %s", self._testMethodName,
|
2020-05-11 17:43:55 -04:00
|
|
|
api.xla_computation(func)(1).as_hlo_text())
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertEqual(2, api.jit(func)(1))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit2(self):
|
|
|
|
"""A sequence of JIT."""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertEqual(2, api.jit(func)(1))
|
|
|
|
self.assertEqual(11, api.jit(func)(10))
|
|
|
|
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: 1
|
|
|
|
10
|
|
|
|
where: 2
|
|
|
|
11""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_nested(self):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
def func_nested(x):
|
|
|
|
x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
x3 = api.jit(func_nested)(x1)
|
|
|
|
return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream)
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertEqual(3, api.jit(func)(1))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: nested
|
|
|
|
2
|
|
|
|
where: 3
|
|
|
|
3""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_devices(self):
|
|
|
|
"""Running on multiple devices."""
|
|
|
|
devices = api.local_devices()
|
|
|
|
logging.info(f"{self._testMethodName}: has devices {devices}")
|
|
|
|
def func(x, device_id):
|
|
|
|
x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream)
|
|
|
|
return x2
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
for d in devices:
|
|
|
|
self.assertEqual(112, api.jit(func, device=d, static_argnums=1)(111, d.id))
|
|
|
|
logging.info(f"{self._testMethodName}: found output {testing_stream.output}")
|
|
|
|
self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output)))
|
|
|
|
self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output)))
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_pytree(self, with_jit=False):
|
|
|
|
def func(x, what=""):
|
|
|
|
"""Returns some pytrees depending on x"""
|
|
|
|
if what == "pair_1_x":
|
|
|
|
return (1, x)
|
|
|
|
elif what == "pair_x_2x":
|
|
|
|
return (x, 2 * x)
|
|
|
|
elif what == "dict":
|
|
|
|
return dict(a=2 * x, b=3 * x)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
tap_count = 0
|
|
|
|
def tap_func(a, what=""):
|
|
|
|
nonlocal tap_count
|
|
|
|
tap_count += 1
|
|
|
|
self.assertEqual(func(5, what), a)
|
|
|
|
|
|
|
|
transform = api.jit if with_jit else lambda f: f
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
for what in ("pair_1_x", "pair_x_2x", "dict"):
|
|
|
|
self.assertEqual(func(10, what),
|
|
|
|
transform(lambda x: hcb.id_tap(tap_func, func(x, what),
|
|
|
|
result=func(x * 2, what),
|
|
|
|
what=what))(5))
|
|
|
|
# Wait for receivers to be done
|
|
|
|
self.assertEqual(3, tap_count)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_cond(self, with_jit=False):
|
|
|
|
"""A conditional"""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: hcb.id_print(x, where="cond_t", output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="cond_f", result=x, output_stream=testing_stream),
|
|
|
|
x2 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream)
|
|
|
|
return x5
|
|
|
|
|
|
|
|
transform = api.jit if with_jit else lambda f: f
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertEqual(4, transform(func)(1))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: cond_f
|
|
|
|
-1
|
|
|
|
where: end
|
|
|
|
4""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_while_cond(self, with_jit=False):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
def body(x):
|
|
|
|
x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream)
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: hcb.id_print(x, where="w_b_t",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="w_b_f",
|
|
|
|
result=x, output_stream=testing_stream),
|
|
|
|
x3 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream)
|
|
|
|
x10 = lax.while_loop(lambda x: x <= 3, body, x2)
|
|
|
|
res = hcb.id_print(x10, where="end", output_stream=testing_stream)
|
|
|
|
return res
|
2020-05-24 10:50:07 +03:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
transform = api.jit if with_jit else lambda f: f
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertEqual(4, transform(func)(1))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: w_b_1
|
|
|
|
2
|
|
|
|
where: w_b_t
|
|
|
|
3
|
|
|
|
where: w_b_2
|
|
|
|
3
|
|
|
|
where: w_b_1
|
|
|
|
3
|
|
|
|
where: w_b_f
|
|
|
|
-1
|
|
|
|
where: w_b_2
|
|
|
|
4
|
|
|
|
where: end
|
|
|
|
4""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_jit_while_pred_tap(self):
|
2020-05-08 17:18:11 +03:00
|
|
|
"""While with printing in the conditional."""
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1")
|
2020-05-24 10:50:07 +03:00
|
|
|
x10 = lax.while_loop(lambda x: hcb.id_print(x < 3,
|
|
|
|
where="w_p",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x10, where="3", output_stream=testing_stream)
|
2020-05-08 17:18:11 +03:00
|
|
|
return res
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
2020-05-24 10:50:07 +03:00
|
|
|
self.assertEqual(3, api.jit(func)(1))
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self,
|
|
|
|
"""
|
2020-05-24 10:50:07 +03:00
|
|
|
where: w_p
|
|
|
|
True
|
|
|
|
where: w_b
|
|
|
|
2
|
|
|
|
where: w_p
|
|
|
|
True
|
|
|
|
where: w_b
|
|
|
|
3
|
|
|
|
where: w_p
|
|
|
|
False
|
|
|
|
where: 3
|
|
|
|
3""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_with_jit_{with_jit}",
|
|
|
|
with_jit=with_jit)
|
|
|
|
for with_jit in [True, False]))
|
|
|
|
def test_scan_cond(self, with_jit=False):
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
|
|
|
|
|
|
|
|
def body(c, x):
|
|
|
|
x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
|
|
|
|
x4 = lax.cond(x % 2 == 0,
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: hcb.id_print(x, where="s_t", output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream),
|
|
|
|
x3 + 1)
|
2020-05-08 17:18:11 +03:00
|
|
|
return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream))
|
|
|
|
|
2020-05-10 19:54:46 +03:00
|
|
|
_, x10 = lax.scan(body, x2, jnp.arange(3))
|
2020-05-08 17:18:11 +03:00
|
|
|
res = hcb.id_print(x10, where="10", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
if with_jit:
|
|
|
|
func = api.jit(func)
|
|
|
|
res = func(1)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(jnp.array([1, 2, 3]), res)
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
where: 1
|
|
|
|
1
|
|
|
|
where: 2
|
|
|
|
2
|
|
|
|
where: s_1
|
|
|
|
0
|
|
|
|
where: s_t
|
|
|
|
1
|
|
|
|
where: s_2
|
|
|
|
1
|
|
|
|
where: s_1
|
|
|
|
1
|
|
|
|
where: s_f
|
|
|
|
-1
|
|
|
|
where: s_2
|
|
|
|
2
|
|
|
|
where: s_1
|
|
|
|
2
|
|
|
|
where: s_t
|
|
|
|
3
|
|
|
|
where: s_2
|
|
|
|
3
|
|
|
|
where: 10
|
|
|
|
[1 2 3]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
jtu.cases_from_list(
|
|
|
|
dict(
|
|
|
|
testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
|
|
|
|
shape=shape,
|
|
|
|
dtype=dtype,
|
|
|
|
nr_args=nr_args) for nr_args in [1, 2]
|
|
|
|
for shape in [(), (2,), (2, 3), (2, 3, 4)]
|
2020-05-08 17:58:25 -07:00
|
|
|
for dtype in supported_dtypes()))
|
2020-05-10 19:54:46 +03:00
|
|
|
def test_jit_types(self, nr_args=2, dtype=jnp.int16, shape=(2,)):
|
|
|
|
if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
|
2020-05-08 17:18:11 +03:00
|
|
|
raise SkipTest(f"id_print jit not implemented for {dtype}.")
|
|
|
|
if jtu.device_under_test() == "tpu":
|
2020-05-10 19:54:46 +03:00
|
|
|
if dtype in (jnp.int16,):
|
2020-05-08 17:18:11 +03:00
|
|
|
raise SkipTest(f"transfering {dtype} not supported on TPU")
|
2020-05-28 17:16:56 -03:00
|
|
|
args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
|
2020-05-08 17:18:11 +03:00
|
|
|
if nr_args > 1:
|
|
|
|
args = args * nr_args
|
|
|
|
jit_fun1 = api.jit(lambda xs: hcb.id_print(
|
|
|
|
xs,
|
|
|
|
a_new_test="************",
|
|
|
|
testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = jit_fun1(args)
|
2020-06-01 17:19:23 -04:00
|
|
|
# self.assertAllClose(args, res)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_jit_large(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
|
2020-05-08 17:18:11 +03:00
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
api.jit(hcb.id_print)(arg)
|
|
|
|
|
|
|
|
def test_jit_several_together(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
|
2020-05-08 17:18:11 +03:00
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
2020-05-10 19:54:46 +03:00
|
|
|
api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(arg, jnp.ones(100, dtype=jnp.int32))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_jit_interleaving(self):
|
|
|
|
# Several jit's without data dependencies; they may interfere
|
|
|
|
count = 0 # Count tap invocations
|
|
|
|
nr_arrays = 5
|
|
|
|
def tap_func(arg, **kwargs):
|
|
|
|
nonlocal count
|
|
|
|
assert len(arg) == nr_arrays
|
|
|
|
count += 1
|
|
|
|
# This is the function that we'll run multiple times
|
|
|
|
def func(x, count):
|
|
|
|
for i in range(count):
|
|
|
|
x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1]
|
|
|
|
return x
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
2020-05-10 19:54:46 +03:00
|
|
|
x = jnp.array(1, dtype=np.int32)
|
2020-05-08 17:18:11 +03:00
|
|
|
res = 0
|
|
|
|
for i in range(10):
|
|
|
|
# No dependencies between the jit invocations
|
|
|
|
res += api.jit(lambda x: func(x, 10))(x)
|
|
|
|
self.assertEqual(100, count)
|
|
|
|
|
|
|
|
def test_jit_tap_exception(self):
|
|
|
|
# Simulate a tap error
|
|
|
|
def tap_err(*args, **kwargs):
|
|
|
|
raise NotImplementedError
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
|
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
|
|
|
with self.assertRaises(hcb.TapFunctionException):
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
res = api.jit(func)(0)
|
|
|
|
# Even though the receiver thread raised, the main thread should still
|
|
|
|
# return 3.
|
|
|
|
self.assertEqual(3, res)
|
|
|
|
# We should have received all others
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x1
|
|
|
|
1
|
|
|
|
what: x3
|
|
|
|
3""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_jit_unknown_tap(self):
|
|
|
|
# Simulate an unknown tap function
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
|
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
|
|
|
with self.assertRaises(hcb.TapFunctionException):
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
res = api.jit(func)(0)
|
|
|
|
# Even though the receiver thread raised, the main thread should still
|
|
|
|
# return 3.
|
|
|
|
self.assertEqual(3, res)
|
|
|
|
# We should have received all others
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x1
|
|
|
|
1
|
|
|
|
what: x3
|
|
|
|
3""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
# On CPU and GPU the device code blocks
|
|
|
|
# On GPU it seems that there is a 5 min timeout?
|
|
|
|
# On TPU the client does not block, but messes up the rest somehow
|
|
|
|
@jtu.skip_on_devices("cpu", "gpu", "tpu")
|
|
|
|
def test_jit_receiver_ends_prematurely(self):
|
|
|
|
# Simulate an unknown tap function
|
|
|
|
def func(x):
|
|
|
|
x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
|
|
|
|
x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1) # Will end the consumer loop
|
|
|
|
x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
|
|
|
|
return x3
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = api.jit(func)(0)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
assert False # It seems that the previous jit blocks above
|
|
|
|
|
|
|
|
def test_jit_error_no_consumer(self):
|
|
|
|
# Check for errors if starting jit without a consumer active
|
|
|
|
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
|
|
|
api.jit(lambda x: hcb.id_print(x))(0)
|
|
|
|
|
|
|
|
def test_jit_nested_cond_no_print(self):
|
|
|
|
"""A nested conditional, without any prints"""
|
|
|
|
raise SkipTest("skip this")
|
|
|
|
@api.jit
|
|
|
|
def cfun(x):
|
|
|
|
return lax.cond(
|
|
|
|
lax.lt(x, 2),
|
2020-05-14 09:02:29 -07:00
|
|
|
lambda x: x,
|
|
|
|
lambda x: lax.cond(x < 5,
|
|
|
|
3, lambda x: x,
|
|
|
|
4, lambda y: y),
|
|
|
|
x)
|
2020-05-11 17:43:55 -04:00
|
|
|
print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text())
|
2020-05-08 17:18:11 +03:00
|
|
|
cfun(1)
|
|
|
|
|
|
|
|
def test_while(self):
|
|
|
|
"""Executing while, even without JIT uses compiled code"""
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return lax.while_loop(
|
|
|
|
lambda c: c[1] < 5,
|
|
|
|
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
|
|
|
(x, 1))
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
func(y)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
1
|
|
|
|
2
|
|
|
|
3
|
|
|
|
4""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_while_error_no_receiver(self):
|
|
|
|
"""Executing while needs the receiver"""
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x):
|
|
|
|
return lax.while_loop(
|
|
|
|
lambda c: c[1] < 5,
|
|
|
|
lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
|
|
|
|
(x, 1))
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError, ".*outfeed_receiver.*not started"):
|
|
|
|
func(y).block_until_ready()
|
|
|
|
|
|
|
|
|
|
|
|
def test_jvp(self):
|
|
|
|
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = mul a 2.00
|
|
|
|
d = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=0
|
|
|
|
what=a * 2 ] c
|
|
|
|
e = mul d 3.00
|
|
|
|
f g = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
|
|
|
what=y * 3 ] e d
|
2020-05-18 17:54:20 -04:00
|
|
|
h = integer_pow[ y=2 ] g
|
2020-05-08 17:18:11 +03:00
|
|
|
i = mul b 2.00
|
|
|
|
j k = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('jvp',),)
|
2020-05-08 17:18:11 +03:00
|
|
|
what=a * 2 ] i d
|
|
|
|
l = mul j 3.00
|
|
|
|
m n o = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=2
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('jvp',),)
|
2020-05-08 17:18:11 +03:00
|
|
|
what=y * 3 ] l j f
|
2020-05-18 17:54:20 -04:00
|
|
|
p = mul 2.00 g
|
|
|
|
q = mul n p
|
|
|
|
in (h, q) }""",
|
2020-05-10 19:54:46 +03:00
|
|
|
str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1))))
|
2020-05-08 17:18:11 +03:00
|
|
|
with hcb.outfeed_receiver():
|
2020-05-10 19:54:46 +03:00
|
|
|
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(100., res_primals, check_dtypes=False)
|
|
|
|
self.assertAllClose(4., res_tangents, check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: a * 2
|
|
|
|
10.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'},) what: a * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
0.20
|
|
|
|
what: y * 3
|
|
|
|
30.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'},) what: y * 3
|
2020-05-08 17:18:11 +03:00
|
|
|
0.60""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_primal_unused(self):
|
|
|
|
# The output of id_print is not needed for backwards pass
|
|
|
|
def func(x):
|
|
|
|
return 2. * hcb.id_print(x * 3., what="x * 3", output_stream=testing_stream)
|
|
|
|
|
|
|
|
grad_func = api.grad(func)
|
|
|
|
with hcb.outfeed_receiver():
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let
|
|
|
|
in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
|
|
|
|
|
|
|
# Just making the Jaxpr invokes the id_print once
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
|
2020-05-08 17:18:11 +03:00
|
|
|
2.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
2020-06-02 17:37:20 -07:00
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
with hcb.outfeed_receiver():
|
2020-05-10 19:54:46 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
self.assertAllClose(6., res_grad, check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x * 3
|
|
|
|
15.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
|
2020-05-08 17:18:11 +03:00
|
|
|
2.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_simple(self):
|
|
|
|
def func(x):
|
|
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
|
|
|
return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
|
|
|
|
grad_func = api.grad(func)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul 1.00 a
|
|
|
|
c d = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('jvp',), ('transpose',))
|
2020-05-08 17:18:11 +03:00
|
|
|
what=y * 3 ] b 0.00
|
|
|
|
e = mul c 3.00
|
|
|
|
f g = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('jvp',), ('transpose',))
|
2020-05-08 17:18:11 +03:00
|
|
|
what=x * 2 ] e 0.00
|
|
|
|
h = mul f 2.00
|
|
|
|
i = mul a 2.00
|
|
|
|
j = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=0
|
|
|
|
what=x * 2 ] i
|
|
|
|
k = mul j 3.00
|
|
|
|
l = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=0
|
|
|
|
what=y * 3 ] k
|
|
|
|
m = mul 1.00 l
|
|
|
|
n = add_any h m
|
|
|
|
in (n,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
|
|
|
|
|
|
|
with hcb.outfeed_receiver():
|
2020-05-10 19:54:46 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x * 2
|
|
|
|
10.00
|
|
|
|
what: y * 3
|
|
|
|
30.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
|
2020-05-08 17:18:11 +03:00
|
|
|
5.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
15.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_grad_double(self):
|
|
|
|
def func(x):
|
|
|
|
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
|
|
|
return x * (y * 3.)
|
|
|
|
|
|
|
|
grad_func = api.grad(api.grad(func))
|
|
|
|
with hcb.outfeed_receiver():
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
2020-06-02 17:37:20 -07:00
|
|
|
let
|
2020-05-08 17:18:11 +03:00
|
|
|
in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
|
|
|
|
# Just making the Jaxpr invokes the id_print twiceonce
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
3.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
2.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
2020-05-10 19:54:46 +03:00
|
|
|
res_grad = grad_func(jnp.float32(5.))
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
self.assertAllClose(12., res_grad, check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
what: x * 2
|
|
|
|
10.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
15.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
2.00
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
3.00""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
vmap_fun1 = api.vmap(fun1)
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul a 2.00
|
|
|
|
c = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('batch', (0,)),)
|
2020-05-08 17:18:11 +03:00
|
|
|
what=a * 2 ] b
|
|
|
|
d = mul c 3.00
|
|
|
|
e f = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('batch', (0, 0)),)
|
2020-05-08 17:18:11 +03:00
|
|
|
what=y * 3 ] d c
|
2020-05-18 17:54:20 -04:00
|
|
|
g = integer_pow[ y=2 ] f
|
2020-05-08 17:18:11 +03:00
|
|
|
in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = vmap_fun1(vargs)
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
|
2020-05-08 17:18:11 +03:00
|
|
|
[ 8.00 10.00]
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
|
2020-05-08 17:18:11 +03:00
|
|
|
[24.00 30.00]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap_not_batched(self):
|
|
|
|
x = 3.
|
|
|
|
def func(y):
|
|
|
|
# x is not mapped, y is mapped
|
|
|
|
_, y = hcb.id_print((x, y), output_stream=testing_stream)
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
vmap_func = api.vmap(func)
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
|
|
|
func=_print
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms=(('batch', (None, 0)),) ] 3.00 a
|
2020-05-08 17:18:11 +03:00
|
|
|
d = add c 3.00
|
|
|
|
in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = vmap_func(vargs)
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
2020-05-23 13:49:27 +03:00
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
|
2020-05-08 17:18:11 +03:00
|
|
|
[ 3.00
|
2020-05-23 13:49:27 +03:00
|
|
|
[4.00 5.00] ]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_double_vmap(self):
|
|
|
|
# A 2D tensor with x[i, j] = i + j using 2 vmap
|
|
|
|
def sum(x, y):
|
|
|
|
return hcb.id_print(x + y, output_stream=testing_stream)
|
|
|
|
def sum_rows(xv, y):
|
|
|
|
return api.vmap(sum, in_axes=(0, None))(xv, y)
|
|
|
|
def sum_all(xv, yv):
|
|
|
|
return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)
|
|
|
|
|
|
|
|
xv = jnp.arange(5, dtype=np.int32)
|
|
|
|
yv = jnp.arange(3, dtype=np.int32)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = broadcast_in_dim[ broadcast_dimensions=(1,)
|
|
|
|
shape=(3, 5) ] a
|
|
|
|
d = reshape[ dimensions=None
|
|
|
|
new_sizes=(3, 1) ] b
|
|
|
|
e = add c d
|
|
|
|
f = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
transforms=(('batch', (0,)), ('batch', (0,))) ] e
|
|
|
|
in (f,) }""", str(api.make_jaxpr(sum_all)(xv, yv)))
|
|
|
|
with hcb.outfeed_receiver():
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = sum_all(xv, yv)
|
2020-05-23 13:49:27 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
|
|
|
|
[[0 1 2 3 4]
|
|
|
|
[1 2 3 4 5]
|
|
|
|
[2 3 4 5 6]]""", testing_stream.output)
|
2020-05-08 17:18:11 +03:00
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-24 10:50:07 +03:00
|
|
|
def test_vmap_while(self):
|
|
|
|
"""Vmap of while."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
# like max(x, 2)
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = lax.while_loop(lambda x: x < 2,
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
|
|
|
check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
|
|
|
|
[0 1 2 3 4]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
|
|
|
|
[1 2 3 4 5]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
|
|
|
|
[2 3 3 4 5]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
|
|
|
|
[2 2 2 3 4]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
def test_vmap_while_tap_cond(self):
|
|
|
|
"""Vmap of while, with a tap in the conditional."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
# like max(x, 2)
|
|
|
|
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
|
|
|
x2 = lax.while_loop(lambda x: hcb.id_print(x < 2, where="w_c",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
lambda x: hcb.id_print(x + 1, where="w_b",
|
|
|
|
output_stream=testing_stream),
|
|
|
|
x1)
|
|
|
|
res = hcb.id_print(x2, where="3", output_stream=testing_stream)
|
|
|
|
return res
|
|
|
|
|
|
|
|
inputs = np.arange(5, dtype=np.int32)
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs),
|
|
|
|
check_dtypes=False)
|
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
|
|
|
|
[0 1 2 3 4]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
|
|
|
|
[ True True False False False]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
|
|
|
|
[1 2 3 4 5]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
|
|
|
|
[ True False False False False]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
|
|
|
|
[2 3 3 4 5]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
|
|
|
|
[False False False False False]
|
|
|
|
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
|
|
|
|
[2 2 2 3 4]""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
2020-05-08 17:18:11 +03:00
|
|
|
def test_pmap(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
pmap_fun1 = api.pmap(fun1, axis_name="i")
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
res = pmap_fun1(vargs)
|
2020-05-10 19:54:46 +03:00
|
|
|
expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertAllClose(expected_res, res, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_pmap_error_no_receiver(self):
|
|
|
|
# Check for errors if starting jit without a consumer active
|
2020-05-10 19:54:46 +03:00
|
|
|
vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
|
2020-05-08 17:18:11 +03:00
|
|
|
with self.assertRaisesRegex(ValueError, "outfeed_receiver is not started"):
|
|
|
|
api.pmap(lambda x: hcb.id_print(x))(vargs)
|
|
|
|
|
|
|
|
def test_mask(self):
|
|
|
|
# TODO(necula)
|
|
|
|
raise SkipTest("masking has regressed")
|
|
|
|
@partial(api.mask, in_shapes=['n'], out_shape='')
|
|
|
|
def padded_sum(x):
|
2020-05-10 19:54:46 +03:00
|
|
|
return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
|
|
|
|
args = [jnp.arange(4)], dict(n=np.int64(2))
|
2020-05-08 17:18:11 +03:00
|
|
|
assertMultiLineStrippedEqual(self, """
|
|
|
|
{ lambda c f ; a b.
|
|
|
|
let d = lt c b
|
|
|
|
e = id_tap[ func=_print
|
|
|
|
logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
|
|
|
|
transforms=('mask',)
|
|
|
|
what=x ] a
|
|
|
|
g = select d e f
|
|
|
|
h = reduce_sum[ axes=(0,) ] g
|
|
|
|
in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))
|
|
|
|
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = padded_sum(*args)
|
2020-05-08 17:18:11 +03:00
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
logical_shapes: [(2,)] transforms: ('mask',) what: x
|
|
|
|
[0 1 2 3]
|
|
|
|
""", testing_stream.output)
|
|
|
|
testing_stream.reset()
|
|
|
|
|
|
|
|
class OutfeedRewriterTest(jtu.JaxTestCase):
|
|
|
|
def assertRewrite(self, expected: str, func: Callable, args: Sequence,
|
|
|
|
has_input_token=True, has_output_token=True):
|
|
|
|
"""Check that the rewrite of func(*args) matches expected."""
|
|
|
|
jaxpr = api.make_jaxpr(func)(*args)
|
|
|
|
assertMultiLineStrippedEqual(self, expected,
|
|
|
|
str(hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
|
|
|
|
|
|
|
|
def test_no_outfeed(self):
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, has_output_token=False)
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a d.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c,) }""", lambda x: x + x * x, [0], has_output_token=False)
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a d.
|
|
|
|
let b = mul a a
|
|
|
|
c = add a b
|
|
|
|
in (c, d) }""", lambda x: x + x * x, [0])
|
|
|
|
|
|
|
|
def test_simple_outfeed(self):
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda ; a d.
|
|
|
|
let b = add a a
|
|
|
|
c e = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
] b d
|
|
|
|
in (c, e) }""", lambda x: hcb.id_print(x + x), [0])
|
|
|
|
|
|
|
|
def test_cond(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x, z):
|
2020-05-10 19:54:46 +03:00
|
|
|
return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)),
|
2020-05-08 17:18:11 +03:00
|
|
|
z, lambda a: (hcb.id_print(a), y))
|
|
|
|
self.assertRewrite("""
|
2020-05-26 19:32:29 -07:00
|
|
|
{ lambda e f ; a b i.
|
2020-05-08 17:18:11 +03:00
|
|
|
let c = gt b 0
|
2020-05-26 19:32:29 -07:00
|
|
|
d = convert_element_type[ new_dtype=int32
|
|
|
|
old_dtype=bool ] c
|
|
|
|
g h j = cond[ branches=( { lambda ; f_ e a b c g.
|
|
|
|
let d h = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
] c g
|
|
|
|
in (d, e, h) }
|
|
|
|
{ lambda ; d g_ a b c h.
|
2020-06-02 17:37:20 -07:00
|
|
|
let
|
2020-05-26 19:32:29 -07:00
|
|
|
in (a, d, h) } )
|
|
|
|
linear=(False, False, False, False, False, False) ] d e f 1 2 b i
|
|
|
|
in (g, h, j) }""", func, [y, 5])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_while(self):
|
2020-05-24 10:50:07 +03:00
|
|
|
ct_body = jnp.ones(5, np.float32) # captured const for the body
|
|
|
|
ct_cond = jnp.ones(5, np.float32) # captured const for the conditional
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def func(x):
|
2020-05-24 10:50:07 +03:00
|
|
|
return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
|
|
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
|
|
|
|
(x, np.float32(1.)))
|
2020-05-08 17:18:11 +03:00
|
|
|
# TODO: we should not need to start a receiver here!!! I believe this is
|
|
|
|
# because of the partial evaluation of while, which calls impl, which
|
|
|
|
# uses JIT.
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertRewrite("""
|
2020-05-24 10:50:07 +03:00
|
|
|
{ lambda b c ; a f.
|
|
|
|
let d e g = while[ body_jaxpr={ lambda ; c a b f.
|
2020-05-08 17:18:11 +03:00
|
|
|
let d g = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
] b f
|
2020-05-24 10:50:07 +03:00
|
|
|
e = add d 1.00
|
2020-05-08 17:18:11 +03:00
|
|
|
in (c, e, g) }
|
|
|
|
body_nconsts=1
|
2020-05-24 10:50:07 +03:00
|
|
|
cond_jaxpr={ lambda ; c a b g.
|
|
|
|
let d = add a c
|
|
|
|
e = reduce_sum[ axes=(0,) ] d
|
|
|
|
f = lt b e
|
|
|
|
in (f,) }
|
|
|
|
cond_nconsts=1 ] b c a 1.00 f
|
|
|
|
in (d, e, g) }""", func, [ct_body])
|
|
|
|
|
|
|
|
def test_while_pred_outfeed(self):
|
|
|
|
"""A while with outfeed in the pred."""
|
|
|
|
ct_body = jnp.ones(5) # captured const for the body
|
|
|
|
ct_cond = jnp.ones(2) # captured const for the conditional
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
return lax.while_loop(lambda c: hcb.id_print(ct_cond, result=c[1]) < 5,
|
|
|
|
lambda c: (ct_body, hcb.id_print(c[1]) + 1),
|
|
|
|
(x, 1))
|
|
|
|
|
|
|
|
# TODO: we should not need to start a receiver here!!! I believe this is
|
|
|
|
# because of the partial evaluation of while, which calls impl, which
|
|
|
|
# uses JIT.
|
|
|
|
with hcb.outfeed_receiver(receiver_name=self._testMethodName):
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda b c ; a f.
|
|
|
|
let h i = xla_call[ call_jaxpr={ lambda ; c a b g.
|
|
|
|
let d e h = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
|
|
|
] c b g
|
|
|
|
f = lt e 5
|
|
|
|
in (f, h) }
|
|
|
|
name=cond_before ] b a 1 f
|
|
|
|
y d e g = while[ body_jaxpr={ lambda ; n o p q r s.
|
|
|
|
let t u v = xla_call[ call_jaxpr={ lambda ; c a b f.
|
|
|
|
let d g = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
] b f
|
|
|
|
e = add d 1
|
|
|
|
in (c, e, g) }
|
|
|
|
name=body ] o q r s
|
|
|
|
w x = xla_call[ call_jaxpr={ lambda ; c a b g.
|
|
|
|
let d e h = id_tap[ arg_treedef=*
|
|
|
|
func=_print
|
|
|
|
nr_untapped=1
|
|
|
|
] c b g
|
|
|
|
f = lt e 5
|
|
|
|
in (f, h) }
|
|
|
|
name=cond_body ] n t u v
|
|
|
|
in (w, t, u, x) }
|
|
|
|
body_nconsts=2
|
|
|
|
cond_jaxpr={ lambda ; j k l m.
|
2020-06-02 17:37:20 -07:00
|
|
|
let
|
2020-05-24 10:50:07 +03:00
|
|
|
in (j,) }
|
|
|
|
cond_nconsts=0 ] b c h a 1 i
|
|
|
|
in (d, 5, g) }""", func, [ct_body])
|
2020-05-08 17:18:11 +03:00
|
|
|
|
|
|
|
def test_scan(self):
|
2020-05-10 19:54:46 +03:00
|
|
|
y = jnp.ones(5) # captured const
|
2020-05-08 17:18:11 +03:00
|
|
|
def func(x):
|
|
|
|
return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
|
|
|
|
self.assertRewrite("""
|
|
|
|
{ lambda b ; a f.
|
|
|
|
let c d g e = scan[ jaxpr={ lambda ; f a b g c.
|
|
|
|
let d e h = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
|
|
|
|
func=_print
|
|
|
|
] a b g
|
|
|
|
in (d, e, h, f) }
|
|
|
|
length=5
|
|
|
|
linear=(False, False, False, False, False)
|
|
|
|
num_carry=3
|
|
|
|
num_consts=1
|
|
|
|
reverse=False ] b 1 2 f a
|
|
|
|
in (c, d, e, g) }""", func, [y])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main()
|