mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

By wrapping common operators in `jit`, we get a number of benefits: * `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive. * `jit` allows us to cache and reuse logic such as broadcasting and type promotion. One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars. On my laptop, this benchmark improves from 95us to 4us: ``` In [1]: import jax.numpy as jnp, jax In [2]: x = jax.device_put(7) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) In [3]: %timeit jnp.add(x, x).block_until_ready() 4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) ``` PiperOrigin-RevId: 389871450
5657 lines
170 KiB
Python
5657 lines
170 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import collections
|
|
from contextlib import contextmanager
|
|
import copy
|
|
import enum
|
|
from functools import partial
|
|
import operator
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import types
|
|
import unittest
|
|
import warnings
|
|
import weakref
|
|
import functools
|
|
import itertools as it
|
|
import operator as op
|
|
|
|
from absl import logging
|
|
from absl.testing import absltest, parameterized
|
|
import numpy as np
|
|
|
|
import concurrent.futures
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
|
|
from jax import core, dtypes, lax
|
|
from jax._src import api
|
|
from jax.core import Primitive
|
|
from jax.errors import UnexpectedTracerError
|
|
from jax.interpreters import ad
|
|
from jax.interpreters import xla
|
|
from jax.interpreters.sharded_jit import PartitionSpec as P
|
|
from jax.lib import xla_bridge as xb
|
|
from jax import test_util as jtu
|
|
from jax import tree_util
|
|
from jax import linear_util as lu
|
|
import jax._src.util
|
|
|
|
from jax.config import config
|
|
config.parse_flags_with_absl()
|
|
FLAGS = config.FLAGS
|
|
|
|
|
|
python_version = (sys.version_info[0], sys.version_info[1])
|
|
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
|
|
|
|
|
class CPPJitTest(jtu.BufferDonationTestCase):
|
|
"""Shared tests between the Python and the C++ jax,jit implementations.
|
|
|
|
Because the Python implementation supports more features, we need to have the
|
|
Python tests that extend the C++ tests (and not the other way around).
|
|
"""
|
|
|
|
@property
|
|
def jit(self):
|
|
# Right now, the CPP tests also test the Python code-path when jaxlib is
|
|
# too old.
|
|
# TODO(jblespiau,phawkins): Remove this when jaxlib has been released.
|
|
# This is in the future, because we are making a breaking change to
|
|
# Tensorflow.
|
|
return api._cpp_jit
|
|
|
|
def test_jit_of_noncallable(self):
|
|
self.assertRaisesRegex(TypeError, "Expected a callable value.*",
|
|
lambda: self.jit(3))
|
|
|
|
def test_jit_of_generator(self):
|
|
|
|
def gen(x):
|
|
yield x
|
|
|
|
self.assertRaisesRegex(TypeError,
|
|
"Expected a function, got a generator function.*",
|
|
lambda: self.jit(gen))
|
|
|
|
@parameterized.parameters([
|
|
# Integer support
|
|
(1, 2, 3, 4, 5),
|
|
# Numpy array support
|
|
(
|
|
np.asarray(1, np.int32),
|
|
np.asarray(2, np.int32),
|
|
np.asarray(3, np.int32),
|
|
np.asarray(4, np.int32),
|
|
np.asarray(5, np.int32),
|
|
),
|
|
])
|
|
def test_jit_static_args(self, one, two, three, four, five):
|
|
side = []
|
|
|
|
def f(x, y, z, flag=False, flag2=False):
|
|
del flag2 # unused
|
|
assert flag
|
|
side.append(None)
|
|
return 100 * x + 10 * y + z
|
|
|
|
f1 = self.jit(f, static_argnums=(3, 4))
|
|
assert f1(one, two, three, True, False) == 123
|
|
assert len(side) == 1
|
|
assert f1(one, two, three, True, False) == 123
|
|
assert len(side) == 1 # Obvious cache hit.
|
|
assert f1(two, one, three, True, False) == 213
|
|
assert len(side) == 1 # Should cache hit because same signature.
|
|
assert f1(two, one, three, True, True) == 213
|
|
assert len(side) == 2
|
|
|
|
side[:] = []
|
|
f2 = self.jit(f, static_argnums=(0, 2, 3, 4))
|
|
assert f2(1, 2, 3, True, False) == 123
|
|
assert len(side) == 1
|
|
assert f2(1, 3, 3, True, False) == 133
|
|
assert len(side) == 1
|
|
assert f2(2, 2, 3, True, False) == 223
|
|
assert len(side) == 2
|
|
assert f2(2, 4, 3, True, False) == 243
|
|
assert len(side) == 2
|
|
assert f2(2, 4, 3, True, True) == 243
|
|
assert len(side) == 3
|
|
assert f2(2, 5, 3, True, True) == 253
|
|
assert len(side) == 3
|
|
|
|
def test_static_args_equality(self):
|
|
class A():
|
|
|
|
def __hash__(self):
|
|
return 1
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, A)
|
|
|
|
side = []
|
|
def f(x, static_arg):
|
|
del static_arg
|
|
side.append(None)
|
|
return x * 100
|
|
|
|
f1 = self.jit(f, static_argnums=(1,))
|
|
|
|
self.assertEqual(f1(1, A()), 100)
|
|
self.assertLen(side, 1)
|
|
self.assertEqual(f1(1, A()), 100)
|
|
self.assertLen(side, 1)
|
|
if self.jit == api._cpp_jit:
|
|
f1_cpp = getattr(f1, "_cpp_jitted_f", f1)
|
|
self.assertEqual(f1_cpp._cache_size(), 1)
|
|
|
|
@parameterized.parameters([
|
|
(1, 2, 3),
|
|
(
|
|
np.asarray(1, np.int32),
|
|
np.asarray(2, np.int32),
|
|
np.asarray(3, np.int32),
|
|
),
|
|
])
|
|
def test_jit_kwargs(self, one, two, three):
|
|
side = []
|
|
# For the CPP jit, we need to clear the cache to prevent cache hits between
|
|
# parameterized tests.
|
|
if hasattr(self.jit, "cache_clear"):
|
|
self.jit.cache_clear()
|
|
|
|
def f(x, y, z):
|
|
side.append(None)
|
|
return 100 * x + 10 * y + z
|
|
|
|
f = self.jit(f)
|
|
assert f(one, two, three) == 123
|
|
assert len(side) == 1
|
|
assert f(one, two, three) == 123
|
|
assert len(side) == 1
|
|
|
|
assert f(one, two, z=three) == 123
|
|
assert len(side) == 2 # actually recompiles from kwarg
|
|
assert f(one, two, z=three) == 123
|
|
assert len(side) == 2 # but should still cache
|
|
|
|
f(one, two, z=np.zeros(3)) # doesn't crash
|
|
if config.x64_enabled:
|
|
# In the above call, three is of a new type (int64), thus it should
|
|
# trigger a new compilation.
|
|
assert len(side) == 3
|
|
|
|
def test_jit_device(self):
|
|
device = xb.devices()[-1]
|
|
x = self.jit(lambda x: x, device=device)(3.)
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
self.assertEqual(x.device_buffer.device(), device)
|
|
|
|
def test_complex_support(self):
|
|
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
|
|
|
|
def test_jit_with_many_args_works(self):
|
|
|
|
@self.jit
|
|
def f(args_list):
|
|
return sum(args_list)
|
|
|
|
self.assertEqual(f(list(range(500))), sum(range(500)))
|
|
|
|
# Jit and Donate arguments
|
|
|
|
def test_jit_donate_argnums_warning_raised(self):
|
|
x = jnp.array([1.0, 2.0], jnp.float32)
|
|
y = jnp.array([1, 2], jnp.int32)
|
|
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
f(x, y)
|
|
|
|
self.assertLen(w, 1)
|
|
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
|
self.assertIn(
|
|
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
|
|
str(w[-1].message))
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
def test_jit_donate_argnums_invalidates_input(self):
|
|
# We can't just use `lambda x: x` because JAX simplifies this away to an
|
|
# empty XLA computation.
|
|
move = self.jit(lambda x: x + x - x, donate_argnums=0)
|
|
x = jnp.ones([])
|
|
y = move(x)
|
|
self.assertDeleted(x)
|
|
self.assertEqual(y, 1.)
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
def test_jit_donate_argnums_static_argnums(self):
|
|
jit_fun = self.jit(
|
|
lambda a, b, c, d: ((a + b + c), (a + b + d)),
|
|
static_argnums=(0, 1),
|
|
donate_argnums=(2, 3))
|
|
|
|
c = jax.device_put(jnp.array([1., 1.]))
|
|
d = jax.device_put(jnp.array([1., 1., 1.]))
|
|
e, f = jit_fun(1, 2, c, d)
|
|
np.testing.assert_allclose(e, jnp.array([4., 4.]))
|
|
np.testing.assert_allclose(f, jnp.array([4., 4., 4.]))
|
|
self.assertDeleted(c)
|
|
self.assertDeleted(d)
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
def test_jnp_array_copy(self):
|
|
# https://github.com/google/jax/issues/3412
|
|
|
|
@partial(self.jit, donate_argnums=(0,))
|
|
def _test(array):
|
|
return array.at[0].set(77)
|
|
|
|
x = jnp.asarray([0, 1])
|
|
x_copy = jnp.array(x, copy=True)
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
_test(x) # donation
|
|
|
|
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
|
|
print(x_copy) # doesn't crash
|
|
|
|
def test_jit_global_cache(self):
|
|
def f(x):
|
|
assert python_should_be_executing
|
|
return x
|
|
|
|
python_should_be_executing = True
|
|
self.jit(f)(2)
|
|
python_should_be_executing = False
|
|
self.jit(f)(3)
|
|
|
|
def test_jit_shallow_copy(self):
|
|
def f(x):
|
|
return copy.copy(x)
|
|
self.jit(f)(1)
|
|
|
|
def test_jit_deep_copy(self):
|
|
def f(x):
|
|
return copy.deepcopy(x)
|
|
self.jit(f)(1)
|
|
|
|
def test_disable_jit(self):
|
|
effects = []
|
|
|
|
@self.jit
|
|
def f(x):
|
|
effects.append(1)
|
|
return x
|
|
|
|
with api.disable_jit():
|
|
f(2)
|
|
f(2)
|
|
assert len(effects) == 2
|
|
|
|
f(2)
|
|
f(2)
|
|
assert len(effects) == 3
|
|
|
|
def test_static_argnum_on_method(self):
|
|
|
|
class A:
|
|
|
|
@functools.partial(self.jit, static_argnums=(0,))
|
|
def my_func_jit(self, x):
|
|
return x+2
|
|
|
|
A().my_func_jit(3)
|
|
|
|
def test_static_argnum_on_static_method_is_not_supported(self):
|
|
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
|
|
|
|
class A:
|
|
|
|
@functools.partial(self.jit, static_argnums=(0,))
|
|
@classmethod
|
|
def my_classmethod_jit(cls, x):
|
|
return x+2
|
|
|
|
def test_classmethod_is_not_supported(self):
|
|
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
|
|
|
|
class A:
|
|
|
|
@functools.partial(self.jit)
|
|
@staticmethod
|
|
def my_staticmethod_jit(x):
|
|
return x + 2
|
|
|
|
def test_concurrent_jit(self):
|
|
@self.jit
|
|
def f(x):
|
|
return x + x - 3.
|
|
|
|
xs = [np.random.randn(i) for i in range(10)]
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
ys = [f.result() for f in futures]
|
|
for x, y in zip(xs, ys):
|
|
self.assertAllClose(x * 2 - 3., y)
|
|
|
|
def test_trivial_computations(self):
|
|
x = jnp.array([1, 2, 3])
|
|
y = self.jit(lambda x: x)(x)
|
|
self.assertIs(x, y)
|
|
|
|
z1, z2 = self.jit(lambda x: (x, x))(x)
|
|
self.assertIs(z1, z2)
|
|
|
|
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
|
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
|
|
self.assertIs(z1, x2)
|
|
self.assertIs(z3, x1)
|
|
self.assertEqual(z2, 1)
|
|
|
|
def test_trivial_computations_with_tokens(self):
|
|
@self.jit
|
|
def noop(arr, token):
|
|
return arr, token
|
|
|
|
arr = jax.numpy.ones(10)
|
|
token = jax.lax.create_token()
|
|
|
|
self.assertEqual(token, noop(arr, token)[1])
|
|
|
|
def test_jit_bad_input(self):
|
|
def f(x):
|
|
return x
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
lambda: self.jit(f)("foo"))
|
|
|
|
def test_jit_on_all_devices(self):
|
|
# Verifies we can run the same computation on every device present, even
|
|
# if they are, for example, different models of GPU.
|
|
data = np.random.rand(1000).astype(np.float32)
|
|
f = self.jit(jnp.negative)
|
|
for device in jax.local_devices():
|
|
x = device_put(data, device=device)
|
|
np.testing.assert_array_equal(-data, f(x))
|
|
|
|
def test_jit_nested_donate_ignored(self):
|
|
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
|
|
a = jax.device_put(jnp.array(1))
|
|
|
|
# NOTE(mattjj): stopped raising error here and instead just ignored
|
|
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
|
# jit_fun(a)
|
|
|
|
jit_fun(a) # doesn't crash
|
|
|
|
def test_jit_reference_dropping(self):
|
|
x = jnp.ones(10)
|
|
f = (lambda x: lambda: x)(x) # reference to x in f's closure
|
|
g = self.jit(f)
|
|
x = weakref.ref(x) # no more strong ref to x in this scope
|
|
assert x() is not None # x is still around
|
|
f() # f runs
|
|
g() # g runs
|
|
g() # g runs a second time
|
|
del f # delete the raw callable
|
|
assert x() is not None # x is still around
|
|
g() # g still runs
|
|
del g # no more references to x
|
|
assert x() is None # x is gone
|
|
|
|
def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self):
|
|
if self.jit != api._python_jit:
|
|
raise unittest.SkipTest("this test only applies to _python_jit")
|
|
f = lambda x, y: x + 3
|
|
jitted_f = self.jit(f, static_argnums=(1,))
|
|
|
|
msg = ("Non-hashable static arguments are not supported, as this can lead "
|
|
"to unexpected cache-misses. Static argument (index 1) of type "
|
|
"<class 'numpy.ndarray'> for function <lambda> is non-hashable.")
|
|
with self.assertRaisesRegex(ValueError, re.escape(msg)):
|
|
jitted_f(1, np.asarray(1))
|
|
|
|
def test_cpp_jit_raises_on_non_hashable_static_argnum(self):
|
|
if self.jit != api._cpp_jit:
|
|
raise unittest.SkipTest("this test only applies to _cpp_jit")
|
|
|
|
f = lambda x, y: x + 3
|
|
jitted_f = api._cpp_jit(f, static_argnums=[1])
|
|
|
|
jitted_f(1, 1)
|
|
|
|
msg = ("Non-hashable static arguments are not supported. An error occured "
|
|
"while trying to hash an object of type <class 'numpy.ndarray'>, 1. "
|
|
"The error was:\nTypeError: unhashable type: 'numpy.ndarray'")
|
|
|
|
with self.assertRaisesRegex(ValueError, re.escape(msg)):
|
|
jitted_f(1, np.asarray(1))
|
|
|
|
class HashableWithoutEq:
|
|
|
|
def __hash__(self):
|
|
return 1
|
|
|
|
def __eq__(self, other):
|
|
raise NotImplementedError(
|
|
"A Python error is as is, without stack trace")
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
re.escape("static arguments should be comparable using __eq__")):
|
|
jitted_f(1, HashableWithoutEq())
|
|
|
|
def test_cpp_jitted_function_returns_PyBuffer(self):
|
|
if self.jit != api._cpp_jit:
|
|
raise unittest.SkipTest("this test only applies to _cpp_jit")
|
|
|
|
jitted_f = self.jit(lambda a: a + 1)
|
|
jitted_f(1)
|
|
self.assertIsInstance(jitted_f(2), xla._CppDeviceArray)
|
|
|
|
@jtu.skip_on_devices("cpu")
|
|
def test_explicit_backend(self):
|
|
f = lambda x: x + 1
|
|
jitted_f = jit(f, backend=jtu.device_under_test())
|
|
jitted_f_cpu = jit(f, backend="cpu")
|
|
|
|
result = jitted_f(1.)
|
|
result_cpu = jitted_f_cpu(1.)
|
|
self.assertEqual(result.device_buffer.platform(), jtu.device_under_test())
|
|
self.assertEqual(result_cpu.device_buffer.platform(), "cpu")
|
|
|
|
@jtu.skip_on_devices("cpu")
|
|
def test_device_to_device_copy_between_backends(self):
|
|
# b/186624243
|
|
f = lambda x: x + 1
|
|
jitted_f = jit(f, backend=jtu.device_under_test())
|
|
jitted_f_cpu = jit(f, backend="cpu")
|
|
|
|
x = np.arange(30).reshape(1, 10, 3)
|
|
result = jitted_f(x)
|
|
result_cpu = jitted_f_cpu(result)
|
|
result_2 = jitted_f(result_cpu)
|
|
result_cpu_2 = jitted_f_cpu(result_2)
|
|
self.assertAllClose(result_2, x + 3)
|
|
self.assertAllClose(result_cpu_2, x + 4)
|
|
|
|
@jtu.skip_on_devices("cpu")
|
|
def test_mismatched_nested_backends(self):
|
|
@partial(jit, backend=jtu.device_under_test())
|
|
def f(x):
|
|
return jit(lambda x: x + 1, backend="cpu")(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
f"Outer-jit backend specification {jtu.device_under_test()} must match "
|
|
f"explicit inner-jit backend specification cpu."):
|
|
f(1.)
|
|
|
|
def test_omnistaging(self):
|
|
# See https://github.com/google/jax/issues/5206
|
|
key_list = [None]
|
|
|
|
def init():
|
|
key, subkey = jax.random.split(key_list[0])
|
|
key_list[0] = key
|
|
return jax.random.normal(subkey, ())
|
|
|
|
key_list[0] = np.array([2384771982, 3928867769], dtype=np.uint32)
|
|
init()
|
|
self.jit(init)()
|
|
self.assertIsInstance(key_list[0], core.Tracer)
|
|
|
|
def test_jit_wrapped_attributes(self):
|
|
def f(x: int) -> int:
|
|
"""docstring of f."""
|
|
return x + 1
|
|
f.some_value = 4
|
|
jf = self.jit(f)
|
|
for attr in ["doc", "name", "module", "qualname", "annotations"]:
|
|
self.assertEqual(
|
|
{attr: getattr(f, f"__{attr}__")},
|
|
{attr: getattr(jf, f"__{attr}__")})
|
|
self.assertEqual(f.some_value, jf.some_value)
|
|
|
|
def test_jit_python_builtin(self):
|
|
x = jnp.array([1, 2])
|
|
expected = x + 1
|
|
jit_add = self.jit(operator.add, static_argnums=(1,))
|
|
actual = jit_add(x, 1)
|
|
self.assertArraysEqual(expected, actual)
|
|
|
|
def test__infer_argnums_and_argnames(self):
|
|
def f(x, y=1):
|
|
pass
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
f, argnums=None, argnames=None)
|
|
assert argnums == ()
|
|
assert argnames == ()
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
f, argnums=0, argnames=None)
|
|
assert argnums == (0,)
|
|
assert argnames == ('x',)
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
f, argnums=None, argnames='y')
|
|
assert argnums == (1,)
|
|
assert argnames == ('y',)
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
f, argnums=0, argnames='y') # no validation
|
|
assert argnums == (0,)
|
|
assert argnames == ('y',)
|
|
|
|
def g(x, y, *args):
|
|
pass
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
g, argnums=(1, 2), argnames=None)
|
|
assert argnums == (1, 2)
|
|
assert argnames == ('y',)
|
|
|
|
def h(x, y, **kwargs):
|
|
pass
|
|
|
|
argnums, argnames = api._infer_argnums_and_argnames(
|
|
h, argnums=None, argnames=('foo', 'bar'))
|
|
assert argnums == ()
|
|
assert argnames == ('foo', 'bar')
|
|
|
|
def test_jit_with_static_argnames(self):
|
|
|
|
def f(x):
|
|
assert x == 'foo'
|
|
return 1
|
|
|
|
f_nums = self.jit(f, static_argnums=0)
|
|
assert f_nums('foo') == 1
|
|
assert f_nums(x='foo') == 1
|
|
|
|
f_names = self.jit(f, static_argnames='x')
|
|
assert f_names('foo') == 1
|
|
assert f_names(x='foo') == 1
|
|
|
|
def test_new_static_argnum_on_keyword_arguments(self):
|
|
f = self.jit(lambda x: x, static_argnums=0)
|
|
y = f(x=4)
|
|
assert y == 4
|
|
|
|
def test_new_static_argnum_with_default_arguments(self):
|
|
f = self.jit(lambda x=4: x, static_argnums=0)
|
|
y = f()
|
|
assert y == 4
|
|
|
|
def test_jit_with_mismatched_static_argnames(self):
|
|
x_is_tracer, y_is_tracer = False, False
|
|
def f(x, y):
|
|
assert isinstance(x, core.Tracer) == x_is_tracer
|
|
assert isinstance(y, core.Tracer) == y_is_tracer
|
|
return 1
|
|
|
|
# If both static_argnums and static_argnames are provided, they are allowed
|
|
# to disagree and `jit` will respect the user's choices.
|
|
f_nums = self.jit(f, static_argnums=1, static_argnames=())
|
|
x_is_tracer, y_is_tracer = True, False
|
|
assert f_nums(2, 'foo') == 1
|
|
x_is_tracer, y_is_tracer = True, True
|
|
assert f_nums(1, y=2) == 1
|
|
|
|
f_names = self.jit(f, static_argnums=(), static_argnames='y')
|
|
x_is_tracer, y_is_tracer = True, True
|
|
assert f_names(2, 3) == 1
|
|
x_is_tracer, y_is_tracer = True, False
|
|
assert f_names(1, y='foo') == 1
|
|
|
|
f_mixed = self.jit(f, static_argnums=(1,), static_argnames='x')
|
|
x_is_tracer, y_is_tracer = True, False
|
|
assert f_mixed(2, 'foo') == 1
|
|
x_is_tracer, y_is_tracer = True, True
|
|
assert f_mixed(1, y=3) == 1
|
|
x_is_tracer, y_is_tracer = False, True
|
|
assert f_mixed(x='foo', y=3) == 1
|
|
|
|
# TODO(zhangqiaorjc): Test pruning constants after DCE pass prunes primitive
|
|
# applications.
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_num_args={}".format(num_args),
|
|
"num_args": num_args}
|
|
for num_args in [2, 3, 4]))
|
|
def test_jit_with_pruned_args(self, num_args):
|
|
def f(*args):
|
|
used = np.array(2)
|
|
return args[1] + used
|
|
f_pruned = self.jit(f)
|
|
args = range(num_args)
|
|
with jtu.count_device_put() as count:
|
|
np.testing.assert_allclose(f_pruned(*args), 3)
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
|
class PythonJitTest(CPPJitTest):
|
|
|
|
@property
|
|
def jit(self):
|
|
return api._python_jit
|
|
|
|
|
|
class APITest(jtu.JaxTestCase):
|
|
|
|
def test_grad_bad_input(self):
|
|
def f(x):
|
|
return x
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
lambda: grad(f)("foo"))
|
|
|
|
def test_grad_argnums(self):
|
|
def f(x, y, z, flag=False):
|
|
assert flag
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
|
|
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
|
|
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
|
|
|
|
def test_value_and_grad_argnums(self):
|
|
def f(x, y, z, flag=False):
|
|
assert flag
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
y = f(1.0, 1.0, 1.0, flag=True)
|
|
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
|
|
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
|
|
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
|
|
|
|
def test_grad_of_jit(self):
|
|
side = []
|
|
|
|
@jit
|
|
def f(x):
|
|
side.append(None)
|
|
return x * x
|
|
|
|
assert grad(f)(1.0) == 2.0
|
|
assert len(side) == 1
|
|
assert grad(f)(2.0) == 4.0
|
|
assert len(side) == 1
|
|
|
|
def test_jit_of_grad(self):
|
|
side = []
|
|
|
|
@jit
|
|
def f(x):
|
|
side.append(None)
|
|
return x * x
|
|
|
|
g = jit(grad(f))
|
|
assert g(1.0) == 2.0
|
|
assert len(side) == 1
|
|
assert g(2.0) == 4.0
|
|
assert len(side) == 1
|
|
|
|
def test_bad_input(self):
|
|
def f(x):
|
|
return x
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
lambda: grad(f)("foo"))
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
lambda: jit(f)("foo"))
|
|
|
|
def test_grad_tuple_output(self):
|
|
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
|
|
"Gradient only defined for scalar-output functions. ")
|
|
|
|
def test_grad_unit_output(self):
|
|
jtu.check_raises(lambda: grad(lambda x: ())(np.zeros(3)), TypeError,
|
|
"Gradient only defined for scalar-output functions. ")
|
|
|
|
def test_grad_nonscalar_output(self):
|
|
jtu.check_raises(lambda: grad(lambda x: x)(np.zeros(3)), TypeError,
|
|
"Gradient only defined for scalar-output functions. ")
|
|
|
|
def test_unwrapped_numpy(self):
|
|
def f(x):
|
|
return np.exp(x)
|
|
|
|
with self.assertRaisesRegex(Exception, "The numpy.ndarray conversion .*"):
|
|
grad(f)(np.zeros(3))
|
|
|
|
def test_binop_mismatch(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
jtu.check_raises(
|
|
lambda: f(jnp.zeros(3), jnp.zeros(4)),
|
|
TypeError,
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
|
|
|
jtu.check_raises(
|
|
lambda: grad(f)(np.zeros(3), np.zeros(4)),
|
|
TypeError,
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
|
|
|
def test_dot_mismatch(self):
|
|
def f(x, y):
|
|
return jnp.dot(x, y)
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
|
|
lambda: grad(f)(np.zeros(3), np.zeros(4)))
|
|
|
|
def test_abstract_error_message(self):
|
|
for castfun in [float, complex, int]:
|
|
def f(x):
|
|
return castfun(x)
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
|
|
lambda: jit(f)(1.0))
|
|
|
|
def test_switch_value_jit(self):
|
|
def f(x):
|
|
y = x > 0
|
|
if y:
|
|
return x
|
|
else:
|
|
return -x
|
|
|
|
assert grad(f)(1.0) == 1.0
|
|
assert grad(f)(-1.0) == -1.0
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
|
"Abstract tracer value"):
|
|
jit(f)(1)
|
|
|
|
def test_list_index_err(self):
|
|
L = [1, 2, 3]
|
|
def f(n):
|
|
return L[n]
|
|
|
|
assert jit(f, static_argnums=(0,))(0) == L[0]
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"The __index__\(\) method was called on the JAX Tracer object.*",
|
|
lambda: jit(f)(0))
|
|
|
|
def test_range_err(self):
|
|
def f(x, n):
|
|
for i in range(n):
|
|
x = x + i
|
|
return x
|
|
|
|
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"The __index__\(\) method was called on the JAX Tracer object.*",
|
|
lambda: jit(f)(0, 5))
|
|
|
|
def test_cast_int(self):
|
|
f = lambda x: int(x)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
|
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
|
|
|
|
def test_casts(self):
|
|
for castfun in [hex, oct]:
|
|
f = lambda x: castfun(x)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"The __index__\(\) method was called on the JAX Tracer object.*", lambda: jit(f)(0))
|
|
|
|
def test_unimplemented_interpreter_rules(self):
|
|
foo_p = Primitive('foo')
|
|
def foo(x):
|
|
return foo_p.bind(x)
|
|
|
|
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
|
|
"Evaluation rule for 'foo' not implemented")
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
|
"Abstract evaluation for 'foo' not implemented")
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
|
"Differentiation rule for 'foo' not implemented")
|
|
|
|
foo_p.def_abstract_eval(lambda x: x)
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
|
"XLA translation rule for primitive 'foo' not found")
|
|
|
|
foo_p.def_impl(lambda x: x)
|
|
ad.defjvp(foo_p, lambda g, x: foo(g))
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
|
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
|
|
|
|
def test_device_put_and_get(self):
|
|
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
|
dx = api.device_put(x)
|
|
self.assertIsInstance(dx, xla.DeviceArray)
|
|
x2 = api.device_get(dx)
|
|
self.assertIsInstance(x2, np.ndarray)
|
|
assert np.all(x == x2)
|
|
|
|
y = [x, (2 * x, 3 * x)]
|
|
dy = api.device_put(y)
|
|
y2 = api.device_get(dy)
|
|
self.assertIsInstance(y2, list)
|
|
self.assertIsInstance(y2[0], np.ndarray)
|
|
assert np.all(y2[0] == x)
|
|
self.assertIsInstance(y2[1], tuple)
|
|
self.assertIsInstance(y2[1][0], np.ndarray)
|
|
assert np.all(y2[1][0] == 2 * x)
|
|
self.assertIsInstance(y2[1][1], np.ndarray)
|
|
assert np.all(y2[1][1] == 3 * x)
|
|
|
|
def test_device_get_scalar(self):
|
|
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
|
x = api.device_put(x)
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
y = [x, 2]
|
|
y2 = api.device_get(y)
|
|
self.assertIsInstance(y2, list)
|
|
self.assertIsInstance(y2[0], np.ndarray)
|
|
assert np.all(y2[0] == x)
|
|
self.assertIsInstance(y2[1], int)
|
|
self.assertEqual(y2[1], 2)
|
|
|
|
@parameterized.parameters([(3,)], [(2, 0)])
|
|
def test_device_put_across_devices(self, shape):
|
|
if len(api.local_devices()) < 2:
|
|
raise unittest.SkipTest("this test requires multiple devices")
|
|
d1, d2 = api.local_devices()[:2]
|
|
data = np.random.randn(*shape).astype(np.float32)
|
|
x = api.device_put(data, device=d1)
|
|
self.assertEqual(x.device_buffer.device(), d1)
|
|
y = api.device_put(x, device=d2)
|
|
self.assertEqual(y.device_buffer.device(), d2)
|
|
np.testing.assert_array_equal(data, np.array(y))
|
|
# Make sure these don't crash
|
|
api.device_put(x)
|
|
api.device_put(y)
|
|
|
|
@jtu.skip_on_devices("cpu")
|
|
def test_device_put_across_platforms(self):
|
|
default_device = jax.devices()[0]
|
|
cpu_device = jax.devices("cpu")[0]
|
|
|
|
np_arr = np.array([1,2,3])
|
|
scalar = 1
|
|
device_arr = jnp.array([1,2,3])
|
|
assert device_arr.device_buffer.device() is default_device
|
|
|
|
for val in [np_arr, device_arr, scalar]:
|
|
x = api.device_put(val, device=cpu_device)
|
|
self.assertEqual(x.device_buffer.device(), cpu_device)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_jacobian(self):
|
|
R = np.random.RandomState(0).randn
|
|
A = R(4, 3)
|
|
x = R(3)
|
|
|
|
f = lambda x: jnp.dot(A, x)
|
|
assert np.allclose(jacfwd(f)(x), A)
|
|
assert np.allclose(jacrev(f)(x), A)
|
|
|
|
f = lambda x: jnp.tanh(jnp.dot(A, x))
|
|
assert np.allclose(jacfwd(f)(x), jacrev(f)(x))
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_hessian(self):
|
|
R = np.random.RandomState(0).randn
|
|
A = R(4, 4)
|
|
x = R(4)
|
|
|
|
f = lambda x: jnp.dot(x, jnp.dot(A, x))
|
|
assert np.allclose(hessian(f)(x), A + A.T)
|
|
|
|
def test_std_basis(self):
|
|
basis = api._std_basis(jnp.zeros(3))
|
|
assert getattr(basis, "shape", None) == (3, 3)
|
|
assert np.allclose(basis, np.eye(3))
|
|
|
|
basis = api._std_basis(jnp.zeros((3, 3)))
|
|
assert getattr(basis, "shape", None) == (9, 3, 3)
|
|
assert np.allclose(basis, np.eye(9).reshape(9, 3, 3))
|
|
|
|
basis = api._std_basis([0., (jnp.zeros(3), jnp.zeros((3, 4)))])
|
|
assert isinstance(basis, list) and len(basis) == 2
|
|
assert getattr(basis[0], "shape", None) == (16,)
|
|
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
|
|
assert getattr(basis[1][0], "shape", None) == (16, 3)
|
|
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_jacobian_on_pytrees(self):
|
|
for jacfun in [jacfwd, jacrev]:
|
|
ans = jacfun(lambda x, y: (x, y))(0., 1.)
|
|
expected = (1., 0.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
|
|
expected = (0., 1.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
|
|
expected = ((1., 0.),
|
|
(0., 1.),)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
|
|
expected = ((1., 0., 0.),
|
|
(0., 1., 0.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
R = np.random.RandomState(0).randn
|
|
x = R(2)
|
|
y = R(3)
|
|
ans = jacfun(lambda x, y: {'x': x, 'xy': jnp.outer(x, y)})(x, y)
|
|
expected = {'x': np.eye(2),
|
|
'xy': np.kron(np.eye(2), y[:, None]).reshape(2, 3, 2)}
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_hessian_on_pytrees(self):
|
|
ans = hessian(lambda x: jnp.array(x)**2)((1., 2.))
|
|
expected = ((np.array([2., 0.]), np.array([0., 0.])),
|
|
(np.array([0., 0.]), np.array([0., 2.])))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_issue1372(self):
|
|
def quad(x):
|
|
return jnp.dot(x, x)
|
|
|
|
def f(x, u):
|
|
return quad(x) + quad(u)
|
|
|
|
x, u = jnp.ones(5), jnp.ones(2)
|
|
|
|
rev = jacrev
|
|
fwd = jacfwd
|
|
|
|
# Diagonal entries
|
|
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
# Off-diagonal entries by reverse-mode on the outside
|
|
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
# Off-diagonal entries by forward-mode on the outside
|
|
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
|
|
def test_large_device_constant(self):
|
|
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
|
|
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
|
|
|
|
def test_grad_and_aux_basic(self):
|
|
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
|
|
self.assertAllClose(g, grad(lambda x: x**3)(3.))
|
|
self.assertAllClose(aux, [9.], check_dtypes=False)
|
|
|
|
def test_grad_and_aux_error(self):
|
|
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
|
grad(lambda x: (1, 2, 3), has_aux=True)(1.)
|
|
|
|
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
|
grad(lambda x: x, has_aux=True)(1.)
|
|
|
|
with self.assertRaisesRegex(TypeError, "two-element tuple"):
|
|
grad(lambda x: (x,), has_aux=True)(1.)
|
|
|
|
def test_grad_and_aux_nested(self):
|
|
def f(x):
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
|
return aux[0]
|
|
|
|
f2 = lambda x: x**3
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
def f(x):
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
|
return aux[0] * jnp.sin(x)
|
|
|
|
f2 = lambda x: x**3 * jnp.sin(x)
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
def test_grad_and_aux_constant(self):
|
|
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
self.assertEqual(aux, [4.])
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**2, 4.]), has_aux=True)(4.)
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
self.assertEqual(aux, [4.**2, 4.])
|
|
|
|
def test_grad_and_aux_no_tracers(self):
|
|
# see https://github.com/google/jax/issues/1950
|
|
def f(x):
|
|
aux = dict(identity=x, p1=x+1)
|
|
return x ** 2, aux
|
|
|
|
_, aux = jax.grad(f, has_aux=True)(3.)
|
|
self.assertIsInstance(aux, dict)
|
|
for val in aux.values():
|
|
self.assertNotIsInstance(val, core.Tracer)
|
|
|
|
def test_jvp_mismatched_arguments(self):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
"structure"),
|
|
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), ()))
|
|
# If primals and tangents must both be tuples or both lists
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
"structure"),
|
|
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), [np.float32(2)]))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"primal and tangent arguments to jax.jvp do not match.",
|
|
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
|
|
# If primals and tangents are not of the same shape then raise error
|
|
fun = lambda x: x+1
|
|
with self.assertRaisesRegex(
|
|
ValueError, "jvp called with different primal and tangent shapes"):
|
|
api.jvp(fun, (jnp.array([1.,2.,3.]),), (jnp.array([1.,2.,3.,4.]),))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "jvp called with different primal and tangent shapes"):
|
|
api.jvp(fun, (jnp.float32(10.),), (jnp.array([1.,2.,3.], dtype=jnp.float32),))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "jvp called with different primal and tangent shapes"):
|
|
api.jvp(fun, (jnp.array([1.,2.,3.], dtype=jnp.float32),), (jnp.float32(20.),))
|
|
with self.assertRaisesRegex(
|
|
ValueError, "jvp called with different primal and tangent shapes"):
|
|
api.jvp(fun, (jnp.array([1.,2.,3.]),), (20.,))
|
|
|
|
def test_jvp_non_tuple_arguments(self):
|
|
def f(x, y): return x + y
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
|
|
lambda: api.jvp(f, 0., (1.,)))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
|
|
lambda: api.jvp(f, (0.,), np.array([1., 2.])))
|
|
|
|
def test_vjp_mismatched_arguments(self):
|
|
_, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"Tree structure of cotangent input.*does not match",
|
|
lambda: pullback((np.float32(7), np.float32(100))))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"Type of cotangent input to vjp pullback.*is not the expected tangent type",
|
|
lambda: pullback((np.float16(42))))
|
|
|
|
def test_vjp_bad_cotangent_shape(self):
|
|
x = np.ones((2, 5), dtype=np.float32)
|
|
y = np.ones((5, 3), dtype=np.float32)
|
|
def f_jax(x, y):
|
|
return jnp.matmul(x, y)
|
|
res, pullback = jax.vjp(f_jax, x, y)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Shape of cotangent input to vjp pullback function .* must be the same as the shape of corresponding primal input .*"):
|
|
pullback(np.ones((2, 4), dtype=np.float32))
|
|
|
|
def test_jvp_jit_cached(self):
|
|
"""Bug in caching in presence of JVP and JIT."""
|
|
|
|
def func(x):
|
|
def inner(y):
|
|
return y * x
|
|
|
|
# Must have two calls to the inner jit (the second one hits the cache)
|
|
res1 = api.jit(inner)(4.)
|
|
res2 = api.jit(inner)(5.)
|
|
return res1 + res2
|
|
|
|
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
|
|
|
|
def test_linear_transpose_abstract(self):
|
|
x = types.SimpleNamespace(shape=(3,), dtype=np.dtype(np.float32))
|
|
y = jnp.arange(3, dtype=np.float32)
|
|
transpose_fun = api.linear_transpose(lambda x: 2 * x, x)
|
|
z, = transpose_fun(y)
|
|
self.assertArraysEqual(2 * y, z, check_dtypes=True)
|
|
|
|
def test_linear_transpose_integer(self):
|
|
f = lambda x: 2 * x
|
|
transpose = api.linear_transpose(f, 1)
|
|
actual, = transpose(3)
|
|
expected = 6
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linear_transpose_error(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, "linear_transpose only supports"):
|
|
api.linear_transpose(lambda x: 2. * x, 1)
|
|
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
|
|
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
|
|
transpose_fun(1.0)
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: jnp.stack([x, x]), 1.0)
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
transpose_fun(1.0)
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: 1j * x, 1.0)
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
transpose_fun(1.0)
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: x, 1.0)
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
transpose_fun(1j)
|
|
|
|
def test_linear_transpose_complex(self):
|
|
f = lambda x: (1 + 2j) * x
|
|
transpose = api.linear_transpose(f, 1j)
|
|
actual, = transpose(3 + 4j)
|
|
expected = -5 + 10j
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linear_transpose_zeros(self):
|
|
f = lambda x: x[0]
|
|
transpose = api.linear_transpose(f, [1., 2.])
|
|
actual, = transpose(3.)
|
|
expected = [3., 0.]
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_complex_grad_raises_error(self):
|
|
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))
|
|
|
|
def test_holomorphic_grad(self):
|
|
out = grad(lambda x: jnp.sin(x), holomorphic=True)(1 + 2j)
|
|
expected = 2.0327230070196656 - 3.0518977991518j
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
|
|
|
def test_nonholomorphic_grad(self):
|
|
zs = 0.5j * np.arange(5) + np.arange(5)
|
|
|
|
def f(z):
|
|
return jnp.sum(jnp.cos(jnp.abs(z)))
|
|
|
|
ans = grad(f)(zs)
|
|
expected = np.array([ 0. + 0.j,
|
|
-0.80430663 + 0.40215331j,
|
|
-0.70368982 + 0.35184491j,
|
|
0.1886467 - 0.09432335j,
|
|
0.86873727 - 0.43436864j])
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
atol=jtu.default_gradient_tolerance,
|
|
rtol=jtu.default_gradient_tolerance)
|
|
|
|
def test_complex_output_jacrev_raises_error(self):
|
|
self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j))
|
|
|
|
def test_nonholomorphic_jacrev(self):
|
|
# code based on https://github.com/google/jax/issues/603
|
|
zs = 0.5j * np.arange(5) + np.arange(5)
|
|
|
|
def f(z):
|
|
return jnp.cos(jnp.linalg.norm(2 * z))
|
|
|
|
ans = jacrev(f)(zs)
|
|
expected = grad(f)(zs)
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def test_complex_input_jacfwd_raises_error(self):
|
|
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
|
|
|
|
def test_legacy_devicearray_repr(self):
|
|
dx = device_put(3.)
|
|
str(dx.item()) # doesn't crash
|
|
|
|
def test_devicearray_repr(self):
|
|
x = device_put(jnp.zeros(3))
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
repr(x) # doesn't crash
|
|
|
|
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
repr(x) # doesn't crash
|
|
|
|
def test_devicearray_delete(self):
|
|
x = device_put(1.)
|
|
x.delete()
|
|
self.assertRaisesRegex(RuntimeError, "DeviceArray has been deleted.",
|
|
lambda: repr(x))
|
|
|
|
def test_devicearray_block_until_ready(self):
|
|
x = device_put(1.)
|
|
y = x.block_until_ready()
|
|
# Tests mostly that block_until_ready() does not produce an error.
|
|
self.assertTrue(y is x)
|
|
|
|
def test_devicearray_weakref_friendly(self):
|
|
x = device_put(1.)
|
|
y = weakref.ref(x)
|
|
self.assertEqual(y(), 1.)
|
|
del x
|
|
self.assertIsNone(y())
|
|
|
|
def test_namedtuple_transparency(self):
|
|
# See https://github.com/google/jax/issues/446
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
def f(pt):
|
|
return jnp.sqrt(pt.x ** 2 + pt.y ** 2)
|
|
|
|
pt = Point(1., 2.)
|
|
|
|
f(pt) # doesn't crash
|
|
g = api.grad(f)(pt)
|
|
self.assertIsInstance(g, Point)
|
|
|
|
f_jit = api.jit(f)
|
|
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
|
|
|
|
def test_namedtuple_subclass_transparency(self):
|
|
# See https://github.com/google/jax/issues/806
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
class ZeroPoint(Point):
|
|
def is_zero(self):
|
|
return (self.x == 0) and (self.y == 0)
|
|
|
|
pt = ZeroPoint(0., 0.)
|
|
|
|
def f(pt):
|
|
return 0. if pt.is_zero() else jnp.sqrt(pt.x ** 2 + pt.y ** 2)
|
|
|
|
f(pt) # doesn't crash
|
|
_ = api.grad(f)(pt)
|
|
self.assertIsInstance(pt, ZeroPoint)
|
|
|
|
@parameterized.parameters(1, 2, 3)
|
|
def test_shape_dtype_struct(self, i):
|
|
s = api.ShapeDtypeStruct(shape=(i, 2, 3), dtype=jnp.float32)
|
|
self.assertEqual(s.shape, (i, 2, 3))
|
|
self.assertEqual(s.dtype, jnp.float32)
|
|
self.assertEqual(s.ndim, 3)
|
|
self.assertEqual(s.size, i * 2 * 3)
|
|
self.assertLen(s, i)
|
|
for f in (str, repr):
|
|
self.assertEqual(
|
|
f(s), "ShapeDtypeStruct(shape=({}, 2, 3), dtype=float32)".format(i))
|
|
|
|
def test_shape_dtype_struct_scalar(self):
|
|
s = api.ShapeDtypeStruct(shape=(), dtype=jnp.float32)
|
|
self.assertEmpty(s.shape)
|
|
self.assertEqual(s.size, 1)
|
|
self.assertEqual(s.ndim, 0)
|
|
with self.assertRaisesRegex(TypeError, "len[(][)] of unsized object"):
|
|
_ = len(s)
|
|
|
|
def test_eval_shape(self):
|
|
def fun(x, y):
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
|
|
|
x = jnp.ones((2, 3))
|
|
y = jnp.ones((3, 4))
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
|
|
|
def test_eval_shape_constants(self):
|
|
def fun():
|
|
x = jnp.ones((2, 3))
|
|
y = jnp.ones((3, 4))
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
|
|
|
out_shape = api.eval_shape(fun)
|
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
|
|
|
def test_eval_shape_tuple_unpacking(self):
|
|
def fun(x, y):
|
|
a, b = x
|
|
return a + b + y
|
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
|
y = 3.
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
|
|
|
def test_eval_shape_tuple_itemgetting(self):
|
|
def fun(x, y):
|
|
return x[0] + x[1] + y
|
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
|
y = 3.
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
|
|
|
def test_eval_shape_output_dict(self):
|
|
def fun(x, y):
|
|
return {'hi': x[0] + x[1] + y}
|
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
|
y = 3.
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
out_shape = tree_util.tree_map(np.shape, out_shape)
|
|
|
|
self.assertEqual(out_shape, {'hi': (2,)})
|
|
|
|
def test_eval_shape_shape_error(self):
|
|
def fun(x, y):
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
|
|
|
x = jnp.ones((3, 3))
|
|
y = jnp.ones((4, 4))
|
|
|
|
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
|
|
|
|
def test_eval_shape_duck_typing(self):
|
|
def fun(A, b, x):
|
|
return jnp.dot(A, x) + b
|
|
|
|
class MyArgArray(object):
|
|
def __init__(self, shape, dtype):
|
|
self.shape = shape
|
|
self.dtype = np.dtype(dtype)
|
|
|
|
A = MyArgArray((3, 4), jnp.float32)
|
|
b = MyArgArray((5,), jnp.float32)
|
|
x = MyArgArray((4, 5), jnp.float32)
|
|
out_shape = api.eval_shape(fun, A, b, x)
|
|
|
|
self.assertEqual(out_shape.shape, (3, 5))
|
|
|
|
def test_eval_shape_duck_typing2(self):
|
|
# https://github.com/google/jax/issues/5683
|
|
class EasyDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
x = EasyDict(shape=(3,), dtype=np.dtype('float32'))
|
|
out_shape = api.eval_shape(lambda x: x, x) # doesn't crash
|
|
self.assertEqual(out_shape.shape, (3,))
|
|
|
|
def test_eval_shape_names(self):
|
|
def fun(x, y):
|
|
return lax.psum(x, 'i') + y
|
|
|
|
class MyArgArray(object):
|
|
def __init__(self, shape, dtype, named_shape):
|
|
self.shape = shape
|
|
self.dtype = jnp.dtype(dtype)
|
|
self.named_shape = named_shape
|
|
|
|
x = MyArgArray((3, 2), jnp.float32, {'i': 10})
|
|
y = MyArgArray((3, 2), jnp.float32, {'j': 5})
|
|
with core.extend_axis_env('i', 10, None):
|
|
with core.extend_axis_env('j', 5, None):
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
self.assertEqual(out_shape.named_shape, {'j': 5})
|
|
|
|
def test_issue_871(self):
|
|
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
|
|
x = jnp.array([1, 2, 3])
|
|
msg = ("linearized function called on tangent values inconsistent with "
|
|
"the original primal values")
|
|
|
|
y, f_jvp = api.linearize(jnp.sum, x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
f_jvp(T)
|
|
|
|
y, f_jvp = api.linearize(api.jit(jnp.sum), x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
f_jvp(T)
|
|
|
|
def test_partial_eval_lower(self):
|
|
# this is a simplified model of a bug that arose when we first used @jit in
|
|
# a jvp rule. it's in this file because we want to use make_jaxpr.
|
|
|
|
# NOTE(mattjj): I no longer understand what this was meant to test. My guess
|
|
# is it was related to staging out the broadcast into a jaxpr to be
|
|
# transposed, but after #1749 that's no longer a problem. After changing
|
|
# make_jaxpr (and jit) to stage out sub-calls fully, this test started to
|
|
# fail; I left it in as skipped because deleting tests feels wrong.
|
|
raise unittest.SkipTest("obsolete test")
|
|
|
|
@api.jit
|
|
def f(a, b, c):
|
|
a = lax.broadcast(a, (2,))
|
|
return lax.select(a, b, c)
|
|
|
|
a = np.ones((3, 3), dtype=np.bool_)
|
|
b = np.ones((2, 3, 3))
|
|
c = np.ones((2, 3, 3))
|
|
|
|
jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c)
|
|
subjaxpr = next(eqn.params["call_jaxpr"] for eqn in jaxpr.jaxpr.eqns
|
|
if "call_jaxpr" in eqn.params)
|
|
self.assertEqual(len(subjaxpr.eqns), 1)
|
|
|
|
def test_grad_of_int_errors(self):
|
|
# Errors without allow_int=True
|
|
dfn = grad(lambda x: x ** 2)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
|
|
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
|
|
lambda: dfn(3))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_jvp_of_int_identity(self):
|
|
primals = (1,)
|
|
tangents = (np.zeros(shape=(), dtype=float0),)
|
|
|
|
_, out = api.jvp(lambda x: x, primals, tangents)
|
|
self.assertEqual(out, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_jvp_of_int_add(self):
|
|
primals = (2,)
|
|
tangents = (np.zeros(shape=(), dtype=float0),)
|
|
|
|
_, out_tangent = api.jvp(lambda x: x+1, primals, tangents)
|
|
self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_jit_jvp_of_int(self):
|
|
primals = (2,)
|
|
tangents = (np.zeros(shape=(), dtype=float0),)
|
|
|
|
_, out_tangent = api.jvp(jax.jit(lambda x: x+1), primals, tangents)
|
|
self.assertEqual(out_tangent, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_vjp_of_int_index(self):
|
|
primal, fn_vjp = api.vjp(lambda x, i: x[i], np.ones(2)*2, 1)
|
|
tangent_x, tangent_i = fn_vjp(1.)
|
|
self.assertEqual(primal, 2.)
|
|
self.assertAllClose(tangent_x, jnp.array([0., 1.]))
|
|
self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_vjp_of_int_shapes(self):
|
|
out, fn_vjp = api.vjp(lambda x: lax.reshape(x, (2, 2)), np.ones((4, 1),
|
|
dtype=int))
|
|
tangent, = fn_vjp(out)
|
|
self.assertArraysEqual(tangent, np.zeros(shape=(4, 1), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_jit_vjp_of_int(self):
|
|
primal, fn_vjp = api.vjp(lambda x, y: x+y, 2, 1)
|
|
tangent_x, tangent_i = jax.jit(fn_vjp)(1)
|
|
self.assertEqual(primal, 3)
|
|
self.assertEqual(tangent_x, np.zeros(shape=(), dtype=float0))
|
|
self.assertEqual(tangent_i, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_vjp_of_int_fulllike(self):
|
|
# Regression test for tangent and cotangent mismatch in convert_element_type
|
|
# transpose rule wrt a ConstVar
|
|
f = lax.full_like
|
|
out, vjp = api.vjp(f, np.zeros((2, 2)), 1)
|
|
self.assertAllClose(out, jnp.ones((2, 2)))
|
|
tangent_x, tangent_y = vjp(out)
|
|
self.assertAllClose(tangent_x, jnp.zeros((2, 2)))
|
|
self.assertEqual(tangent_y, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_grad_of_int(self):
|
|
# Need real-valued output, but testing integer input.
|
|
out = api.grad(lambda x: x+0., allow_int=True)(1)
|
|
self.assertEqual(out, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_grad_of_bool(self):
|
|
def cond(pred):
|
|
return lax.cond(pred, lambda _: 1., lambda _: 2., 1.)
|
|
value, grd = api.value_and_grad(cond, allow_int=True)(True)
|
|
self.assertEqual(value, 1.)
|
|
self.assertEqual(grd, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_grad_of_int_index(self):
|
|
grad_x, grad_i = api.grad(lambda x, i: x[i], argnums=(0, 1),
|
|
allow_int=True)(np.ones(2), 1)
|
|
self.assertAllClose(grad_x, jnp.array([0., 1.]))
|
|
self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_jit_grad_of_int(self):
|
|
grad_f = api.grad(lambda x, i: x[i], argnums=(0, 1), allow_int=True)
|
|
grad_x, grad_i = jax.jit(grad_f)(np.ones(2), 1)
|
|
self.assertAllClose(grad_x, jnp.array([0., 1.]))
|
|
self.assertEqual(grad_i, np.zeros(shape=(), dtype=float0))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_float0_reshape(self):
|
|
# dtype-agnostic operations are supported
|
|
float0_array = jax.grad(lambda x: jnp.sum(x+0.),
|
|
allow_int=True)(np.ones((2, 4), dtype=int))
|
|
|
|
self.assertArraysEqual(float0_array.reshape((4, 2)),
|
|
np.zeros((4, 2), dtype=float0))
|
|
self.assertArraysEqual(float0_array.transpose(),
|
|
np.zeros((4, 2), dtype=float0))
|
|
|
|
def test_float0_error(self):
|
|
# float0 is incompatible with other dtypes
|
|
float0_array = jax.grad(lambda x: x+0., allow_int=True)(1)
|
|
error_text = "float0s do not support any operations by design"
|
|
|
|
with self.assertRaisesRegex(TypeError, error_text):
|
|
# dispatch via DeviceArray
|
|
_ = float0_array + jnp.zeros(())
|
|
|
|
with self.assertRaisesRegex(TypeError, error_text):
|
|
# dispatch via lax
|
|
_ = lax.add(float0_array, jnp.zeros(()))
|
|
|
|
def test_grad_complex_result_errors(self):
|
|
dfn = grad(lambda x: x ** 2 + 1j)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"grad requires real-valued outputs \(output dtype that is a "
|
|
r"sub-dtype of np.floating\), but got complex.*"),
|
|
lambda: dfn(3.))
|
|
|
|
def test_holomorphic_grad_of_float_errors(self):
|
|
dfn = grad(lambda x: x ** 2, holomorphic=True)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"grad with holomorphic=True requires inputs with complex dtype, "
|
|
r"but got float.*"),
|
|
lambda: dfn(3.))
|
|
|
|
def test_holomorphic_jacrev_of_float_errors(self):
|
|
dfn = jacrev(lambda x: x ** 2, holomorphic=True)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"jacrev with holomorphic=True requires inputs with complex dtype, "
|
|
r"but got float.*"),
|
|
lambda: dfn(3.))
|
|
|
|
def test_holomorphic_jacfwd_of_float_errors(self):
|
|
dfn = jacfwd(lambda x: x ** 2, holomorphic=True)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"jacfwd with holomorphic=True requires inputs with complex dtype, "
|
|
r"but got float.*"),
|
|
lambda: dfn(3.))
|
|
|
|
def test_jacfwd_of_complex_errors(self):
|
|
dfn = jacfwd(lambda x: x ** 2)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(r"jacfwd requires real-valued inputs \(input dtype that is a "
|
|
r"sub-dtype of np.floating\), but got complex.*"),
|
|
lambda: dfn(3. + 1j))
|
|
|
|
def test_xla_computation(self):
|
|
# these tests basically check the examples in the xla_computation docstring
|
|
|
|
def e(x):
|
|
return jnp.sin(jnp.cos(x))
|
|
c = api.xla_computation(e)(2.)
|
|
self.assertIn('cosine', c.as_hlo_text())
|
|
self.assertIn('sine', c.as_hlo_text())
|
|
|
|
def f(x):
|
|
return x - lax.psum(x, 'i')
|
|
axis_env = [('i', 4)]
|
|
c = api.xla_computation(f, axis_env=axis_env)(2)
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
|
|
|
|
def g(x):
|
|
rowsum = lax.psum(x, 'i')
|
|
colsum = lax.psum(x, 'j')
|
|
allsum = lax.psum(x, ('i', 'j'))
|
|
return rowsum, colsum, allsum
|
|
axis_env = [('i', 4), ('j', 2)]
|
|
c = api.xla_computation(g, axis_env=axis_env)(5.)
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
|
|
|
|
def h(x):
|
|
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
|
|
colsum = lax.psum(x, 'j')
|
|
return rowsum, colsum
|
|
axis_env = [('i', 4), ('j', 2)]
|
|
c = api.xla_computation(h, axis_env=axis_env)(5.)
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
|
|
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
|
|
|
def test_xla_computation_args(self):
|
|
def foo(x, y, z):
|
|
return x + y + z
|
|
|
|
c = api.xla_computation(foo)(1., 2., 3.)
|
|
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
param_shapes = c.program_shape().parameter_shapes()
|
|
self.assertEqual(len(param_shapes), 1)
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
def test_xla_computation_duck_typing(self):
|
|
def foo(x, y, z):
|
|
return x + y + z
|
|
|
|
x = jax.ShapeDtypeStruct((), np.float32)
|
|
y = jax.ShapeDtypeStruct((), np.float32)
|
|
z = jax.ShapeDtypeStruct((), np.float32)
|
|
|
|
c = api.xla_computation(foo)(x, y, z)
|
|
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
param_shapes = c.program_shape().parameter_shapes()
|
|
self.assertEqual(len(param_shapes), 1)
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
def test_staging_out_multi_replica(self):
|
|
def f(x):
|
|
return api.pmap(jnp.mean)(x)
|
|
xla_comp = api.xla_computation(f)
|
|
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
|
|
|
|
def test_xla_computation_instantiate_constant_outputs(self):
|
|
def f():
|
|
return jnp.zeros((3, 4))
|
|
|
|
xla_comp = api.xla_computation(f)()
|
|
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
|
|
self.assertEqual(out_shape.dimensions(), (3, 4))
|
|
|
|
def test_xla_computation_static_argnums(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
self.assertIn("constant(3)", hlo_text)
|
|
# The static arguments should be removed from the function being compiled,
|
|
# thus the function should have only a single argument.
|
|
self.assertIn("parameter.1", hlo_text)
|
|
self.assertNotIn("parameter.2", hlo_text)
|
|
|
|
def test_xla_computation_return_shape(self):
|
|
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
|
|
return_shape=True)(np.int32(1))
|
|
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
|
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
|
self.assertEqual(shape_tree, expected)
|
|
|
|
def test_xla_computation_partitioned(self):
|
|
def f(x, y):
|
|
return jnp.dot(x, y) + 1
|
|
|
|
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
|
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
|
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
|
|
out_parts=P(4, 1))(x, y)
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
|
self.assertIn('sharding={replicated}', hlo_text)
|
|
self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text)
|
|
|
|
def test_xla_computation_replicated_and_partitioned(self):
|
|
def f(x, y):
|
|
return jnp.dot(x, y), lax.psum(x, 'i')
|
|
|
|
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
|
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
|
axis_env = [('i', 4)]
|
|
xla_comp = api.xla_computation(f, axis_env=axis_env,
|
|
in_parts=(P(2, 2), None),
|
|
out_parts=(P(4, 1), None))(x, y)
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
self.assertIn('all-reduce', hlo_text)
|
|
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
|
|
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
|
self.assertIn('sharding={replicated}', hlo_text)
|
|
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
|
|
|
|
def test_xla_computation_psum_constant(self):
|
|
f = lambda: jax.lax.psum(1, "i")
|
|
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
|
|
|
@jtu.skip_on_devices("cpu", "gpu")
|
|
@jtu.ignore_warning(message="Some donated buffers were not usable")
|
|
def test_xla_computation_donate_argnums(self):
|
|
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
|
|
|
|
def test_xla_computation_lower_fun_axis_env(self):
|
|
axis_name = 'i'
|
|
def fn(x):
|
|
y = lax.all_gather(
|
|
x, axis_name=axis_name)
|
|
return y * lax.axis_index(axis_name).astype(jnp.float32)
|
|
|
|
input_x = jnp.ones((5,6,4))
|
|
axis_env = [(axis_name, api.local_device_count())]
|
|
_ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x)
|
|
|
|
def test_concurrent_device_get_and_put(self):
|
|
def f(x):
|
|
for _ in range(100):
|
|
y = jax.device_put(x)
|
|
x = jax.device_get(y)
|
|
return x
|
|
|
|
xs = [np.random.randn(i) for i in range(10)]
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
ys = [f.result() for f in futures]
|
|
for x, y in zip(xs, ys):
|
|
self.assertAllClose(x, y)
|
|
|
|
def test_dtype_warning(self):
|
|
# cf. issue #1230
|
|
if config.x64_enabled:
|
|
raise unittest.SkipTest("test only applies when x64 is disabled")
|
|
|
|
def check_warning(warn, nowarn):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter("always")
|
|
|
|
nowarn() # get rid of extra startup warning
|
|
|
|
prev_len = len(w)
|
|
nowarn()
|
|
assert len(w) == prev_len
|
|
|
|
warn()
|
|
assert len(w) > 0
|
|
msg = str(w[-1].message)
|
|
expected_prefix = "Explicitly requested dtype "
|
|
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
|
|
|
|
prev_len = len(w)
|
|
nowarn()
|
|
assert len(w) == prev_len
|
|
|
|
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
|
|
lambda: jnp.array([1, 2, 3], dtype="float32"))
|
|
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
|
|
lambda: jnp.array([1, 2, 3], dtype=float))
|
|
check_warning(lambda: jnp.ones(3, dtype=np.float64),
|
|
lambda: jnp.ones(3))
|
|
check_warning(lambda: jnp.ones(3, dtype=np.float64),
|
|
lambda: jnp.ones(3, dtype=float))
|
|
check_warning(lambda: jnp.ones_like(3, dtype=np.int64),
|
|
lambda: jnp.ones_like(3, dtype=np.int32))
|
|
check_warning(lambda: jnp.zeros(3, dtype="int64"),
|
|
lambda: jnp.zeros(3, dtype="int32"))
|
|
check_warning(lambda: jnp.zeros_like(3, dtype="float64"),
|
|
lambda: jnp.zeros_like(3, dtype="float32"))
|
|
check_warning(lambda: jnp.full((2, 3), 1, dtype="int64"),
|
|
lambda: jnp.full((2, 3), 1))
|
|
check_warning(lambda: jnp.ones(3).astype("float64"),
|
|
lambda: jnp.ones(3).astype("float32"))
|
|
check_warning(lambda: jnp.eye(3, dtype=np.float64),
|
|
lambda: jnp.eye(3))
|
|
check_warning(lambda: jnp.arange(3, dtype=np.float64),
|
|
lambda: jnp.arange(3, dtype=np.float32))
|
|
check_warning(lambda: jnp.linspace(0, 3, dtype=np.float64),
|
|
lambda: jnp.linspace(0, 3, dtype=np.float32))
|
|
check_warning(lambda: jnp.tri(2, dtype="float64"),
|
|
lambda: jnp.tri(2, dtype="float32"))
|
|
check_warning(lambda: jnp.arange(1).astype("float64"),
|
|
lambda: jnp.arange(1).astype(float))
|
|
check_warning(lambda: jnp.arange(1.0).astype("int64"),
|
|
lambda: jnp.arange(1.0).astype(int))
|
|
|
|
def test_error_for_invalid_dtype(self):
|
|
with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"):
|
|
lax.add(jnp.array(7), np.array("hello"))
|
|
|
|
def test_vmap_preserves_docstr(self):
|
|
def superfun(a):
|
|
"""Does things with stuff."""
|
|
pass
|
|
|
|
self.assertRegex(api.vmap(superfun).__doc__, "\n".join([
|
|
"Vectorized version of superfun.*",
|
|
"",
|
|
"Original documentation:",
|
|
"",
|
|
superfun.__doc__,
|
|
]))
|
|
|
|
def test_vmap_in_axes_list(self):
|
|
# https://github.com/google/jax/issues/2367
|
|
dictionary = {'a': 5., 'b': jnp.ones(2)}
|
|
x = jnp.zeros(3)
|
|
y = jnp.arange(3.)
|
|
|
|
|
|
def f(dct, x, y):
|
|
return dct['a'] + dct['b'] + x + y
|
|
|
|
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
|
|
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
|
|
self.assertAllClose(out1, out2)
|
|
|
|
def test_vmap_in_axes_tree_prefix_error(self):
|
|
# https://github.com/google/jax/issues/795
|
|
value_tree = jnp.ones(3)
|
|
self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap in_axes specification must be a tree prefix of the corresponding "
|
|
r"value, got specification \(0, 0\) for value tree "
|
|
+ re.escape(f"{tree_util.tree_structure((value_tree,))}."),
|
|
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(value_tree)
|
|
)
|
|
|
|
def test_vmap_in_axes_leaf_types(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, r"vmap in_axes must be an int, None, or .*"):
|
|
api.vmap(lambda x: x, in_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
|
|
|
|
def test_vmap_out_axes_leaf_types(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, r"vmap out_axes must be an int, None, or .*"):
|
|
api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
|
|
|
|
def test_vmap_unbatched_object_passthrough_issue_183(self):
|
|
# https://github.com/google/jax/issues/183
|
|
fun = lambda f, x: f(x)
|
|
vfun = api.vmap(fun, (None, 0))
|
|
ans = vfun(lambda x: x + 1, jnp.arange(3))
|
|
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
|
|
|
|
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
|
|
# https://github.com/google/jax/issues/705
|
|
def h(a, b):
|
|
return jnp.sum(a) + jnp.sum(b)
|
|
|
|
X = np.random.randn(10, 4)
|
|
U = np.random.randn(10, 2)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
|
"so\n"
|
|
"arg 0 has an axis to be mapped of size 10\n"
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
api.vmap(h, in_axes=(0, 1))(X, U)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
|
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
"so\n"
|
|
"args 0, 2 have axes to be mapped of size 10\n"
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
|
"the tree of axis sizes is:\n"
|
|
r"\(10, \[2, 2\]\)"):
|
|
api.vmap(h, in_axes=(0, 1))(X, [U, U])
|
|
|
|
error = (r"vmap was requested to map its argument along axis 0, which "
|
|
r"implies that its rank should be at least 1, but is only 0 "
|
|
r"\(its shape is \(\)\)")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
# The mapped inputs cannot be scalars
|
|
api.vmap(lambda x: x)(1.)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "vmap must have at least one non-None value in in_axes"):
|
|
# If the output is mapped, there must be a non-None in_axes
|
|
api.vmap(lambda x: x, in_axes=None)(jnp.array([1., 2.]))
|
|
|
|
error = (r"vmap was requested to map its argument along axis 1, which "
|
|
r"implies that its rank should be at least 2, but is only 1 "
|
|
r"\(its shape is \(2,\)\)")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
api.vmap(lambda x: x, in_axes=1)(jnp.array([1., 2.]))
|
|
|
|
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap out_axes specification must be a tree prefix of the "
|
|
"corresponding value.*"):
|
|
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(jnp.array([1., 2.]))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"vmap has mapped output \(axis_name=foo\) but out_axes is None"):
|
|
# If the output is mapped (user-named axis), then there must be some
|
|
# out_axes specified.
|
|
api.vmap(lambda x: x, out_axes=None, axis_name="foo")(jnp.array([1., 2.]))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"vmap has mapped output but out_axes is None"):
|
|
# If the output is mapped (unnamed axis), then there must be some out_axes
|
|
# specified.
|
|
api.vmap(lambda x: x, out_axes=None)(jnp.array([1., 2.]))
|
|
|
|
def test_vmap_structured_in_axes(self):
|
|
|
|
A, B, C, D = 2, 3, 4, 5
|
|
K = 6 # batch size
|
|
x = np.ones((K, A, B)) # batch axis in different locations
|
|
y = np.ones((B, K, C))
|
|
z = np.ones((C, D, K))
|
|
|
|
def foo(tree_arg):
|
|
x, (y, z) = tree_arg
|
|
return jnp.dot(x, jnp.dot(y, z))
|
|
|
|
tree = (x, (y, z))
|
|
vfoo = api.vmap(foo, in_axes=((0, (1, 2)),))
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
tree = (x, Point(y, z))
|
|
vfoo = api.vmap(foo, in_axes=((0, Point(1, 2)),))
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
def foo(tree_arg):
|
|
x, dct = tree_arg
|
|
y, z = dct['a'], dct['b']
|
|
return jnp.dot(x, jnp.dot(y, z))
|
|
|
|
tree = (x, {'a': y, 'b': z})
|
|
vfoo = api.vmap(foo, in_axes=((0, {'a': 1, 'b': 2}),))
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
tree = (x, collections.OrderedDict([('a', y), ('b', z)]))
|
|
vfoo = api.vmap(
|
|
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
def test_vmap_in_axes_bool_error(self):
|
|
# https://github.com/google/jax/issues/6372
|
|
with self.assertRaisesRegex(TypeError, "must be an int"):
|
|
api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3))
|
|
|
|
def test_pmap_in_axes_bool_error(self):
|
|
# https://github.com/google/jax/issues/6372
|
|
with self.assertRaisesRegex(TypeError, "must be an int"):
|
|
api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1))
|
|
|
|
def test_pmap_global_cache(self):
|
|
def f(x, y):
|
|
return x, y
|
|
|
|
x = np.ones((1, 1, 1))
|
|
|
|
# All defaults
|
|
with jtu.assert_num_jit_and_pmap_compilations(1):
|
|
for _ in range(2):
|
|
api.pmap(f)(x, x)
|
|
|
|
# With axis name
|
|
with jtu.assert_num_jit_and_pmap_compilations(1):
|
|
for _ in range(2):
|
|
api.pmap(f, 'i')(x, x)
|
|
|
|
# With in_axes and out_axes
|
|
for x_in, y_in, x_out, y_out in it.product(*((0, 1, 2) for _ in range(4))):
|
|
with jtu.assert_num_jit_and_pmap_compilations(1):
|
|
for _ in range(2):
|
|
api.pmap(f, 'i', in_axes=(x_in, y_in), out_axes=(x_out, y_out))(x, x)
|
|
|
|
# Forward-mode AD on the outside
|
|
with jtu.assert_num_jit_and_pmap_compilations(1):
|
|
for _ in range(2):
|
|
api.jvp(api.pmap(f), (x, x), (x, x))
|
|
|
|
# Reverse-mode AD on the outside. One compilation for forward, one for backward.
|
|
with jtu.assert_num_jit_and_pmap_compilations(2):
|
|
for _ in range(2):
|
|
api.vjp(api.pmap(f), x, x)[1]((x, x))
|
|
|
|
def test_device_array_repr(self):
|
|
rep = jnp.ones(()) + 1.
|
|
self.assertStartsWith(repr(rep), "DeviceArray")
|
|
|
|
def test_device_array_hash(self):
|
|
rep = jnp.ones(()) + 1.
|
|
self.assertIsInstance(rep, jax.interpreters.xla.DeviceArray)
|
|
self.assertNotIsInstance(rep, collections.Hashable)
|
|
with self.assertRaisesRegex(TypeError, 'unhashable type'):
|
|
hash(rep)
|
|
|
|
def test_grad_without_enough_args_error_message(self):
|
|
# https://github.com/google/jax/issues/1696
|
|
def f(x, y): return x + y
|
|
df = api.grad(f, argnums=0)
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"differentiating with respect to argnums=0 requires at least 1 "
|
|
"positional arguments to be passed by the caller, but got only 0 "
|
|
"positional arguments.",
|
|
lambda: partial(df, x=0.)(y=1.))
|
|
|
|
def test_grad_of_jit_compilation_caching(self):
|
|
if not hasattr(self, "assertLogs"):
|
|
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
|
|
|
lax.add(1, 2) # make sure some initial warnings are already printed
|
|
|
|
sin = api.jit(jnp.sin)
|
|
|
|
prev_level = logging.get_verbosity()
|
|
try:
|
|
logging.set_verbosity('DEBUG')
|
|
with self.assertLogs(level=logging.DEBUG) as l:
|
|
ans1 = api.grad(sin)(2.)
|
|
ans2 = api.grad(sin)(3.)
|
|
finally:
|
|
logging.set_verbosity(prev_level)
|
|
self.assertLen(l.output, 2)
|
|
|
|
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
|
|
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
|
|
|
|
def test_trivial_computations(self):
|
|
x = jnp.array([1, 2, 3])
|
|
y = api.jit(lambda x: x)(x)
|
|
self.assertIs(x, y)
|
|
|
|
z1, z2 = api.jit(lambda x: (x, x))(x)
|
|
self.assertIs(z1, z2)
|
|
|
|
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
|
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
|
self.assertIs(z1, x2)
|
|
self.assertIs(z3, x1)
|
|
self.assertEqual(z2, 1)
|
|
|
|
def test_nested_jit_hoisting(self):
|
|
@api.jit
|
|
def f(x, y):
|
|
z = 2 * x
|
|
return y + z, 3
|
|
|
|
@api.jit
|
|
def g(x):
|
|
return f(2, x)
|
|
|
|
jaxpr_subcomp = xla.jaxpr_subcomp
|
|
|
|
jaxprs = []
|
|
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
|
jaxprs.append(jaxpr)
|
|
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
|
|
|
try:
|
|
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
|
ans = g(3)
|
|
finally:
|
|
xla.jaxpr_subcomp = jaxpr_subcomp
|
|
|
|
self.assertEqual(ans, (7, 3))
|
|
self.assertLen(jaxprs, 2)
|
|
outer_jaxpr, inner_jaxpr = jaxprs
|
|
|
|
self.assertLen(outer_jaxpr.eqns, 1)
|
|
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
|
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
|
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
|
self.assertLen(inner_jaxpr.eqns, 2)
|
|
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
|
|
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
|
|
|
|
def test_primitive_compilation_cache(self):
|
|
with jtu.count_primitive_compiles() as count:
|
|
lax.add(1, 2)
|
|
lax.add(2, 3)
|
|
self.assertEqual(count[0], 1)
|
|
|
|
def test_arange_jit(self):
|
|
# see https://github.com/google/jax/issues/553
|
|
def fun(x):
|
|
r = jnp.arange(x.shape[0])[x]
|
|
return r
|
|
|
|
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
|
|
|
|
def helper_save_tracer(self, x):
|
|
self._saved_tracer = x
|
|
return x
|
|
|
|
def test_escaped_tracers_different_top_level_traces(self):
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError, "Encountered an unexpected tracer"):
|
|
api.jit(lambda x: self._saved_tracer)(0.)
|
|
|
|
def test_escaped_tracers_cant_lift_sublevels(self):
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError,
|
|
re.compile(
|
|
"Encountered an unexpected tracer",
|
|
re.DOTALL)):
|
|
api.jit(lambda x: x)(self._saved_tracer)
|
|
|
|
def test_escaped_tracers_tracer_from_higher_level(self):
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError,
|
|
re.compile(
|
|
"Encountered an unexpected tracer.*Tracer from a higher level",
|
|
re.DOTALL)):
|
|
api.grad(lambda x: x)(self._saved_tracer)
|
|
|
|
def test_escaped_tracers_incompatible_sublevel(self):
|
|
def func1(x):
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
# Use the tracer
|
|
return x + self._saved_tracer
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError,
|
|
re.compile("Encountered an unexpected tracer",
|
|
re.DOTALL)):
|
|
api.jit(func1)(2.)
|
|
|
|
def test_escaped_tracers_cant_lift(self):
|
|
def func1(x):
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
return x + self._saved_tracer
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError,
|
|
re.compile("Encountered an unexpected tracer.*Can't lift",
|
|
re.DOTALL)):
|
|
api.grad(func1)(2.)
|
|
|
|
def test_escaped_tracers_not_among_input_tracers(self):
|
|
def func1(x):
|
|
api.grad(self.helper_save_tracer)(x)
|
|
# Use the tracer
|
|
return x + self._saved_tracer
|
|
|
|
with self.assertRaisesRegex(
|
|
UnexpectedTracerError,
|
|
re.compile(
|
|
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
|
re.DOTALL)):
|
|
api.jit(func1)(2.)
|
|
|
|
def test_escaped_tracer_omnistaging(self):
|
|
count = 1
|
|
|
|
@jit
|
|
def f():
|
|
nonlocal count
|
|
count = jnp.add(count, 1)
|
|
f() # leaked a tracer! but currently undetected
|
|
|
|
def f(x, c):
|
|
jnp.add(count, 1)
|
|
return None, None
|
|
|
|
@jit
|
|
def g():
|
|
lax.scan(f, None, None, length=2)
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError,
|
|
"was created on line"):
|
|
g()
|
|
|
|
def test_escaped_tracer_omnistaging_top_trace(self):
|
|
count = 1
|
|
|
|
def f(_, __):
|
|
nonlocal count
|
|
count = jnp.add(count, 1)
|
|
return None, None
|
|
|
|
lax.scan(f, None, None, length=2) # leaked a tracer! (of level 1!)
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError,
|
|
"was created on line"):
|
|
# The following call will try and raise the ones array to the count tracer
|
|
# level, which is no longer live.
|
|
jax.jit(jnp.add)(jnp.ones(()), count)
|
|
|
|
def test_escaped_tracer_transform_name(self):
|
|
with self.assertRaisesRegex(UnexpectedTracerError,
|
|
"for jit"):
|
|
jax.jit(self.helper_save_tracer)(1)
|
|
_ = self._saved_tracer+1
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError,
|
|
"for pmap"):
|
|
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
|
|
_ = self._saved_tracer+1
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError,
|
|
"for eval_shape"):
|
|
jax.eval_shape(self.helper_save_tracer, 1)
|
|
_ = self._saved_tracer+1
|
|
|
|
def test_escaped_tracer_shape_dtype(self):
|
|
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
|
r"shape \(4, 3\) and dtype int32"):
|
|
jax.jit(self.helper_save_tracer)(jnp.ones((4, 3), dtype=jnp.int32))
|
|
_ = self._saved_tracer+1
|
|
|
|
def test_pmap_static_kwarg_error_message(self):
|
|
# https://github.com/google/jax/issues/3007
|
|
def f(a, b):
|
|
return a + b
|
|
|
|
g = jax.pmap(f, static_broadcasted_argnums=(1,))
|
|
|
|
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
|
|
r"called with only 1 positional argument. All static broadcasted "
|
|
r"arguments must be passed positionally.")
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
g(jnp.ones((1, 1)), b=1)
|
|
|
|
def test_vmap_unmapped_last(self):
|
|
@partial(jax.vmap, out_axes=-1)
|
|
def f(x):
|
|
return np.zeros((2,))
|
|
f(np.zeros((5,)))
|
|
|
|
# TODO(jakevdp): re-enable this if possible.
|
|
@unittest.skipIf(True, "broken by convert_element_type change.")
|
|
def test_xla_constant_dedup(self):
|
|
y = np.array([7, 14], dtype=np.float32)
|
|
def f(x):
|
|
return x + y + y
|
|
|
|
x = np.array([1, 2], dtype=np.float32)
|
|
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
|
|
hlo_lines = set([s.strip() for s in hlo_lines])
|
|
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
|
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
|
|
|
def test_eval_context(self):
|
|
@jit
|
|
def f():
|
|
with core.eval_context():
|
|
assert jnp.add(1, 1) == 2
|
|
|
|
f() # doesn't crash
|
|
|
|
def test_concrete_error_because_arg_unary(self):
|
|
@jax.jit
|
|
def f(x):
|
|
if x > 0:
|
|
return x
|
|
else:
|
|
return 0
|
|
|
|
msg = r"on the value of the argument 'x'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1)
|
|
|
|
def test_concrete_error_because_arg_binary(self):
|
|
@jax.jit
|
|
def f(x, y):
|
|
if x > y:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments 'x' and 'y'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2)
|
|
|
|
def test_concrete_error_because_arg_ternary(self):
|
|
@jax.jit
|
|
def f(x, y, z):
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the arguments 'x' and 'z'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, 3)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, z=3)
|
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, y=2, z=3)
|
|
|
|
def test_concrete_error_because_arg_varargs(self):
|
|
@jax.jit
|
|
def f(*args):
|
|
x, y, z = args
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the argument 'args'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(1, 2, 3)
|
|
|
|
def test_concrete_error_because_arg_kwargs(self):
|
|
@jax.jit
|
|
def f(**kwargs):
|
|
x, y, z = kwargs['x'], kwargs['y'], kwargs['z']
|
|
if x > z:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the values of the argument 'kwargs'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f(x=1, y=2, z=3)
|
|
|
|
def test_concrete_error_because_arg_pytree(self):
|
|
@jax.jit
|
|
def f(xy, z):
|
|
x, y = xy
|
|
if x > 0:
|
|
return x
|
|
else:
|
|
return y
|
|
|
|
msg = r"on the value of the argument 'xy'"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f((1, 2), z=3)
|
|
|
|
def test_concrete_error_because_const(self):
|
|
@jax.jit
|
|
def f():
|
|
assert jnp.add(1, 1) > 0
|
|
|
|
msg = "on these lines"
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError, msg):
|
|
f()
|
|
|
|
def test_xla_computation_zeros_doesnt_device_put(self):
|
|
with jtu.count_device_put() as count:
|
|
api.xla_computation(lambda: jnp.zeros(3))()
|
|
self.assertEqual(count[0], 0)
|
|
|
|
def test_join_concrete_arrays_with_omnistaging(self):
|
|
# https://github.com/google/jax/issues/4622
|
|
x = jnp.array([1., 2., 3.])
|
|
y = jnp.array([1., 2., 4.])
|
|
|
|
@jit
|
|
def f():
|
|
core.lattice_join(core.ConcreteArray(x), core.ConcreteArray(y))
|
|
|
|
f() # doesn't crash
|
|
|
|
def test_linearize_aval_error(self):
|
|
# https://github.com/google/jax/issues/4622
|
|
f = lambda x: x
|
|
|
|
# these should not error
|
|
_, f_jvp = api.linearize(f, 1.)
|
|
f_jvp(1.)
|
|
_, f_jvp = api.linearize(f, np.ones(2, np.int32))
|
|
f_jvp(np.zeros(2, float0))
|
|
|
|
# these should error
|
|
_, f_jvp = api.linearize(f, 1.)
|
|
with self.assertRaisesRegex(ValueError, "tangent values inconsistent"):
|
|
f_jvp(1)
|
|
_, f_jvp = api.linearize(f, np.ones(2, np.int32))
|
|
with self.assertRaisesRegex(ValueError, "tangent values inconsistent"):
|
|
f_jvp(np.ones(2, np.int32))
|
|
|
|
def test_grad_of_token_consuming_primitive(self):
|
|
# https://github.com/google/jax/issues/5463
|
|
tokentest_p = core.Primitive("tokentest")
|
|
tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p))
|
|
tokentest_p.def_abstract_eval(lambda x, y: x)
|
|
xla.translations[tokentest_p] = lambda c, x, y: x
|
|
ad.defjvp(tokentest_p, (lambda g, x, token: x), None)
|
|
|
|
token = jax.lax.create_token(123)
|
|
arr = jnp.ones((3, 2))
|
|
res, vjp_fun = jax.vjp(lambda x: tokentest_p.bind(x, token), arr)
|
|
# Should not crash.
|
|
vjp_fun(arr)
|
|
|
|
def test_jit_returning_token(self):
|
|
x = jax.jit(jax.lax.create_token)(1.0)
|
|
self.assertIsInstance(x, jax.interpreters.xla.Token)
|
|
|
|
def test_leak_checker_catches_a_jit_leak(self):
|
|
with jax.checking_leaks():
|
|
lst = []
|
|
|
|
@jit
|
|
def f(x):
|
|
lst.append(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(Exception, r"Leaked"):
|
|
f(3)
|
|
|
|
def test_leak_checker_catches_a_pmap_leak(self):
|
|
with jax.checking_leaks():
|
|
lst = []
|
|
|
|
@api.pmap
|
|
def f(x):
|
|
lst.append(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(Exception, r"Leaked"):
|
|
f(np.ones(1))
|
|
|
|
def test_leak_checker_catches_a_grad_leak(self):
|
|
with jax.checking_leaks():
|
|
lst = []
|
|
|
|
def f(x):
|
|
lst.append(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(Exception, r"Leaked trace"):
|
|
api.grad(f)(3.)
|
|
|
|
def test_leak_checker_avoids_false_positives(self):
|
|
with jax.checking_leaks():
|
|
@jit
|
|
def f(x):
|
|
return x
|
|
f(3) # doesn't crash
|
|
api.vmap(f)(np.arange(3)) # doesn't crash
|
|
api.grad(f)(3.) # doesn't crash
|
|
|
|
@api.pmap
|
|
def f(x):
|
|
return x
|
|
f(np.ones(1)) # doesn't crash
|
|
api.vmap(f)(np.ones((1, 1))) # doesn't crash
|
|
|
|
def test_leak_checker_catches_a_scan_leak(self):
|
|
with jax.checking_leaks():
|
|
lst = []
|
|
|
|
to_scan = lambda c, x: (lst.append(c) or jnp.sin(c), None)
|
|
|
|
with self.assertRaisesRegex(Exception, r"Leaked trace"):
|
|
lax.scan(to_scan, 1., np.arange(3.))
|
|
|
|
def test_leak_checker_avoids_false_positives_scan(self):
|
|
with jax.checking_leaks():
|
|
to_scan = lambda c, x: (jnp.sin(c), None)
|
|
lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash
|
|
|
|
def test_leak_checker_avoids_false_positives_scan_jvp(self):
|
|
with jax.checking_leaks():
|
|
to_scan = lambda c, x: (c, None)
|
|
|
|
def f(x):
|
|
lax.scan(to_scan, x, None, length=1)
|
|
api.jvp(f, (3.,), (1.,)) # doesn't crash
|
|
|
|
def test_leak_checker_avoids_false_positives_scan_vmap(self):
|
|
with jax.checking_leaks():
|
|
to_scan = lambda c, _: (1., None)
|
|
|
|
@api.vmap
|
|
def f(x):
|
|
lax.scan(to_scan, x, None, length=1)
|
|
f(np.arange(5.)) # doesn't crash
|
|
|
|
def test_leak_checker_avoids_false_positives_scan_vmap_2(self):
|
|
with jax.checking_leaks():
|
|
to_scan = lambda c, _: (c, None)
|
|
|
|
@api.vmap
|
|
def f(x):
|
|
lax.scan(to_scan, x, None, length=1)
|
|
f(np.arange(5.)) # doesn't crash
|
|
|
|
def test_leak_checker_catches_a_sublevel_leak(self):
|
|
with jax.checking_leaks():
|
|
@jit
|
|
def f(x):
|
|
lst = []
|
|
@jit
|
|
def g(x):
|
|
lst.append(x)
|
|
return x
|
|
|
|
x = g(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(Exception, r"Leaked sublevel"):
|
|
f(3)
|
|
|
|
def test_leak_checker_avoids_false_positive_custom_jvp(self):
|
|
# see https://github.com/google/jax/issues/5636
|
|
with jax.checking_leaks():
|
|
@api.custom_jvp
|
|
def t(y):
|
|
return y
|
|
|
|
def t_jvp(p, t):
|
|
pass
|
|
|
|
t.defjvp(t_jvp)
|
|
|
|
@jit
|
|
def s(y):
|
|
return t(y)
|
|
s(3) # doesn't crash
|
|
|
|
def test_default_backend(self):
|
|
first_local_device = api.local_devices()[0]
|
|
self.assertEqual(first_local_device.platform, api.default_backend())
|
|
|
|
def test_dunder_jax_array(self):
|
|
# https://github.com/google/jax/pull/4725
|
|
|
|
class AlexArray:
|
|
def __init__(self, jax_val):
|
|
self.jax_val = jax_val
|
|
def __jax_array__(self):
|
|
return self.jax_val
|
|
dtype = property(lambda self: self.jax_val.dtype)
|
|
shape = property(lambda self: self.jax_val.shape)
|
|
|
|
x = AlexArray(jnp.array([1., 2., 3.]))
|
|
y = jnp.sin(x)
|
|
self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.])))
|
|
y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x)
|
|
self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.])))
|
|
|
|
x = AlexArray(jnp.array([[1., 2., 3.]]))
|
|
y = api.pmap(jnp.sin)(x)
|
|
self.assertAllClose(y, jnp.sin(jnp.array([[1., 2., 3.]])))
|
|
|
|
x = jnp.array(1)
|
|
a = AlexArray(x)
|
|
for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]:
|
|
self.assertEqual(f(x), f(a))
|
|
|
|
def test_constant_handler_mro(self):
|
|
# https://github.com/google/jax/issues/6129
|
|
|
|
class Foo(enum.IntEnum):
|
|
bar = 1
|
|
|
|
@api.pmap
|
|
def f(_):
|
|
return Foo.bar
|
|
|
|
ans = f(jnp.arange(1)) # doesn't crash
|
|
expected = jnp.arange(1) + 1
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def test_large_python_ints(self):
|
|
with self.assertRaises(OverflowError):
|
|
jnp.multiply(2 ** 100, 3.)
|
|
|
|
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
|
|
self.assertArraysEqual(out, np.float32(2 ** 100))
|
|
|
|
def test_dot_precision_context_manager(self):
|
|
x = jnp.zeros((2, 2))
|
|
|
|
with jax.default_matmul_precision(None):
|
|
jnp.dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
|
self.assertIn('precision=None', str(jaxpr))
|
|
|
|
with jax.default_matmul_precision("bfloat16"):
|
|
x @ x # doesn't crash
|
|
jaxpr = jax.make_jaxpr(op.matmul)(x, x)
|
|
self.assertIn('Precision.DEFAULT', str(jaxpr))
|
|
|
|
with jax.default_matmul_precision("tensorfloat32"):
|
|
jnp.dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
|
self.assertIn('Precision.HIGH', str(jaxpr))
|
|
|
|
with jax.default_matmul_precision("float32"):
|
|
jnp.dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
|
self.assertIn('Precision.HIGHEST', str(jaxpr))
|
|
|
|
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
|
with jax.default_matmul_precision("tensorfloat32"):
|
|
dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(dot)(x, x)
|
|
self.assertIn('Precision.HIGHEST', str(jaxpr))
|
|
|
|
def test_dot_precision_flag(self):
|
|
x = jnp.zeros((2, 2))
|
|
|
|
prev_val = config._read("jax_default_matmul_precision")
|
|
try:
|
|
config.FLAGS.jax_default_matmul_precision = "tensorfloat32"
|
|
jnp.dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
|
finally:
|
|
config.FLAGS.jax_default_matmul_precision = prev_val
|
|
self.assertIn('Precision.HIGH', str(jaxpr))
|
|
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
|
|
|
prev_val = config._read("jax_default_matmul_precision")
|
|
try:
|
|
config.update('jax_default_matmul_precision','tensorfloat32')
|
|
jnp.dot(x, x) # doesn't crash
|
|
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
|
finally:
|
|
config.update('jax_default_matmul_precision', prev_val)
|
|
self.assertIn('Precision.HIGH', str(jaxpr))
|
|
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
|
|
|
def test_dot_precision_forces_retrace(self):
|
|
num_traces = 0
|
|
|
|
def g(x):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
return jnp.dot(x, x)
|
|
def f_cond(x):
|
|
return lax.cond(True, g, g, x)
|
|
|
|
@jax.jit
|
|
def f_jit(x):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
return jnp.dot(x, x)
|
|
|
|
for f in [f_jit, f_cond]:
|
|
precision = config.jax_default_matmul_precision
|
|
try:
|
|
num_traces = 0
|
|
x = jnp.zeros((2, 2))
|
|
f(x)
|
|
self.assertEqual(num_traces, 1)
|
|
f(x)
|
|
self.assertEqual(num_traces, 1)
|
|
with jax.default_matmul_precision("tensorfloat32"):
|
|
f(x)
|
|
self.assertEqual(num_traces, 2)
|
|
FLAGS.jax_default_matmul_precision = "float32"
|
|
f(x)
|
|
self.assertGreaterEqual(num_traces, 2)
|
|
nt = num_traces
|
|
f(x)
|
|
self.assertEqual(num_traces, nt + 1)
|
|
f(x)
|
|
self.assertEqual(num_traces, nt + 1)
|
|
finally:
|
|
FLAGS.jax_default_matmul_precision = precision
|
|
|
|
def test_rank_promotion_forces_retrace(self):
|
|
num_traces = 0
|
|
|
|
def g(x):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
return x + x
|
|
def f_cond(x):
|
|
return lax.cond(True, g, g, x)
|
|
|
|
@jax.jit
|
|
def f_jit(x):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
return x + x
|
|
|
|
for f in [f_jit, f_cond]:
|
|
allow_promotion = config.jax_numpy_rank_promotion
|
|
try:
|
|
num_traces = 0
|
|
@jax.jit
|
|
def f(x):
|
|
nonlocal num_traces
|
|
num_traces += 1
|
|
return x + x
|
|
x = jnp.zeros((2, 2))
|
|
f(x)
|
|
self.assertEqual(num_traces, 1)
|
|
f(x)
|
|
self.assertEqual(num_traces, 1)
|
|
with jax.numpy_rank_promotion("warn"):
|
|
f(x)
|
|
self.assertEqual(num_traces, 2)
|
|
FLAGS.jax_numpy_rank_promotion = "raise"
|
|
f(x)
|
|
self.assertGreaterEqual(num_traces, 2)
|
|
nt = num_traces
|
|
f(x)
|
|
self.assertEqual(num_traces, nt + 1)
|
|
f(x)
|
|
self.assertEqual(num_traces, nt + 1)
|
|
finally:
|
|
FLAGS.jax_numpy_rank_promotion = allow_promotion
|
|
|
|
def test_backward_pass_ref_dropping(self):
|
|
refs = []
|
|
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return x
|
|
def f_fwd(x):
|
|
return x, None
|
|
def f_rev(_, g):
|
|
assert len(refs) != 2 or refs[0]() is None
|
|
zero = np.zeros(())
|
|
refs.append(weakref.ref(zero))
|
|
return (zero,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
api.grad(lambda x: f(f(f(x))))(1.)
|
|
|
|
def test_custom_vjp_scan_batching_edge_case(self):
|
|
# https://github.com/google/jax/issues/5832
|
|
@jax.custom_vjp
|
|
def mul(x, coeff): return x * coeff
|
|
def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff)
|
|
def mul_bwd(res, g):
|
|
x, coeff = res
|
|
g_x = g * coeff
|
|
g_coeff = (x * g).sum()
|
|
return g_x, g_coeff
|
|
mul.defvjp(mul_fwd, mul_bwd)
|
|
|
|
def scan_over_mul(x, coeff):
|
|
def f_(x, t):
|
|
return mul(x, coeff), None
|
|
y, _ = jax.lax.scan(f_, x, jnp.arange(3))
|
|
return y
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
key1, key2 = jax.random.split(key, 2)
|
|
x_batch = jax.random.normal(key1, (3, 2))
|
|
covector_batch = jax.random.normal(key2, (3, 2))
|
|
coeff = jnp.array(1.)
|
|
|
|
batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0)
|
|
res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff)
|
|
vjp_fun(covector_batch) # doesn't crash
|
|
|
|
jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2,
|
|
modes=['rev'])
|
|
|
|
def test_jit_inline(self):
|
|
@partial(api.jit, inline=False)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
jaxpr = api.make_jaxpr(f)(3)
|
|
self.assertIn('xla_call', str(jaxpr))
|
|
|
|
@partial(api.jit, inline=True)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
jaxpr = api.make_jaxpr(f)(3)
|
|
self.assertNotIn('xla_call', str(jaxpr))
|
|
|
|
# Repro for https://github.com/google/jax/issues/7229.
|
|
def test_compute_with_large_transfer(self):
|
|
def f(x, delta):
|
|
return x + jnp.asarray(delta, x.dtype)
|
|
|
|
# A large and potentially unaligned array to trigger non-zero-copy and
|
|
# async device array copy.
|
|
xs = np.random.uniform(0., 1., size=(10, 131, 111, 3)).astype(np.float32)
|
|
for x in xs:
|
|
delta = np.random.uniform(-0.5, 0.5, size=())
|
|
jitted_f = api.jit(f)
|
|
np.testing.assert_allclose(jitted_f(x, delta), f(x, delta))
|
|
|
|
def test_vjp_fun_jit(self):
|
|
# test that the function returned by vjp can be returned
|
|
# from and passed to jitted functions
|
|
f = lambda x: 2. * x
|
|
|
|
@partial(jit, static_argnums=0)
|
|
def linearize_vjp(f, x):
|
|
_, vjp_fun = api.vjp(f, x)
|
|
return vjp_fun
|
|
|
|
linearized = linearize_vjp(f, 1.)
|
|
actual = jit(lambda f, x: f(x))(linearized, 3.)
|
|
expected = (6.,)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linearize_fun_jit(self):
|
|
# test that the function returned by linearize can be returned
|
|
# from and passed to jitted functions
|
|
f = lambda x: 2. * x
|
|
|
|
@partial(jit, static_argnums=0)
|
|
def linearize(f, x):
|
|
_, jvp_fun = api.linearize(f, x)
|
|
return jvp_fun
|
|
|
|
linearized = linearize(f, 1.)
|
|
actual = jit(lambda f, x: f(x))(linearized, 3.)
|
|
expected = 6.
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_linear_transpose_fun_jit(self):
|
|
# test that the function returned by linear_transpose can be returned
|
|
# from and passed to jitted functions
|
|
f = lambda x: 2. * x
|
|
|
|
@partial(jit, static_argnums=0)
|
|
def transpose(f, x):
|
|
return api.linear_transpose(f, x)
|
|
|
|
transposed = transpose(f, 1.)
|
|
actual = jit(lambda f, x: f(x))(transposed, 3.)
|
|
expected = (6.,)
|
|
self.assertEqual(actual, expected)
|
|
|
|
|
|
class RematTest(jtu.JaxTestCase):
|
|
|
|
def test_remat_basic(self):
|
|
@api.remat
|
|
def g(x):
|
|
return lax.sin(lax.sin(x)), 3.
|
|
|
|
def f(x):
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
ans = f(2.)
|
|
expected = np.sin(np.sin(2.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans, f_lin = api.linearize(f, 2.)
|
|
expected = np.sin(np.sin(2.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = f_lin(3.)
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.) * 3.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
sin_calls = []
|
|
cos_calls = []
|
|
sin_impl = lax.sin_p.impl
|
|
cos_impl = lax.cos_p.impl
|
|
try:
|
|
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
|
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
|
f_lin(3.)
|
|
finally:
|
|
lax.sin_p.def_impl(sin_impl)
|
|
lax.cos_p.def_impl(cos_impl)
|
|
self.assertEqual(len(sin_calls), 1)
|
|
self.assertEqual(len(cos_calls), 2)
|
|
|
|
def test_remat_freevars(self):
|
|
def f1(x):
|
|
y = 2 * jnp.sin(x)
|
|
z = jnp.cos(x) * jnp.sin(y)
|
|
return z
|
|
|
|
def f2(x):
|
|
y = 2 * jnp.sin(x)
|
|
z = api.remat(lambda x: jnp.cos(x) * jnp.sin(y))(x)
|
|
return z
|
|
|
|
ans, f_lin = api.linearize(f2, 2.)
|
|
expected, f_lin_expected = api.linearize(f1, 2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = f_lin(3.)
|
|
expected = f_lin_expected(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_grad_python_control_flow(self):
|
|
@partial(api.remat, concrete=True)
|
|
def g(x):
|
|
if x > 0:
|
|
return lax.sin(x), 3.
|
|
else:
|
|
return lax.cos(x), 4.
|
|
|
|
def f(x):
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
ans = f(2.)
|
|
expected = np.sin(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(f)(2.)
|
|
expected = np.cos(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_jit(self):
|
|
@api.remat
|
|
def g(x):
|
|
return lax.sin(lax.sin(x))
|
|
|
|
def f_(x):
|
|
return g(x)
|
|
f = api.jit(f_)
|
|
|
|
ans = f(2.)
|
|
expected = np.sin(np.sin(2.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(f)(2.)
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.grad(f_))(2.)
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_vmap(self):
|
|
@api.remat
|
|
def g(x):
|
|
return lax.sin(lax.sin(x))
|
|
|
|
x = np.arange(3.)
|
|
|
|
ans = api.vmap(g)(x)
|
|
expected = np.sin(np.sin(x))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jacfwd(g)(x)
|
|
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jacrev(g)(x)
|
|
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_higher_order_autodiff(self):
|
|
def f(x):
|
|
return lax.cos(lax.sin(x))
|
|
g = api.remat(f)
|
|
|
|
ans = api.grad(api.grad(g))(3.)
|
|
expected = api.grad(api.grad(f))(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_scan(self):
|
|
to_scan = lambda c, x: (jnp.sin(c), None)
|
|
|
|
def f_noremat(x):
|
|
y, _ = lax.scan(to_scan, x, np.arange(3.))
|
|
return y
|
|
|
|
def f_yesremat(x):
|
|
y, _ = lax.scan(api.remat(to_scan), x, np.arange(3.))
|
|
return y
|
|
|
|
ans = f_yesremat(4.)
|
|
expected = f_noremat(4.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(f_yesremat)(4.)
|
|
expected = api.grad(f_noremat)(4.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
|
|
|
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
|
|
|
def test_remat_no_redundant_flops(self):
|
|
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
|
|
|
@api.jit
|
|
def g(x):
|
|
return f(2., x)
|
|
|
|
@api.remat
|
|
def f(x, y):
|
|
return jnp.sin(x) * y
|
|
|
|
# We swap out sin_p's impl rule to count how many times it's invoked
|
|
called = []
|
|
sin_impl = lax.sin_p.impl
|
|
try:
|
|
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
|
api.grad(g)(3.)
|
|
finally:
|
|
lax.sin_p.def_impl(sin_impl)
|
|
num_calls = len(called)
|
|
self.assertLessEqual(num_calls, 1)
|
|
|
|
def test_remat_binomial_checkpointing(self):
|
|
def binom_checkpoint(funs):
|
|
if len(funs) == 1:
|
|
return funs[0]
|
|
else:
|
|
f1 = binom_checkpoint(funs[:len(funs)//2])
|
|
f2 = binom_checkpoint(funs[len(funs)//2:])
|
|
return api.remat(lambda x: f1(f2(x)))
|
|
|
|
f1 = binom_checkpoint([jnp.sin, jnp.sin, jnp.sin, jnp.sin])
|
|
f2 = lambda x: jnp.sin(jnp.sin(jnp.sin(jnp.sin(x))))
|
|
x = 4.
|
|
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
|
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
|
|
|
def test_remat_symbolic_zeros(self):
|
|
# code from https://github.com/google/jax/issues/1907
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
key, split = jax.random.split(key)
|
|
n = 5
|
|
|
|
def func(D0):
|
|
def shift(R, dR, **unused_kwargs):
|
|
return R + dR
|
|
|
|
def apply_fn(R):
|
|
return D0 * R
|
|
|
|
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
|
|
dtype=jnp.float32)
|
|
|
|
def move(R,i):
|
|
F = apply_fn(R)
|
|
return shift(R, 0.001 * F), jnp.array([0.])
|
|
|
|
move = api.remat(move)
|
|
R, temp = lax.scan(move, Rinit, jnp.arange(2))
|
|
return R[0, 0]
|
|
|
|
api.grad(func)(5.0) # doesn't crash
|
|
|
|
def test_remat_jit2(self):
|
|
@api.jit
|
|
def f(x):
|
|
y = 2 * x
|
|
|
|
@api.remat
|
|
def g():
|
|
return y
|
|
|
|
return g()
|
|
|
|
self.assertAllClose(f(3), 6, check_dtypes=False)
|
|
|
|
def test_remat_nontrivial_env(self):
|
|
# simplified from https://github.com/google/jax/issues/2030
|
|
|
|
@api.remat
|
|
def foo(state, dt=0.5, c=1):
|
|
u, u_t = state
|
|
u_tt = c**2 * u
|
|
u_t = u_t + u_tt * dt
|
|
return (u, u_t)
|
|
|
|
@partial(api.jit, static_argnums=(1,))
|
|
def _multi_step(state, count, dt, c):
|
|
f = lambda s, _: (foo(s, dt, c), _)
|
|
return lax.scan(f, state, None, count)
|
|
|
|
def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):
|
|
return _multi_step(state, count, dt, c)
|
|
|
|
def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1):
|
|
init = (u0, jnp.zeros_like(u0))
|
|
(uf, _), _ = multi_step(init, steps, dt, c)
|
|
return ((uf - target) ** 2).mean()
|
|
|
|
target = jnp.zeros((128, 128))
|
|
u0 = jnp.ones_like(target)
|
|
loss(u0, target, 10) # doesn't crash
|
|
|
|
def test_remat_jit3(self):
|
|
# https://github.com/google/jax/issues/2180
|
|
def f(w, x):
|
|
a = jnp.dot(x, w)
|
|
b = jnp.einsum("btd,bTd->btT", a, a)
|
|
c = jnp.einsum("btT,btd->btd", b, a)
|
|
return jnp.sum(c)
|
|
|
|
w = jnp.ones([1, 1])
|
|
x = jnp.ones([1, 1, 1])
|
|
f = api.remat(f)
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
@api.jit
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
def f(w, x):
|
|
a = mul(w, x)
|
|
b = mul(a, a)
|
|
return b
|
|
|
|
w = 1.
|
|
x = 1.
|
|
f = api.remat(f)
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
def test_remat_scan2(self):
|
|
# https://github.com/google/jax/issues/1963
|
|
|
|
def scan_bug(x0):
|
|
f = lambda x, _: (x + 1, None)
|
|
def scanned_f(x, _):
|
|
return lax.scan(f, x, xs=None, length=1)[0], None
|
|
x, _ = jax.remat(scanned_f)(x0, None)
|
|
return x
|
|
|
|
jax.grad(scan_bug)(1.0) # doesn't crash
|
|
|
|
def test_remat_jit_static_argnum_omnistaging(self):
|
|
# https://github.com/google/jax/issues/2833
|
|
def named_call(f):
|
|
def named_f(*args):
|
|
f_ = lu.wrap_init(lambda: (f(*args),))
|
|
out, = core.call_p.bind(f_)
|
|
return out
|
|
return named_f
|
|
|
|
def f(a_bool, y):
|
|
if a_bool:
|
|
return y + 1
|
|
else:
|
|
return y
|
|
|
|
api.jit(named_call(f), static_argnums=0)(True, 1) # no crash
|
|
|
|
def test_remat_eval_counter(self):
|
|
# https://github.com/google/jax/issues/2737
|
|
add_one_p = Primitive('add_one')
|
|
add_one = add_one_p.bind
|
|
|
|
num_evals = 0
|
|
|
|
@contextmanager
|
|
def assertEvals(n):
|
|
start = num_evals
|
|
yield
|
|
assert num_evals - start == n
|
|
|
|
def add_one_impl(x):
|
|
nonlocal num_evals
|
|
num_evals += 1
|
|
return x + 1
|
|
add_one_p.def_impl(add_one_impl)
|
|
|
|
def add_one_jvp(pin, tin):
|
|
pout = add_one(pin[0])
|
|
return pout, pout * tin[0]
|
|
ad.primitive_jvps[add_one_p] = add_one_jvp
|
|
|
|
add_one_p.def_abstract_eval(lambda x: x)
|
|
|
|
v = np.zeros((1,))
|
|
|
|
f = jax.remat(add_one)
|
|
g = jax.remat(lambda x: add_one(f(x)))
|
|
|
|
# 2 calls needed to evaluate g
|
|
with assertEvals(2):
|
|
_, vjp = jax.vjp(g, v)
|
|
# 2 calls made while transposing g, 1 call made while transposing f
|
|
with assertEvals(3):
|
|
vjp(v)
|
|
|
|
@jax._src.util.curry
|
|
def call(f, *args):
|
|
return jax.core.call(
|
|
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
|
|
*args, name='foo')[0]
|
|
|
|
f = call(add_one)
|
|
g = jax.remat(lambda x: add_one(f(x)))
|
|
|
|
# 2 calls needed to evaluate g
|
|
with assertEvals(2):
|
|
_, vjp = jax.vjp(g, v)
|
|
# 2 calls made while transposing g, no reevaluation for transposition of f
|
|
with assertEvals(2):
|
|
vjp(v)
|
|
|
|
def test_escaped_tracer_remat(self):
|
|
# b/169779185
|
|
def f():
|
|
seq = [jnp.zeros([])]
|
|
def g():
|
|
seq[0] += 1 # this is line 7 btw
|
|
return seq[0]
|
|
|
|
api.remat(g)()
|
|
api.remat(g)()
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError, "global state"):
|
|
api.jit(f)()
|
|
|
|
def test_no_cse_widget_on_primals(self):
|
|
@api.remat
|
|
def g(x):
|
|
return lax.sin(lax.sin(x)), 3.
|
|
|
|
def f(x):
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
c = api.xla_computation(f)(2.)
|
|
self.assertNotIn('while', c.as_hlo_text())
|
|
self.assertNotIn('conditional', c.as_hlo_text())
|
|
|
|
c = api.xla_computation(grad(f))(2.)
|
|
text = c.as_hlo_text()
|
|
self.assertTrue('while' in text or 'conditional' in text)
|
|
|
|
def test_no_cse_widget_with_prevent_cse_false(self):
|
|
@partial(api.remat, prevent_cse=False)
|
|
def g(x):
|
|
return lax.sin(lax.sin(x)), 3.
|
|
|
|
def f(x):
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
c = api.xla_computation(f)(2.)
|
|
self.assertNotIn('while', c.as_hlo_text())
|
|
self.assertNotIn('conditional', c.as_hlo_text())
|
|
|
|
c = api.xla_computation(grad(f))(2.)
|
|
self.assertNotIn('while', c.as_hlo_text())
|
|
self.assertNotIn('conditional', c.as_hlo_text())
|
|
|
|
|
|
class JaxprTest(jtu.JaxTestCase):
|
|
|
|
def test_scalar_literals(self):
|
|
jaxpr = api.make_jaxpr(lambda x: x + 2)(42)
|
|
self.assertLen(jaxpr.jaxpr.constvars, 0)
|
|
|
|
def test_abstract_inputs(self):
|
|
jaxpr = api.make_jaxpr(lambda x: x + 2.)(
|
|
types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)))
|
|
self.assertEqual(jaxpr.in_avals[0].shape, ())
|
|
self.assertEqual(jaxpr.in_avals[0].dtype, np.float32)
|
|
|
|
def test_const(self):
|
|
def fun(x):
|
|
return (x, 1., np.zeros(1))
|
|
|
|
expected = """
|
|
{ lambda a ; b.
|
|
let
|
|
in (b, 1.0, a) }
|
|
"""
|
|
|
|
jaxpr = api.make_jaxpr(fun)(0.)
|
|
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
|
|
|
def test_cond(self):
|
|
def f(x):
|
|
return lax.cond(x >= 0.,
|
|
x + 1.,
|
|
lambda xt: xt + x,
|
|
x + 2.,
|
|
lambda xf: xf - x)
|
|
expected = """
|
|
{ lambda ; a.
|
|
let b = ge a 0.0
|
|
c = add a 1.0
|
|
d = add a 2.0
|
|
e = convert_element_type[ new_dtype=int32
|
|
weak_type=False ] b
|
|
f = cond[ branches=( { lambda ; e_ a b c.
|
|
let d = sub c a
|
|
in (d,) }
|
|
{ lambda ; a f_ b c.
|
|
let d = add b a
|
|
in (d,) } )
|
|
linear=(False, False, False, False) ] e a a c d
|
|
in (f,) }
|
|
"""
|
|
jaxpr = api.make_jaxpr(f)(3.)
|
|
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
|
|
|
def test_make_jaxpr_static_argnums(self):
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
jaxpr = api.make_jaxpr(f, static_argnums=(1,))(2, 3)
|
|
self.assertIn('3', str(jaxpr))
|
|
|
|
def test_make_jaxpr_return_shape(self):
|
|
_, shape_tree = api.make_jaxpr(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
|
|
return_shape=True)(np.int32(1))
|
|
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
|
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
|
self.assertEqual(shape_tree, expected)
|
|
|
|
def test_make_jaxpr_axis_env(self):
|
|
def f(x):
|
|
return x - lax.psum(x, 'i')
|
|
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
|
|
self.assertIn('psum', str(jaxpr))
|
|
|
|
def test_make_jaxpr_named(self):
|
|
def f(x):
|
|
return x - lax.psum(x, 'i')
|
|
|
|
x = api.ShapeDtypeStruct(
|
|
shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10})
|
|
jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x)
|
|
named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars]
|
|
self.assertEqual(named_shapes, [{'i': 10}, {}])
|
|
|
|
@parameterized.parameters(True, False)
|
|
def test_vjp_reduce_axes_jaxpr(self, gy_batched):
|
|
def f(w, x):
|
|
return jnp.sin(jnp.dot(x, w))
|
|
|
|
w = api.ShapeDtypeStruct(
|
|
shape=(3, 4), dtype=jnp.float32, named_shape={})
|
|
x = api.ShapeDtypeStruct(
|
|
shape=(3,), dtype=jnp.float32, named_shape={'batch': 2})
|
|
gy = api.ShapeDtypeStruct(
|
|
shape=(4,), dtype=jnp.float32,
|
|
named_shape={'batch': 2} if gy_batched else {})
|
|
|
|
# per-example
|
|
jaxpr, shapes = api.make_jaxpr(
|
|
lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)],
|
|
return_shape=True)(w, x, gy)
|
|
expected = (api.ShapeDtypeStruct(
|
|
shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x)
|
|
self.assertEqual(shapes, expected)
|
|
self.assertNotIn('psum', str(jaxpr))
|
|
|
|
# reduced
|
|
jaxpr, shapes = api.make_jaxpr(
|
|
lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy),
|
|
axis_env=[('batch', 2)],
|
|
return_shape=True)(w, x, gy)
|
|
expected = (w, x)
|
|
self.assertEqual(shapes, expected)
|
|
self.assertIn('psum', str(jaxpr))
|
|
|
|
|
|
class CustomJVPTest(jtu.JaxTestCase):
|
|
|
|
def test_basic(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * jnp.cos(x) * g
|
|
f.defjvp(f_jvp)
|
|
|
|
x = 3.
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
|
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
|
|
|
def test_invariance(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return jnp.cos(2 * x) / 2.
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return (f(x), 3 * g)
|
|
f.defjvp(f_jvp)
|
|
def f2(x):
|
|
y, _ = api.jvp(f, (x,), (x,))
|
|
return y
|
|
def f3(x):
|
|
y, _ = api.jvp(f2, (x,), (x,))
|
|
return y
|
|
x = 1.
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
api.jvp(f2, (x,), (x,)),
|
|
check_dtypes=False)
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
api.jvp(f3, (x,), (x,)),
|
|
check_dtypes=False)
|
|
|
|
def test_python_control_flow(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
if x > 0:
|
|
return jnp.sin(x)
|
|
else:
|
|
return jnp.cos(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
if x > 0:
|
|
return f(x), 2 * g
|
|
else:
|
|
return f(x), 3 * g
|
|
f.defjvp(f_jvp)
|
|
x = 2.
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
self.assertAllClose(f(-x), jnp.cos(-x))
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
|
(jnp.sin(x), 2.),
|
|
check_dtypes=False)
|
|
self.assertAllClose(api.jvp(f, (-x,), (1.,)),
|
|
(jnp.cos(-x), 3.),
|
|
check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False)
|
|
|
|
def test_vmap(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return jnp.sin(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
assert jnp.ndim(x) == jnp.ndim(g) == 0
|
|
return f(x), 2 * jnp.cos(x) * g
|
|
f.defjvp(f_jvp)
|
|
|
|
x = jnp.arange(3.)
|
|
xx = jnp.arange(6.).reshape(2, 3)
|
|
|
|
# vmap of f
|
|
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
|
|
|
# vmap of jvp of f
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
|
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
|
|
|
# jvp of vmap of f
|
|
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
|
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
|
|
|
# vmap of jvp of vmap of f
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
|
|
|
def test_jit(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * jnp.cos(x) * g
|
|
f.defjvp(f_jvp)
|
|
|
|
x = 3.
|
|
|
|
# jit
|
|
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
|
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
|
|
|
# jit of jvp
|
|
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
|
check_dtypes=False)
|
|
|
|
# jvp of jit
|
|
self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)),
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
|
check_dtypes=False)
|
|
|
|
def test_pytrees(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return {'b': jnp.sin(x['a'])}
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']}
|
|
f.defjvp(f_jvp)
|
|
x = {'a': 3.}
|
|
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
({'b': jnp.sin(x['a'])},
|
|
{'b': 2 * jnp.cos(x['a']) * x['a']}),
|
|
check_dtypes=False)
|
|
|
|
def test_kwargs(self):
|
|
# from https://github.com/google/jax/issues/1938
|
|
@api.custom_jvp
|
|
def my_fun(x, y, c=1.):
|
|
return c * (x + y)
|
|
def my_jvp(primals, tangents):
|
|
x, y, c = primals
|
|
t_x, t_y, t_c = tangents
|
|
return my_fun(x, y, c), t_c
|
|
my_fun.defjvp(my_jvp)
|
|
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
|
|
f(10., 5.) # doesn't crash
|
|
api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
|
|
|
|
def test_initial_style(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return 3 * x
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * g
|
|
f.defjvp(f_jvp)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.grad(foo)(3.)
|
|
expected = 2.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.jit(foo))(3.)
|
|
expected = 2.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.grad(foo))(3.)
|
|
expected = 2.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
expected = 0.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.grad(api.jit(foo)))(3.)
|
|
expected = 0.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.jit(api.grad(foo)))(3.)
|
|
expected = 0.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.grad(api.grad(foo)))(3.)
|
|
expected = 0.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_initial_style_vmap(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return 3 * x
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * g
|
|
f.defjvp(f_jvp)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.vmap(foo)(jnp.ones(3))
|
|
expected = 3. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.vmap(api.jit(foo))(jnp.ones(3))
|
|
expected = 3. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.vmap(foo))(jnp.ones(3))
|
|
expected = 3. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_closed_over_tracers_error_message(self):
|
|
def f(x):
|
|
@api.custom_jvp
|
|
def g(y):
|
|
return x + y
|
|
def g_jvp(primals, tangents):
|
|
return g(x), 2 * primals[0]
|
|
g.defjvp(g_jvp)
|
|
return g(1.)
|
|
|
|
self.assertRaises(ad.CustomJVPException, lambda: api.jvp(f, (3.,), (1.,)))
|
|
self.assertRaises(ad.CustomJVPException, lambda: api.grad(f)(3.))
|
|
|
|
def test_nondiff_arg(self):
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
def app(f, x):
|
|
return f(x)
|
|
def app_jvp(f, primals, tangents):
|
|
(x,), (t,) = primals, tangents
|
|
return app(f, x), 3 * t
|
|
app.defjvp(app_jvp)
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
expected = 2
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,))
|
|
expected = (2., 3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_nondiff_arg_jit_tracer(self):
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
def f(x, y):
|
|
return x * y
|
|
def f_jvp(x, primals, tangents):
|
|
(y,), (t_y,) = primals, tangents
|
|
return f(x, y), 5 * t_y
|
|
f.defjvp(f_jvp)
|
|
|
|
@jit
|
|
def g(x, y):
|
|
return f(x, y)
|
|
|
|
ans = api.jvp(lambda y: g(2., y), (3.,), (1.,))
|
|
expected = (6., 5.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_nondiff_arg_hiding_jvp_tracer(self):
|
|
def f(x):
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
def g(h, x):
|
|
return h(x)
|
|
@g.defjvp
|
|
def g_jvp(h, primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return g(h, x), 2. * t
|
|
h = lambda y: x + y # capture x
|
|
return g(h, x)
|
|
|
|
with self.assertRaisesRegex(ad.CustomJVPException, "Detected differentiation"):
|
|
api.jvp(f, (2.,), (1.,))
|
|
|
|
def test_vmap_axes(self):
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
def test_pmap(self):
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
def test_missing_jvp_rule_error_message(self):
|
|
@api.custom_jvp
|
|
def foo(x):
|
|
return x ** 2
|
|
|
|
self.assertRaisesRegex(
|
|
AttributeError,
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
lambda: foo(2))
|
|
self.assertRaisesRegex(
|
|
AttributeError,
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
lambda: api.jvp(foo, (2.,), (1.,)))
|
|
self.assertRaisesRegex(
|
|
AttributeError,
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
def test_jvp_rule_inconsistent_pytree_structures_error_message(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return (x**2,)
|
|
|
|
@f.defjvp
|
|
def foo_jvp(primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return f(x), [2 * x * t, x]
|
|
|
|
f(2.) # doesn't crash
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
re.escape(
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
"with equal container (pytree) structures, but got "
|
|
"{} and {} respectively.".format(
|
|
tree_util.tree_structure((1,)),
|
|
tree_util.tree_structure([1, 2]))
|
|
),
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
def test_primal_tangent_aval_disagreement_error_message(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return x ** 2
|
|
|
|
@f.defjvp
|
|
def foo_jvp(primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return f(x), jnp.reshape(t, (1,))
|
|
|
|
f(2.) # doesn't crash
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
re.escape(
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
"with equal shapes and dtypes, but got float32[] and float32[1] "
|
|
"respectively."),
|
|
lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),)))
|
|
|
|
def test_jvp_rule_doesnt_return_pair_error_message(self):
|
|
# https://github.com/google/jax/issues/2516
|
|
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return x ** 2
|
|
|
|
@f.defjvp
|
|
def foo_jvp(primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return t
|
|
|
|
f(2.) # doesn't crash
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
re.escape(
|
|
"Custom JVP rule must produce a pair (list or tuple of length two) "
|
|
"representing primal and tangent outputs, got 1.0"),
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
def test_multiple_rule_invocations(self):
|
|
@jax.custom_jvp
|
|
def expit(x):
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
@expit.defjvp
|
|
def _expit_jvp(primals, tangents):
|
|
(x,), (t,) = primals, tangents
|
|
ans = expit(x)
|
|
t_out = t * ans * (1 - ans)
|
|
return ans, t_out
|
|
|
|
def scanned_fun(c, _):
|
|
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
def foo(x):
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
return c[-1]
|
|
|
|
# just make sure these don't crash
|
|
foo(3.)
|
|
grad(foo)(3.)
|
|
grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.))
|
|
|
|
def test_hard_stuff(self):
|
|
arr = jnp.ones((5, 2, 2))
|
|
api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash
|
|
|
|
def test_hard_stuff2(self):
|
|
@jax.custom_jvp
|
|
def f(x):
|
|
return lax.tie_in(x, np.zeros(x.shape, x.dtype))
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return f(x), t
|
|
|
|
# don't crash
|
|
jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
|
|
jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))
|
|
|
|
def test_hard_stuff3(self):
|
|
@jax.custom_jvp
|
|
def relu(x):
|
|
return jnp.maximum(x, 0)
|
|
|
|
@relu.defjvp
|
|
def _relu_jvp(primals, tangents):
|
|
x, = primals
|
|
t, = tangents
|
|
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
|
|
|
|
def scanned_fun(c, _):
|
|
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
def f(x):
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
return c[-1]
|
|
|
|
# don't crash
|
|
jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
|
|
jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),))
|
|
|
|
def test_eval_shape(self):
|
|
@jax.custom_jvp
|
|
def expit(x):
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
@expit.defjvp
|
|
def _expit_jvp(primals, tangents):
|
|
(x,), (t,) = primals, tangents
|
|
ans = expit(x)
|
|
t_out = t * ans * (1 - ans)
|
|
return ans, t_out
|
|
|
|
# don't crash
|
|
api.eval_shape(expit, jnp.ones((2, 3)))
|
|
api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3)))
|
|
|
|
def test_jaxpr_zeros(self):
|
|
# from https://github.com/google/jax/issues/2657
|
|
@api.custom_jvp
|
|
def f(A, b):
|
|
return A @ b
|
|
|
|
def f_jvp(primals, tangents):
|
|
A, b = primals
|
|
dA, db = tangents
|
|
z = f(A, b)
|
|
dz = A @ db + dA @ b
|
|
return z, dz
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
def experiment(theta):
|
|
def step(q, _):
|
|
z = f(jnp.eye(3), jnp.ones(3) * theta)
|
|
q += z[0]
|
|
return q, q
|
|
|
|
q = 0.
|
|
q, _ = lax.scan(step, q, None, 4)
|
|
return q
|
|
|
|
grad(experiment)(1.) # doesn't crash
|
|
|
|
def test_linear_in_scan(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return -x
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
x_dot, = tangents
|
|
return f(x), f(x_dot)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.grad(foo)(3.)
|
|
expected = -1.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_custom_jvps_first_rule_is_none(self):
|
|
# https://github.com/google/jax/issues/3389
|
|
@api.custom_jvp
|
|
def f(x, y):
|
|
return x ** 2 * y
|
|
|
|
f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot)
|
|
ans = grad(f, 1)(2., 3.) # doesn't crash
|
|
expected = 12.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_concurrent_initial_style(self):
|
|
# https://github.com/google/jax/issues/3843
|
|
def unroll(param, sequence):
|
|
def scan_f(prev_state, inputs):
|
|
return prev_state, jax.nn.sigmoid(param * inputs)
|
|
return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1])
|
|
|
|
def run():
|
|
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))
|
|
|
|
expected = run()
|
|
|
|
# we just don't want this to crash
|
|
n_workers = 2
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
|
|
futures = []
|
|
for _ in range(n_workers):
|
|
futures.append(e.submit(run))
|
|
results = [f.result() for f in futures]
|
|
for ans in results:
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def test_nondiff_argnums_vmap_tracer(self):
|
|
# https://github.com/google/jax/issues/3964
|
|
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
|
|
def sample(shape, param, seed):
|
|
return jax.random.uniform(key=seed, shape=shape, minval=param)
|
|
|
|
@sample.defjvp
|
|
def sample_jvp(shape, seed, primals, tangents):
|
|
param, = primals
|
|
dparam, = tangents
|
|
dparam = jnp.broadcast_to(dparam, shape)
|
|
samples = sample(shape, param, seed)
|
|
return samples, samples * dparam # dummy jvp for proof of concept
|
|
|
|
# check these don't crash
|
|
jax.vmap(lambda seed: sample((2,3), 1., seed))(
|
|
jax.random.split(jax.random.PRNGKey(1), 10))
|
|
jax.jvp(lambda x: sample((2, 3), x, jax.random.PRNGKey(1)),
|
|
(1.,), (1.,))
|
|
|
|
def test_fun_with_nested_calls_2(self):
|
|
def call(f, *args):
|
|
f = api.custom_jvp(f)
|
|
f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents)))
|
|
return f(*args)
|
|
|
|
def fun_with_nested_calls_2(x):
|
|
def bar(y):
|
|
def baz(w):
|
|
q = call(lambda x: y, x)
|
|
q = q + call(lambda: y)
|
|
q = q + call(lambda y: w + y, y)
|
|
q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q
|
|
return q
|
|
return api.jit(baz)(x)
|
|
return call(bar, x)
|
|
|
|
# test these don't crash
|
|
self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.),
|
|
fun_with_nested_calls_2(3.))
|
|
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
|
|
|
|
def test_closure_with_vmap(self):
|
|
# https://github.com/google/jax/issues/3822
|
|
alpha = np.float32(2.)
|
|
|
|
def sample(seed):
|
|
@api.custom_jvp
|
|
def f(alpha):
|
|
return jax.random.gamma(seed, alpha, shape=[])
|
|
|
|
@f.defjvp
|
|
def f_jvp(primal, tangent):
|
|
alpha = primal
|
|
dalpha = tangent
|
|
sample = f(alpha)
|
|
partial_alpha = lax.random_gamma_grad(alpha, sample)
|
|
return sample, partial_alpha * dalpha
|
|
return f(alpha)
|
|
|
|
api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_float0(self):
|
|
@api.custom_jvp
|
|
def f(x, y):
|
|
return x, y
|
|
def f_jvp(primals, _):
|
|
# we need a defined (non-float0) tangent to trigger the rule
|
|
return primals, (2., 1)
|
|
f.defjvp(f_jvp)
|
|
|
|
primals = (2., 3)
|
|
tangents = (np.ones(()), np.zeros((), float0),)
|
|
expected_tangents = (2., np.zeros((), float0))
|
|
self.assertArraysEqual(api.jvp(f, primals, tangents),
|
|
(primals, expected_tangents))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_float0_initial_style(self):
|
|
@api.custom_jvp
|
|
def f(x, y):
|
|
return x, y
|
|
def f_jvp(primals, _):
|
|
x, y = primals
|
|
return (x, y), (2., 1)
|
|
f.defjvp(f_jvp)
|
|
|
|
def foo(x, y):
|
|
out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1)
|
|
return out
|
|
|
|
primals = (2., 3)
|
|
tangents = (np.ones(()), np.zeros((), float0),)
|
|
expected_tangents = (2., np.zeros((), float0))
|
|
self.assertArraysEqual(api.jvp(foo, primals, tangents),
|
|
(primals, expected_tangents))
|
|
|
|
def test_remat(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * jnp.cos(x) * g
|
|
f.defjvp(f_jvp)
|
|
|
|
@api.remat
|
|
def g(x):
|
|
return f(f(x))
|
|
|
|
ans = g(2.)
|
|
expected = np.sin(np.sin(2.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(g)(2.)
|
|
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_higher_order(self):
|
|
@api.custom_jvp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * jnp.cos(x) * g
|
|
f.defjvp(f_jvp)
|
|
|
|
def g(x):
|
|
return f(f(x))
|
|
|
|
ans = api.grad(api.grad(api.remat(g)))(2.)
|
|
expected = api.grad(api.grad(g))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.remat(api.grad(g)))(2.)
|
|
expected = api.grad(api.grad(g))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
|
|
expected = api.grad(api.grad(api.grad(g)))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_initial_style_vmap_2(self):
|
|
# This is like test_initial_style_vmap except the primal function closes
|
|
# over an array constant.
|
|
y = jnp.array([1., 2., 3.])
|
|
|
|
@api.custom_jvp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return 3 * x * jnp.sum(y)
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
g, = tangents
|
|
return f(x), 2 * g
|
|
f.defjvp(f_jvp)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3))
|
|
expected = 2. * jnp.ones(3)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_custom_jvp_vmap_broadcasting_interaction(self):
|
|
# https://github.com/google/jax/issues/6452
|
|
def f2(y, z):
|
|
v1 = z
|
|
v2 = jnp.sum(y) + z
|
|
return jnp.logaddexp(v1, v2)
|
|
|
|
def f1(y, z):
|
|
v = api.vmap(lambda _y: f2(_y, z))(y)
|
|
return jnp.sum(v)
|
|
|
|
y = jnp.ones((3, 2))
|
|
f = lambda z: f1(y, z)
|
|
z = 0.1
|
|
val, g = api.value_and_grad(f)(z)
|
|
self.assertEqual(val.shape, ())
|
|
self.assertEqual(g.shape, ())
|
|
|
|
def test_custom_jvp_vmap_broadcasting_interaction_2(self):
|
|
# https://github.com/google/jax/issues/5849
|
|
@api.custom_jvp
|
|
def transform(box, R):
|
|
if jnp.isscalar(box) or box.size == 1:
|
|
return R * box
|
|
elif box.ndim == 2:
|
|
return jnp.einsum('ij,j->i', box, R)
|
|
raise ValueError()
|
|
|
|
@transform.defjvp
|
|
def transform_jvp(primals, tangents):
|
|
box, R = primals
|
|
dbox, dR = tangents
|
|
return (transform(box, R), dR + transform(dbox, R))
|
|
|
|
def periodic_general(box):
|
|
def displacement_fn(Ra, Rb, **kwargs):
|
|
_box = kwargs.get('box', box)
|
|
return transform(_box, Ra - Rb)
|
|
|
|
return displacement_fn
|
|
|
|
N = 250
|
|
|
|
scalar_box = 1.0
|
|
displacement = periodic_general(scalar_box)
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
R = jax.random.uniform(key, (N, 2))
|
|
|
|
def energy_fn(box):
|
|
d = partial(displacement, box=box)
|
|
d = api.vmap(api.vmap(d, (None, 0)), (0, None))
|
|
return jnp.sum(d(R, R) ** 2)
|
|
|
|
self.assertEqual(grad(energy_fn)(scalar_box).shape, ())
|
|
|
|
def test_custom_jvp_implicit_broadcasting(self):
|
|
# https://github.com/google/jax/issues/6357
|
|
if config.x64_enabled:
|
|
raise unittest.SkipTest("test only applies when x64 is disabled")
|
|
|
|
@jax.custom_jvp
|
|
def projection_unit_simplex(x: jnp.ndarray) -> jnp.ndarray:
|
|
"""Projection onto the unit simplex."""
|
|
s = 1.0
|
|
n_features = x.shape[0]
|
|
u = jnp.sort(x)[::-1]
|
|
cssv = jnp.cumsum(u) - s
|
|
ind = jnp.arange(n_features) + 1
|
|
cond = u - cssv / ind > 0
|
|
idx = jnp.count_nonzero(cond)
|
|
threshold = cssv[idx - 1] / idx.astype(x.dtype)
|
|
return jax.nn.relu(x - threshold)
|
|
|
|
|
|
@projection_unit_simplex.defjvp
|
|
def projection_unit_simplex_jvp(primals, tangents):
|
|
x, = primals
|
|
x_dot, = tangents
|
|
primal_out = projection_unit_simplex(x)
|
|
supp = primal_out > 0
|
|
card = jnp.count_nonzero(supp)
|
|
tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp
|
|
return primal_out, tangent_out
|
|
|
|
rng = np.random.RandomState(0)
|
|
x = rng.rand(5).astype(np.float32)
|
|
|
|
J_rev = jax.jacrev(projection_unit_simplex)(x)
|
|
J_fwd = jax.jacfwd(projection_unit_simplex)(x)
|
|
|
|
p = projection_unit_simplex(x)
|
|
support = (p > 0).astype(jnp.int32)
|
|
cardinality = jnp.count_nonzero(support)
|
|
J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality
|
|
self.assertAllClose(J_true, J_fwd)
|
|
self.assertAllClose(J_true, J_rev)
|
|
|
|
proj = jax.vmap(projection_unit_simplex)
|
|
|
|
def fun(X):
|
|
return jnp.sum(proj(X) ** 2)
|
|
|
|
rng = np.random.RandomState(0)
|
|
X = rng.rand(4, 5).astype(np.float32)
|
|
U = rng.rand(4, 5)
|
|
U /= np.sqrt(np.sum(U ** 2))
|
|
U = U.astype(np.float32)
|
|
|
|
eps = 1e-3
|
|
dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps)
|
|
dir_deriv = jnp.vdot(jax.grad(fun)(X), U)
|
|
self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3)
|
|
|
|
def test_vmap_inside_defjvp(self):
|
|
# https://github.com/google/jax/issues/3201
|
|
seed = 47
|
|
key = jax.random.PRNGKey(seed)
|
|
mat = jax.random.normal(key, (2, 3))
|
|
|
|
@jax.custom_jvp
|
|
def f(mat, aux):
|
|
num_rows, num_cols = mat.shape
|
|
return jnp.ones((num_rows, 1)) / num_cols
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
mat, aux = primals
|
|
vec, _ = tangents
|
|
output = f(*primals)
|
|
num_rows, num_cols = mat.shape
|
|
size = num_rows * num_cols
|
|
# -----
|
|
bd_mat = mat.reshape(1, 1, num_rows, num_cols)
|
|
bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols))
|
|
bd_mat = bd_mat.reshape(size, num_rows, num_cols)
|
|
# -----
|
|
rowsum = jnp.sum(mat, axis=1, keepdims=True)
|
|
colsum = jnp.sum(mat, axis=0, keepdims=True)
|
|
bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows))
|
|
bd_colsum = jnp.tile(colsum, reps=(num_cols, 1))
|
|
# -----
|
|
bd_vec = vec.reshape(size, 1)
|
|
# -----
|
|
def operate(mx, val):
|
|
buf = 0
|
|
for i in range(2):
|
|
buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i)
|
|
buf = jnp.matmul(bd_rowsum, buf)
|
|
return buf * val
|
|
# -----
|
|
# Vertorizing will raise shape error
|
|
bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec)
|
|
# -----
|
|
bd_buf = bd_buf / aux
|
|
jvp = jnp.sum(bd_buf, axis=0)
|
|
jvp = jnp.mean(jvp, axis=1, keepdims=True)
|
|
# -----
|
|
# JVP ends successfully, but still raise an error
|
|
return (output, jvp)
|
|
|
|
jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash
|
|
|
|
def test_custom_jvp_unbroadcasting(self):
|
|
# https://github.com/google/jax/issues/3056
|
|
a = jnp.array([1., 1.])
|
|
|
|
@jax.custom_jvp
|
|
def f(x):
|
|
return a * x
|
|
|
|
@f.defjvp
|
|
def f_jvp(primals, tangents):
|
|
x, = primals
|
|
dx, = tangents
|
|
return a * x, a * dx
|
|
|
|
shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape
|
|
self.assertEqual(shape, ())
|
|
|
|
|
|
class CustomVJPTest(jtu.JaxTestCase):
|
|
|
|
def test_basic(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = 3.
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
|
self.assertAllClose(api.value_and_grad(f)(x),
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
|
|
|
def test_invariance(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.cos(2 * x) / 2.
|
|
def f_fwd(x):
|
|
return (f(x), x)
|
|
def f_rev(x, g):
|
|
return (g * 3,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
def f2(x):
|
|
y, _ = api.value_and_grad(f)(x)
|
|
return y
|
|
def f3(x):
|
|
y, _ = api.value_and_grad(f2)(x)
|
|
return y
|
|
x = 1.
|
|
self.assertAllClose(f(x), f2(x), check_dtypes=False)
|
|
self.assertAllClose(f(x), f3(x), check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f2)(x),
|
|
check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f3)(x),
|
|
check_dtypes=False)
|
|
|
|
def test_python_control_flow(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
if x > 0:
|
|
return jnp.sin(x)
|
|
else:
|
|
return jnp.cos(x)
|
|
def f_fwd(x):
|
|
if x > 0:
|
|
return f(x), x
|
|
else:
|
|
return f(x), x
|
|
def f_rev(x, g):
|
|
if x > 0:
|
|
return (2 * g,)
|
|
else:
|
|
return (3 * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
x = 2.
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
self.assertAllClose(f(-x), jnp.cos(-x))
|
|
self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.),
|
|
check_dtypes=False)
|
|
self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.),
|
|
check_dtypes=False)
|
|
|
|
def test_vmap(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
assert jnp.ndim(x) == 0
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = jnp.arange(3.)
|
|
xx = jnp.arange(6.).reshape(2, 3)
|
|
|
|
# vmap of f
|
|
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
|
|
|
# vmap of grad of f
|
|
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x))
|
|
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
|
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx))
|
|
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
|
|
(jnp.sin(xx), 2 * jnp.cos(xx)))
|
|
|
|
# grad of vmap of f
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
|
|
2 * jnp.cos(x))
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
|
|
2 * jnp.cos(xx))
|
|
|
|
# vmap of grad of vmap of f
|
|
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
|
|
2 * jnp.cos(xx))
|
|
|
|
def test_jit(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = 3.
|
|
|
|
# jit
|
|
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
|
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
|
|
|
# jit of grad
|
|
self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x),
|
|
check_dtypes=False)
|
|
|
|
# grad of jit
|
|
self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x),
|
|
check_dtypes=False)
|
|
|
|
def test_pytrees(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return {'b': jnp.sin(x['a'])}
|
|
def f_fwd(x):
|
|
return f(x), {'r': jnp.cos(x['a'])}
|
|
def f_bwd(res, g):
|
|
cos_x = res['r']
|
|
return ({'a': 2 * cos_x * g['b']},)
|
|
f.defvjp(f_fwd, f_bwd)
|
|
x = {'a': 3.}
|
|
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
|
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
|
|
{'a': 2 * jnp.cos(x['a'])})
|
|
|
|
def test_jvp_error(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
lambda: api.jvp(f, (3.,), (1.,)))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)))
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
lambda: api.jvp(jit(f), (3.,), (1.,)))
|
|
|
|
def test_kwargs(self):
|
|
# from https://github.com/google/jax/issues/1938
|
|
@api.custom_vjp
|
|
def my_fun(x, y, c=1.):
|
|
return c * (x + y)
|
|
my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
|
|
lambda _, g: (g, g, g))
|
|
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
|
|
f(10., 5.) # doesn't crash
|
|
api.grad(f)(10., 5.) # doesn't crash
|
|
|
|
def test_initial_style(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.grad(foo)(3.)
|
|
expected = 2. * jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
expected = -2. * jnp.sin(3.)
|
|
self.assertAllClose(ans, expected)
|
|
|
|
def test_initial_style_vmap(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return 3 * x
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.vmap(foo)(jnp.arange(3.))
|
|
expected = 3. * jnp.arange(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
|
|
expected = 2. * jnp.cos(jnp.arange(3.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_nondiff_arg(self):
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
def app(f, x):
|
|
return f(x)
|
|
def app_fwd(f, x):
|
|
return app(f, x), jnp.cos(x)
|
|
def app_rev(f, cos_x, g):
|
|
return (cos_x * g,)
|
|
app.defvjp(app_fwd, app_rev)
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
expected = 2
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.)
|
|
expected = (2., jnp.cos(1.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_nondiff_arg_tracer(self):
|
|
# This test is now skipped because we decided not to support this behavior
|
|
# anymore (namely, nondiff args can't be tracers), but
|
|
# test_closed_over_tracer is a replacement test for analogous behavior that
|
|
# we do support
|
|
raise unittest.SkipTest("removed support for tracers in nondiff args")
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
def f(x, y):
|
|
return x * y
|
|
def f_fwd(x, y):
|
|
return f(x, y), jnp.cos(y)
|
|
def f_rev(x, cos_y, g):
|
|
return (cos_y * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
@jit
|
|
def g(x, y):
|
|
return f(x, y)
|
|
|
|
ans = g(2, 3.)
|
|
expected = 6.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(g, 1)(2., 3.)
|
|
expected = jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_closed_over_tracer(self):
|
|
# This test is similar to test_nondiff_arg_tracer except it uses lexical
|
|
# closure rather than the nondiff_argnums mechanism. We decided to disallow
|
|
# tracers in nondiff_argnums to greatly simplify bookkeeping while still
|
|
# supporting the cases for which it is necessary.
|
|
def outer(x):
|
|
@api.custom_vjp
|
|
def f(y):
|
|
return x * y
|
|
def f_fwd(y):
|
|
return f(y), jnp.cos(y)
|
|
def f_rev(cos_y, g):
|
|
return (cos_y * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
return f
|
|
|
|
@jit
|
|
def g(x, y):
|
|
return outer(x)(y)
|
|
|
|
ans = g(2, 3.)
|
|
expected = 6.
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(g, 1)(2., 3.)
|
|
expected = jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_closed_over_tracer2(self):
|
|
def outer(x):
|
|
@api.custom_vjp
|
|
def f(y):
|
|
return x * y
|
|
def f_fwd(y):
|
|
return f(y), jnp.cos(y)
|
|
def f_rev(cos_y, g):
|
|
return (cos_y * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
return f
|
|
|
|
@api.vmap
|
|
def g(x):
|
|
return outer(x)(3.)
|
|
|
|
ans = g(np.arange(3.))
|
|
expected = np.arange(3.) * 3
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_closed_over_tracer3(self):
|
|
def outer(x):
|
|
@api.custom_vjp
|
|
def f(y):
|
|
return x * y
|
|
def f_fwd(y):
|
|
return f(y), (x, jnp.cos(y))
|
|
def f_rev(res, g):
|
|
x, cos_y = res
|
|
return (cos_y * g * x,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
return api.grad(f)
|
|
|
|
@api.vmap
|
|
def g(x):
|
|
return outer(x)(3.)
|
|
|
|
ans = g(np.arange(3.))
|
|
expected = np.cos(3.) * np.arange(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_nondiff_arg_tracer_error(self):
|
|
# This is similar to the old (now skipped) test_nondiff_arg_tracer, except
|
|
# we're testing for the error message that that usage pattern now raises.
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
def f(x, y):
|
|
return x * y
|
|
def f_fwd(x, y):
|
|
return f(x, y), jnp.cos(y)
|
|
def f_rev(x, cos_y, g):
|
|
return (cos_y * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
@jit
|
|
def g(x, y):
|
|
return f(x, y)
|
|
|
|
with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"):
|
|
_ = g(2, 3.)
|
|
with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"):
|
|
_ = api.grad(g, 1)(2., 3.)
|
|
|
|
def test_vmap_axes(self):
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
def test_pmap(self):
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
def test_missing_vjp_rule_error(self):
|
|
@api.custom_vjp
|
|
def foo(x):
|
|
return x ** 2
|
|
|
|
self.assertRaisesRegex(
|
|
AttributeError,
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
lambda: foo(2))
|
|
self.assertRaisesRegex(
|
|
AttributeError,
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
def test_vjp_rule_inconsistent_pytree_structures_error(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return x
|
|
|
|
def foo_fwd(x):
|
|
return x, None
|
|
|
|
def foo_bwd(_, g):
|
|
return (g, g)
|
|
|
|
f.defvjp(foo_fwd, foo_bwd)
|
|
|
|
f(2) # doesn't crash
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
re.escape(
|
|
"Custom VJP rule must produce an output with the same container "
|
|
"(pytree) structure as the args tuple of the primal function, "
|
|
"and in particular must produce a tuple of length equal to the "
|
|
"number of arguments to the primal function, but got VJP output "
|
|
"structure {} for primal input structure {}.".format(
|
|
tree_util.tree_structure((1, 1)),
|
|
tree_util.tree_structure((1,)))
|
|
),
|
|
lambda: api.grad(f)(2.))
|
|
|
|
def test_vjp_bwd_returns_non_tuple_error(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return x
|
|
|
|
def foo_fwd(x):
|
|
return x, None
|
|
|
|
def foo_bwd(_, g):
|
|
return 2. * g # Should be a tuple
|
|
|
|
f.defvjp(foo_fwd, foo_bwd)
|
|
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
|
|
api.grad(f)(3.)
|
|
|
|
def test_issue2511(self):
|
|
arr = jnp.ones((5, 2, 2))
|
|
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
|
|
api.jit(foo)(arr) # doesn't crash
|
|
|
|
def test_lowering_out_of_traces(self):
|
|
# https://github.com/google/jax/issues/2578
|
|
|
|
class F(collections.namedtuple("F", ["a"])):
|
|
def __call__(self, x):
|
|
return jax.nn.relu(self.a) * x
|
|
|
|
@jax.jit
|
|
def g(f, x):
|
|
return f(x)
|
|
|
|
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
|
|
|
|
def test_nondiff_argnums_stop_gradient(self):
|
|
# This test is now skipped because we decided not to support this behavior
|
|
# anymore (namely, nondiff args can't be tracers), but test_clip_gradient is
|
|
# a replacement showing behavior we do support.
|
|
raise unittest.SkipTest("removed support for tracers in nondiff args")
|
|
|
|
# https://github.com/google/jax/issues/2784
|
|
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
|
|
def _clip_gradient(lo, hi, x):
|
|
return x # identity function
|
|
|
|
def clip_gradient_fwd(lo, hi, x):
|
|
# return x, None
|
|
return x, (hi, )
|
|
|
|
def clip_gradient_bwd(lo, hi, _, g):
|
|
return (jnp.clip(g, lo, hi),)
|
|
|
|
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
|
|
|
def clip_gradient(x):
|
|
lo = -1
|
|
hi = x + 1 # causes things to break
|
|
return _clip_gradient(lo, hi, x)
|
|
|
|
jax.grad(clip_gradient)(1.) # doesn't crash
|
|
|
|
def test_clip_gradient(self):
|
|
# https://github.com/google/jax/issues/2784
|
|
@api.custom_vjp
|
|
def _clip_gradient(lo, hi, x):
|
|
return x # identity function when not differentiating
|
|
|
|
def clip_gradient_fwd(lo, hi, x):
|
|
return x, (lo, hi,)
|
|
|
|
def clip_gradient_bwd(res, g):
|
|
lo, hi = res
|
|
return (None, None, jnp.clip(g, lo, hi),)
|
|
|
|
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
|
|
|
def clip_gradient(x):
|
|
lo = -0.1
|
|
hi = x + 0.1
|
|
return _clip_gradient(lo, hi, x)
|
|
|
|
g = jax.grad(clip_gradient)(0.1) # doesn't crash
|
|
self.assertAllClose(g, jnp.array(0.2))
|
|
|
|
def test_nestable_vjp(self):
|
|
# Verify that https://github.com/google/jax/issues/3667 is resolved.
|
|
def f(x):
|
|
return x ** 2
|
|
|
|
@api.custom_vjp
|
|
def g(x):
|
|
return f(x)
|
|
|
|
def g_fwd(x):
|
|
y, f_vjp = api.vjp(f, x)
|
|
return y, f_vjp
|
|
|
|
def g_bwd(f_vjp, y_bar):
|
|
return f_vjp(y_bar)
|
|
|
|
g.defvjp(g_fwd, g_bwd)
|
|
|
|
# Check that VJP can be nested in simple situations. For this to pass,
|
|
# vjp has to return a PyTree.
|
|
_, g_vjp = api.vjp(g, 1.0)
|
|
y, = g_vjp(1.0)
|
|
self.assertAllClose(y, jnp.array(2.0))
|
|
|
|
# Check that VJP can be nested in complex situations. For this to pass,
|
|
# vjp can't treat the closed-over tracer x as a static argument.
|
|
@jit
|
|
def z(x):
|
|
_, g_vjp = api.vjp(g, x)
|
|
return g_vjp
|
|
y, = z(1.0)(3.0)
|
|
self.assertAllClose(y, jnp.array(6.0))
|
|
|
|
def test_initial_style_vmap_2(self):
|
|
# https://github.com/google/jax/issues/4173
|
|
x = jnp.ones((10, 3))
|
|
|
|
# Create the custom function
|
|
@api.custom_vjp
|
|
def custom_fun(x):
|
|
return x.sum()
|
|
|
|
def forward(x):
|
|
return x.sum(), (jnp.ones_like(x),)
|
|
|
|
def backward(res, g):
|
|
return g * res[0],
|
|
|
|
custom_fun.defvjp(forward, backward)
|
|
|
|
def train_fun(x):
|
|
|
|
def summed_fun(x):
|
|
return api.vmap(custom_fun)(x).sum()
|
|
|
|
return api.grad(summed_fun)(x)
|
|
|
|
def scan_body(carry, inputs):
|
|
x = carry
|
|
return carry, train_fun(x)
|
|
|
|
scan_range = jnp.arange(4)
|
|
lax.scan(scan_body, x, scan_range) # don't crash
|
|
|
|
def test_initial_style_vmap_3(self):
|
|
# This is like test_initial_style_vmap except the primal function closes
|
|
# over an array constant.
|
|
y = jnp.array([1., 2., 3.])
|
|
|
|
@api.custom_vjp
|
|
def f(x):
|
|
assert jnp.ndim(x) == 0
|
|
return 3 * x * jnp.sum(y)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def foo(x):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
return out
|
|
|
|
ans = api.vmap(foo)(jnp.arange(3.))
|
|
expected = 3. * jnp.arange(3.) * 6
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
|
|
expected = 2. * jnp.cos(jnp.arange(3.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_bwd_closes_over_tracer(self):
|
|
def f(y):
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return 2. * jnp.sin(x)
|
|
|
|
def fwd(x):
|
|
return f(x), ()
|
|
|
|
def bwd(_, g):
|
|
return (2. * jnp.cos(y) * g,) # capture!
|
|
|
|
f.defvjp(fwd, bwd)
|
|
|
|
return jax.grad(f)(1.)
|
|
|
|
ans = jax.jit(f)(2.)
|
|
self.assertAllClose(ans, 2. * jnp.cos(2.))
|
|
|
|
ans = jax.vmap(f)(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.grad(f)(4.)
|
|
self.assertAllClose(ans, -2. * jnp.sin(4.))
|
|
|
|
def test_fwd_closes_over_tracer(self):
|
|
def f(y):
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return 2. * jnp.sin(x)
|
|
|
|
def fwd(x):
|
|
return f(x), y
|
|
|
|
def bwd(y, g):
|
|
return (2. * jnp.cos(y) * g,) # capture!
|
|
|
|
f.defvjp(fwd, bwd)
|
|
|
|
return jax.grad(f)(1.)
|
|
|
|
ans = jax.jit(f)(2.)
|
|
self.assertAllClose(ans, 2. * jnp.cos(2.))
|
|
|
|
ans = jax.vmap(f)(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.vmap(jax.jit(f))(jnp.arange(3.))
|
|
self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.)))
|
|
|
|
ans = jax.grad(f)(4.)
|
|
self.assertAllClose(ans, -2. * jnp.sin(4.))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_float0(self):
|
|
@api.custom_vjp
|
|
def f(x, _):
|
|
return x
|
|
def f_fwd(x, _):
|
|
# we need a defined (non-float0) tangent to trigger the rule
|
|
return x, (2., 1)
|
|
def f_rev(*_):
|
|
return (2., 1)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = 2.
|
|
y = 3
|
|
self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y),
|
|
(2., np.zeros(shape=(), dtype=float0)))
|
|
|
|
@unittest.skipIf(numpy_version == (1, 21, 0),
|
|
"https://github.com/numpy/numpy/issues/19305")
|
|
def test_float0_initial_style(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return x
|
|
def f_fwd(x):
|
|
return x, (2., x)
|
|
def f_rev(*_):
|
|
return ((2., 1),)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def foo(x, y):
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1)
|
|
return out[0]
|
|
|
|
x = 2.
|
|
y = 3
|
|
self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y),
|
|
(2., np.zeros(shape=(), dtype=float0)))
|
|
|
|
def test_remat(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
@api.remat
|
|
def g(x):
|
|
return f(f(x))
|
|
|
|
ans = g(2.)
|
|
expected = np.sin(np.sin(2.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(g)(2.)
|
|
expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_remat_higher_order(self):
|
|
@api.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
def f_fwd(x):
|
|
return f(x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (2 * cos_x * g,)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def g(x):
|
|
return f(f(x))
|
|
|
|
ans = api.grad(api.grad(api.remat(g)))(2.)
|
|
expected = api.grad(api.grad(g))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.remat(api.grad(g)))(2.)
|
|
expected = api.grad(api.grad(g))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
ans = api.grad(api.grad(api.grad(api.remat(g))))(2.)
|
|
expected = api.grad(api.grad(api.grad(g)))(2.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_bwd_nones(self):
|
|
@api.custom_vjp
|
|
def f(x, y):
|
|
return x * jnp.sin(y)
|
|
def f_fwd(x, y):
|
|
return f(x, y), jnp.cos(y)
|
|
def f_rev(cos, g):
|
|
return (None, 2 * cos * g)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
ans = api.grad(lambda x: f(x, x))(3.)
|
|
expected = 2 * jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_bwd_nones_vmap(self):
|
|
@api.custom_vjp
|
|
def f(x, y):
|
|
return x * jnp.sin(y)
|
|
def f_fwd(x, y):
|
|
return f(x, y), jnp.cos(y)
|
|
def f_rev(cos, g):
|
|
return (None, 2 * cos * g)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.))
|
|
expected = 2 * jnp.cos(jnp.arange(3.))
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_bwd_nones_pytree(self):
|
|
@api.custom_vjp
|
|
def f(xs, y):
|
|
x1, x2 = xs
|
|
return x1 * x2 * jnp.sin(y)
|
|
def f_fwd(xs, y):
|
|
return f(xs, y), jnp.cos(y)
|
|
def f_rev(cos, g):
|
|
return (None, 2 * cos * g)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
ans = api.grad(lambda x: f((x, x), x))(3.)
|
|
expected = 2 * jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_custom_vjp_closure_4521(self):
|
|
# https://github.com/google/jax/issues/4521
|
|
@api.custom_vjp
|
|
def g(x, y):
|
|
return None
|
|
def g_fwd(x, y):
|
|
return None, y
|
|
def g_bwd(residuals, z_bar):
|
|
assert False
|
|
|
|
g.defvjp(g_fwd, g_bwd)
|
|
|
|
def f(xs, y):
|
|
v_g = api.vmap(g, in_axes=(0, None), out_axes=None)
|
|
v_g(xs, y)
|
|
|
|
def scan_body(xs, _):
|
|
y = jnp.zeros(1)
|
|
_, vjp_f = api.vjp(f, xs, y)
|
|
vjp_f(None)
|
|
return xs, None
|
|
|
|
lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash
|
|
|
|
def test_float0_bwd_none(self):
|
|
@api.custom_vjp
|
|
def f(i, x):
|
|
return jnp.sin(x)
|
|
def f_fwd(i, x):
|
|
return f(i, x), jnp.cos(x)
|
|
def f_rev(cos_x, g):
|
|
return (None, 2 * cos_x * g)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash
|
|
expected = 2 * jnp.cos(3.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_custom_gradient(self):
|
|
@api.custom_gradient
|
|
def f(x):
|
|
return x ** 2, lambda g: (g * x,)
|
|
|
|
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
|
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
|
|
|
|
def test_custom_gradient_2(self):
|
|
@api.custom_gradient
|
|
def f(x, y):
|
|
return x * y, lambda g: (y, x)
|
|
|
|
self.assertAllClose(f(3., 4.), 12., check_dtypes=False)
|
|
self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.),
|
|
check_dtypes=False)
|
|
|
|
def test_custom_gradient_3(self):
|
|
@api.custom_gradient
|
|
def f(x):
|
|
vjp = lambda g: (jnp.cos(x) * jnp.array([3., 4., 5.]),)
|
|
return jnp.sum(jnp.sin(x)), vjp
|
|
|
|
self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))),
|
|
check_dtypes=False)
|
|
self.assertAllClose(
|
|
api.grad(f)(jnp.arange(3.)),
|
|
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.array([3., 4., 5.]),
|
|
check_dtypes=False)
|
|
|
|
def test_custom_gradient_can_return_singleton_value_in_vjp(self):
|
|
@api.custom_gradient
|
|
def f(x):
|
|
return x ** 2, lambda g: g * x
|
|
|
|
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
|
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
|
self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False)
|
|
|
|
def test_closure_convert(self):
|
|
def cos_after(fn, x):
|
|
converted_fn, aux_args = api.closure_convert(fn, x)
|
|
self.assertLessEqual(len(aux_args), 1)
|
|
return _cos_after(converted_fn, x, *aux_args)
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
def _cos_after(fn, x, *args):
|
|
return jnp.cos(fn(x, *args))
|
|
|
|
def fwd(fn, x, *args):
|
|
y = _cos_after(fn, x, *args)
|
|
return y, (x, args)
|
|
|
|
def rev(fn, res, g):
|
|
x, args = res
|
|
x_bar = 17. * x
|
|
args_bars = [42. * a for a in args]
|
|
return (x_bar, *args_bars)
|
|
|
|
_cos_after.defvjp(fwd, rev)
|
|
|
|
def dist(c, x):
|
|
return jnp.sum((x - c) ** 2.)
|
|
|
|
def solve(c, x):
|
|
def closure(x):
|
|
return dist(c, x)
|
|
return cos_after(closure, x)
|
|
|
|
c, x = 2. * jnp.ones(2), jnp.ones(2)
|
|
expected = jnp.cos(dist(c, x))
|
|
self.assertAllClose(solve(c, x), expected, check_dtypes=False)
|
|
g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x)
|
|
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
|
|
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
|
|
|
|
def test_closure_convert_mixed_consts(self):
|
|
# Like test_closure_convert, but close over values that
|
|
# participate in AD as well as values that do not.
|
|
# See https://github.com/google/jax/issues/6415
|
|
|
|
def cos_after(fn, x):
|
|
converted_fn, aux_args = api.closure_convert(fn, x)
|
|
self.assertLessEqual(len(aux_args), 1)
|
|
return _cos_after(converted_fn, x, *aux_args)
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
def _cos_after(fn, x, *args):
|
|
return jnp.cos(fn(x, *args))
|
|
|
|
def fwd(fn, x, *args):
|
|
y = _cos_after(fn, x, *args)
|
|
return y, (x, args)
|
|
|
|
def rev(fn, res, g):
|
|
x, args = res
|
|
x_bar = 17. * x
|
|
args_bars = [42. * a for a in args]
|
|
return (x_bar, *args_bars)
|
|
|
|
_cos_after.defvjp(fwd, rev)
|
|
|
|
def dist(c, s, x):
|
|
return jnp.sum(s * (x - c) ** 2.)
|
|
|
|
def solve(c, s, x):
|
|
def closure(x):
|
|
return dist(c, s, x)
|
|
return cos_after(closure, x)
|
|
|
|
c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2)
|
|
expected = jnp.cos(dist(c, s, x))
|
|
self.assertAllClose(solve(c, s, x), expected, check_dtypes=False)
|
|
g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x)
|
|
self.assertAllClose(g_c, 42. * c, check_dtypes=False)
|
|
self.assertAllClose(g_x, 17. * x, check_dtypes=False)
|
|
|
|
|
|
class CustomTransposeTest(jtu.JaxTestCase):
|
|
|
|
def transpose(self, f, x_example):
|
|
def transposed(y):
|
|
x, = api.linear_transpose(f, x_example)(y)
|
|
return x
|
|
return transposed
|
|
|
|
def test_linear_call(self):
|
|
def f(x, y):
|
|
def fn(r, x): return x / r
|
|
def tp(r, t): return t / r
|
|
return x + api.linear_call(fn, tp, y, x)
|
|
|
|
def f_ref(x, y):
|
|
return x + x / y
|
|
|
|
x = jnp.ones(2) * 6.
|
|
y = jnp.ones(2) * 3.
|
|
self.assertAllClose(f(x, y), f_ref(x, y))
|
|
|
|
f1 = lambda x: f(x, y)
|
|
f1_ref = lambda x: f_ref(x, y)
|
|
self.assertAllClose(self.transpose(f1, x)(x),
|
|
self.transpose(f1_ref, x)(x))
|
|
|
|
def test_linear_call_incorrect_transpose(self):
|
|
def f(x, y):
|
|
def fn(r, x): return x / r
|
|
def tp(r, t): return t / (2. * r) # nb: not the true transpose
|
|
return x + api.linear_call(fn, tp, y, x)
|
|
|
|
def f_ref(x, y):
|
|
return x + x / y
|
|
|
|
x = jnp.ones(2) * 6.
|
|
y = jnp.ones(2) * 3.
|
|
self.assertAllClose(f(x, y), f_ref(x, y))
|
|
|
|
f1 = lambda x: f(x, y)
|
|
f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor
|
|
self.assertAllClose(self.transpose(f1, x)(x),
|
|
self.transpose(f1_ref, x)(x))
|
|
|
|
def test_linear_call_transpose_transpose_transpose(self):
|
|
def fn(r, x): return x / r
|
|
def tp(r, t): return t / (2. * r) # nb: untrue transpose
|
|
def f_(x, y):
|
|
return x + api.linear_call(fn, tp, y, x)
|
|
|
|
x = jnp.ones(2) * 6.
|
|
y = jnp.ones(2) * 3.
|
|
f = lambda x: f_(x, y)
|
|
ft = self.transpose(f, x)
|
|
ftt = self.transpose(ft, x)
|
|
fttt = self.transpose(ftt, x)
|
|
self.assertAllClose(ft(x), x + tp(y, x))
|
|
self.assertAllClose(f(x), ftt(x))
|
|
self.assertAllClose(ft(x), fttt(x))
|
|
|
|
def test_linear_call_scalar_to_vector(self):
|
|
def f(c, x):
|
|
def fn(_, x):
|
|
return [x, x]
|
|
|
|
def tp(_, t):
|
|
t1, t2 = t
|
|
return t1 + t2
|
|
|
|
return api.linear_call(fn, tp, (), c * x)
|
|
|
|
def f_ref(c, x):
|
|
return [c * x, c * x]
|
|
|
|
c, x = 2., 3.
|
|
t = [4., 5.]
|
|
self.assertAllClose(f(c, x), f_ref(c, x))
|
|
self.assertAllClose(self.transpose(partial(f, c), x)(t),
|
|
self.transpose(partial(f_ref, c), x)(t))
|
|
|
|
def test_linear_call_nested(self):
|
|
# identity function with an untrue transpose of 0
|
|
def id_(x):
|
|
def f(_, x): return x
|
|
def t(_, t): return 0.
|
|
return api.linear_call(f, t, (), x)
|
|
|
|
# identity function with an untrue transpose of 7, and where both
|
|
# forward and transpose have custom transpositions that should
|
|
# never end up invoked.
|
|
def f(x):
|
|
def f_(_, x): return id_(x)
|
|
def t_(_, t): return id_(7.)
|
|
return api.linear_call(f_, t_, (), x)
|
|
|
|
x = 5.
|
|
id_t = self.transpose(id_, x)
|
|
id_tt = self.transpose(id_t, x)
|
|
ft = self.transpose(f, x)
|
|
ftt = self.transpose(ft, x)
|
|
fttt = self.transpose(ftt, x)
|
|
|
|
self.assertAllClose(id_(x), x)
|
|
self.assertAllClose(id_t(x), 0.)
|
|
self.assertAllClose(id_tt(x), x)
|
|
|
|
self.assertAllClose(f(x), x)
|
|
self.assertAllClose(ft(x), 7.)
|
|
self.assertAllClose(ftt(x), x)
|
|
self.assertAllClose(fttt(x), 7.)
|
|
|
|
|
|
class InvertibleADTest(jtu.JaxTestCase):
|
|
|
|
@jtu.ignore_warning(message="Values that an @invertible function closes")
|
|
def test_invertible_basic(self):
|
|
def f(x):
|
|
return lax.mul(lax.mul(lax.exp(x), 4.), x)
|
|
|
|
finv = jax.invertible(f)
|
|
x = jnp.ones((5,))
|
|
|
|
jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
|
|
|
|
# expected = """
|
|
# { lambda ; a b.
|
|
# let c = exp a
|
|
# d = mul c 4.0
|
|
# e = mul d a
|
|
# f = mul b a
|
|
# g = div e a
|
|
# h = mul b g
|
|
# i = mul f 4.0
|
|
# j = div g 4.0
|
|
# k = mul f j
|
|
# _ = reduce_sum[ axes=(0,) ] k
|
|
# _ = log j
|
|
# l = mul i j
|
|
# m = add_any h l
|
|
# in (m,) }
|
|
# """
|
|
# self.assertMultiLineStrippedEqual(expected, str(jaxpr)) # no jaxpr test
|
|
|
|
self.assertIn('div', str(jaxpr))
|
|
self.assertIn('log', str(jaxpr)) # assumes no DCE
|
|
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
|
|
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
|
|
check_dtypes=True)
|
|
|
|
def test_invertible_blocks(self):
|
|
# NB: This is the reversible ResNet block
|
|
def mk_reversible_block(f, g):
|
|
@jax.custom_ivjp
|
|
def rev_block(x1, x2):
|
|
y1 = f(x2) + x1
|
|
y2 = g(y1) + x2
|
|
return y1, y2
|
|
|
|
@rev_block.defivjp
|
|
def rev_block_ivjp(xs, ys, dys):
|
|
(y1, y2) = ys
|
|
(dy1, dy2) = dys
|
|
|
|
dgo, dx2 = dy2, dy2
|
|
go, gvjp = jax.vjp(g, y1)
|
|
dy1 += gvjp(dgo)[0]
|
|
del gvjp
|
|
x2 = y2 - go
|
|
|
|
dfo, dx1 = dy1, dy1
|
|
fo, fvjp = jax.vjp(f, x2)
|
|
dx2 += fvjp(dfo)[0]
|
|
del fvjp
|
|
x1 = y1 - fo
|
|
|
|
return (x1, x2), (dx1, dx2)
|
|
|
|
return rev_block
|
|
|
|
rev_block = mk_reversible_block(jnp.sin, jnp.cos)
|
|
|
|
def g(x1, x2):
|
|
for i in range(2):
|
|
x1, x2 = rev_block(x1, x2)
|
|
return x1, x2
|
|
|
|
def reduce(f, x1, x2):
|
|
y1, y2 = f(x1, x2)
|
|
return np.sum(y1) + np.sum(y2)
|
|
|
|
x = np.ones((1,))
|
|
# FIXME: This breaks when argnums is left as default (i.e. 0), because JVP prunes
|
|
# zero tangents from call primitives.
|
|
self.assertAllClose(jax.value_and_grad(partial(reduce, jax.invertible(g)), argnums=(0, 1))(x, x + 2),
|
|
jax.value_and_grad(partial(reduce, g), argnums=(0, 1))(x, x + 2),
|
|
check_dtypes=True)
|
|
|
|
def test_invertible_partial_diff(self):
|
|
# Check that we don't have to differentiate with respect to inputs
|
|
# of the invertible function.
|
|
def f(x, y):
|
|
return lax.mul(lax.mul(lax.exp(x), 4.), x), lax.add(y, 4.)
|
|
|
|
finv = jax.invertible(f)
|
|
o = np.ones((5,))
|
|
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x, o)[0]))(o),
|
|
jax.value_and_grad(lambda x: np.sum(finv(x, o)[0]))(o),
|
|
check_dtypes=True)
|
|
|
|
def test_invertible_pytree(self):
|
|
def f(x, y):
|
|
return lax.add(lax.mul(lax.exp(x[0]), x[1]), y)
|
|
|
|
finv = jax.invertible(f)
|
|
o = np.ones((5,))
|
|
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f((x, x), x)[0]))(o),
|
|
jax.value_and_grad(lambda x: np.sum(finv((x, x), x)[0]))(o),
|
|
check_dtypes=True)
|
|
|
|
|
|
class BufferDonationTest(jtu.BufferDonationTestCase):
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
def test_pmap_donate_argnums_invalidates_input(self):
|
|
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
|
|
n = jax.local_device_count()
|
|
x = api.pmap(lambda x: x)(jnp.ones([n]))
|
|
y = move(x)
|
|
self.assertDeleted(x)
|
|
np.testing.assert_allclose(y, [1.] * n)
|
|
|
|
def test_pmap_nested_donate_ignored(self):
|
|
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
|
|
a = api.pmap(lambda x: x)(jnp.array([1]))
|
|
|
|
# NOTE(mattjj): stopped raising error here and instead just ignored
|
|
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
|
# pmap_fun(a)
|
|
|
|
pmap_fun(a) # doesn't crash
|
|
|
|
|
|
class NamedCallTest(jtu.JaxTestCase):
|
|
|
|
def test_default_name(self):
|
|
|
|
@api.named_call
|
|
def my_test_function(x):
|
|
return x**2
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return my_test_function(x)
|
|
|
|
c = jax.xla_computation(f)(2)
|
|
self.assertIn("my_test_function", c.as_hlo_text())
|
|
|
|
def test_non_jaxtype_arg(self):
|
|
# For the test to fail without the invalid JaxType filter we need to pass
|
|
# in a valid JaxType that forces the invalid Jaxtype to be raised to an
|
|
# abstract value.
|
|
def f(not_a_jaxtype, a_jaxtype):
|
|
# then Jax needs to try and evaluate the abstractified non-JaxType
|
|
if not_a_jaxtype:
|
|
return a_jaxtype
|
|
return 0
|
|
|
|
f = api.named_call(f, name="test")
|
|
out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1)
|
|
self.assertEqual(out, 1)
|
|
|
|
@parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat)
|
|
def test_jax_transforms(self, transform):
|
|
f = jnp.sum
|
|
x = jnp.array([1.])
|
|
|
|
unnamed_out = transform(f)(x)
|
|
named_out = transform(api.named_call(f, name="test"))(x)
|
|
|
|
self.assertEqual(unnamed_out, named_out)
|
|
|
|
def test_static_argnums(self):
|
|
f = api.named_call(lambda x, y: y if x else None, name="test")
|
|
f = jax.jit(f, static_argnums=(0,))
|
|
out = f(True, 5)
|
|
self.assertEqual(out, 5)
|
|
|
|
def test_partial_eval(self):
|
|
f = api.named_call(lambda x, y: y if x else None, name="test")
|
|
f = jax.jit(functools.partial(f, True))
|
|
out = f(5)
|
|
self.assertEqual(out, 5)
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": "_jit_type={}_func={}".format(jit_type, func),
|
|
"jit_type": jit_type, "func": func}
|
|
for func in ['identity', 'asarray', 'device_put']
|
|
for jit_type in [None, "python", "cpp"]
|
|
if not (jit_type is None and func == 'identity')))
|
|
def test_integer_overflow(self, jit_type, func):
|
|
funcdict = {
|
|
'identity': lambda x: x,
|
|
'asarray': jnp.asarray,
|
|
'device_put': api.device_put,
|
|
}
|
|
jit = {
|
|
'python': api._python_jit,
|
|
'cpp': api._cpp_jit,
|
|
None: lambda x: x,
|
|
}
|
|
f = jit[jit_type](funcdict[func])
|
|
|
|
int_dtype = dtypes.canonicalize_dtype(jnp.int_)
|
|
int_max = np.iinfo(int_dtype).max
|
|
int_min = np.iinfo(int_dtype).min
|
|
|
|
self.assertEqual(f(int_max).dtype, int_dtype)
|
|
self.assertEqual(f(int_min).dtype, int_dtype)
|
|
self.assertRaises(OverflowError, f, int_max + 1)
|
|
self.assertRaises(OverflowError, f, int_min - 1)
|
|
|
|
|
|
class BackendsTest(jtu.JaxTestCase):
|
|
|
|
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
|
@jtu.skip_on_devices("gpu", "tpu")
|
|
def test_cpu_warning_suppression(self):
|
|
warning_expected = (
|
|
"import jax; "
|
|
"jax.numpy.arange(10)")
|
|
warning_not_expected = (
|
|
"import jax; "
|
|
"jax.config.update('jax_platform_name', 'cpu'); "
|
|
"jax.numpy.arange(10)")
|
|
|
|
result = subprocess.run([sys.executable, '-c', warning_expected],
|
|
check=True, capture_output=True)
|
|
assert "No GPU/TPU found" in result.stderr.decode()
|
|
|
|
result = subprocess.run([sys.executable, '-c', warning_not_expected],
|
|
check=True, capture_output=True)
|
|
assert "No GPU/TPU found" not in result.stderr.decode()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|