rocm_jax/tests/api_test.py
2020-09-21 19:33:14 -07:00

3723 lines
113 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from contextlib import contextmanager
import copy
from functools import partial
import re
import unittest
import types
import warnings
import weakref
import functools
from absl import logging
from absl.testing import absltest, parameterized
import numpy as np
import concurrent.futures
import jax
import jax.numpy as jnp
from jax import jit, grad, device_put, jacfwd, jacrev, hessian
from jax import api, core, lax, lax_reference
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.lib import xla_bridge as xb
from jax import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
from jax.lib import version
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class CPPJitTest(jtu.JaxTestCase):
"""Shared tests between the Python and the C++ jax,jit implementations.
Because the Python implementation supports more features, we need to have the
Python tests that extend the C++ tests (and not the other way around).
"""
@property
def jit(self):
# Right now, the CPP tests also test the Python code-path when jaxlib is
# too old.
# TODO(jblespiau,phawkins): Remove this when jaxlib has been released.
# This is in the future, because we are making a breaking change to
# Tensorflow.
if version < (0, 1, 56):
raise unittest.SkipTest("Disabled because it depends on some future "
"release of jax_jit.cc within jaxlib.")
else:
return jax.api._cpp_jit
def test_jit_of_noncallable(self):
self.assertRaisesRegex(TypeError, "Expected a callable value.*",
lambda: self.jit(3))
def test_jit_of_generator(self):
def gen(x):
yield x
self.assertRaisesRegex(TypeError,
"Expected a function, got a generator function.*",
lambda: self.jit(gen))
@parameterized.parameters([
# Integer support
(1, 2, 3, 4, 5),
# Numpy array support
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
np.asarray(4, np.int32),
np.asarray(5, np.int32),
),
])
def test_jit_static_args(self, one, two, three, four, five):
side = []
# For the CPP jit, we need to clear the cache to prevent cache hits between
# parameterized tests.
if hasattr(self.jit, "cache_clear"):
self.jit.cache_clear()
def f(x, y, z, flag=False, flag2=False):
del flag2 # unused
assert flag
side.append(None)
return 100 * x + 10 * y + z
f1 = self.jit(f, static_argnums=(3, 4))
assert f1(one, two, three, True, False) == 123
assert len(side) == 1
assert f1(one, two, three, True, False) == 123
assert len(side) == 1 # Obvious cache hit.
assert f1(two, one, three, True, False) == 213
assert len(side) == 1 # Should cache hit because same signature.
assert f1(two, one, three, True, True) == 213
assert len(side) == 2
side[:] = []
f2 = self.jit(f, static_argnums=(0, 2, 3, 4))
assert f2(one, two, three, True, False) == 123
assert len(side) == 1
assert f2(one, three, three, True, False) == 133
assert len(side) == 1
assert f2(two, two, three, True, False) == 223
assert len(side) == 2
assert f2(two, four, three, True, False) == 243
assert len(side) == 2
assert f2(two, four, three, True, True) == 243
assert len(side) == 3
assert f2(two, five, three, True, True) == 253
assert len(side) == 3
@parameterized.parameters([
(1, 2, 3),
(
np.asarray(1, np.int32),
np.asarray(2, np.int32),
np.asarray(3, np.int32),
),
])
def test_jit_kwargs(self, one, two, three):
side = []
# For the CPP jit, we need to clear the cache to prevent cache hits between
# parameterized tests.
if hasattr(self.jit, "cache_clear"):
self.jit.cache_clear()
def f(x, y, z):
print(x, y, z)
side.append(None)
return 100 * x + 10 * y + z
f = self.jit(f)
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, three) == 123
assert len(side) == 1
assert f(one, two, z=three) == 123
assert len(side) == 2 # actually recompiles from kwarg
assert f(one, two, z=three) == 123
assert len(side) == 2 # but should still cache
f(one, two, z=np.zeros(3)) # doesn't crash
if FLAGS.jax_enable_x64:
# In the above call, three is of a new type (int64), thus it should
# trigger a new compilation.
assert len(side) == 3
def test_jit_device(self):
device = xb.devices()[-1]
x = self.jit(lambda x: x, device=device)(3.)
self.assertIsInstance(x, xla.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)
def test_complex_support(self):
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
def test_jit_with_many_args_works(self):
@self.jit
def f(args_list):
return sum(args_list)
self.assertEqual(f(list(range(500))), sum(range(500)))
# Jit and Donate arguments
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
def _assertDeleted(self, x, deleted):
if hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
def test_jit_donate_argnums_warning_raised(self):
x = jnp.array([1.0, 2.0], jnp.float32)
y = jnp.array([1, 2], jnp.int32)
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
f(x, y)
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn(
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
str(w[-1].message))
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_invalidates_input(self):
# We can't just use `lambda x: x` because JAX simplifies this away to an
# empty XLA computation.
move = self.jit(lambda x: x + x - x, donate_argnums=0)
x = jnp.ones([])
y = move(x)
self.assertDeleted(x)
self.assertEqual(y, 1.)
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_static_argnums(self):
jit_fun = self.jit(
lambda a, b, c, d: ((a + b + c), (a + b + d)),
static_argnums=(0, 1),
donate_argnums=(2, 3))
a = jnp.array(1)
b = jnp.array(2)
c = jax.device_put(jnp.array([1., 1.]))
d = jax.device_put(jnp.array([1., 1., 1.]))
e, f = jit_fun(a, b, c, d)
np.testing.assert_allclose(e, jnp.array([4., 4.]))
np.testing.assert_allclose(f, jnp.array([4., 4., 4.]))
self.assertNotDeleted(a)
self.assertNotDeleted(b)
self.assertDeleted(c)
self.assertDeleted(d)
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jnp_array_copy(self):
# https://github.com/google/jax/issues/3412
@partial(self.jit, donate_argnums=(0,))
def _test(array):
return array.at[0].set(77)
x = jnp.asarray([0, 1])
x_copy = jnp.array(x, copy=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_test(x) # donation
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
print(x_copy) # doesn't crash
def test_jit_global_cache(self):
def f(x):
assert python_should_be_executing
return x
python_should_be_executing = True
self.jit(f)(2)
python_should_be_executing = False
self.jit(f)(3)
def test_jit_shallow_copy(self):
def f(x):
return copy.copy(x)
self.jit(f)(1)
def test_jit_deep_copy(self):
def f(x):
return copy.deepcopy(x)
self.jit(f)(1)
def test_disable_jit(self):
effects = []
@self.jit
def f(x):
effects.append(1)
return x
with api.disable_jit():
f(2)
f(2)
assert len(effects) == 2
f(2)
f(2)
assert len(effects) == 3
def test_static_argnum_on_method(self):
class A:
@functools.partial(self.jit, static_argnums=(0,))
def my_func_jit(self, x):
return x+2
A().my_func_jit(3)
def test_static_argnum_on_static_method_is_not_supported(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
class A:
@functools.partial(self.jit, static_argnums=(0,))
@classmethod
def my_classmethod_jit(cls, x):
return x+2
def test_classmethod_is_not_supported(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
class A:
@functools.partial(self.jit)
@staticmethod
def my_staticmethod_jit(x):
return x + 2
def test_concurrent_jit(self):
@self.jit
def f(x):
return x + x - 3.
xs = [np.random.randn(i) for i in range(10)]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x * 2 - 3., y)
class PythonJitTest(CPPJitTest):
@property
def jit(self):
return jax.api._python_jit
def test_jit_reference_dropping(self):
x = jnp.ones(10)
f = (lambda x: lambda: x)(x) # reference to x in f's closure
g = self.jit(f)
x = weakref.ref(x) # no more strong ref to x in this scope
assert x() is not None # x is still around
f() # f runs
g() # g runs
g() # g runs a second time
del f # delete the raw callable
assert x() is not None # x is still around
g() # g still runs
del g # no more references to x
assert x() is None # x is gone
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = self.jit(lambda x: x)(x)
self.assertIs(x, y)
z1, z2 = self.jit(lambda x: (x, x))(x)
self.assertIs(z1, z2)
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertIs(z1, x2)
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_jit_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: jit(f)("foo"))
def test_jit_on_all_devices(self):
# Verifies we can run the same computation on every device present, even
# if they are, for example, different models of GPU.
data = np.random.rand(1000).astype(np.float32)
f = self.jit(jnp.negative)
for device in jax.local_devices():
x = device_put(data, device=device)
np.testing.assert_array_equal(-data, f(x))
def test_jit_nested_donate_ignored(self):
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
a = jax.device_put(jnp.array(1))
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# jit_fun(a)
jit_fun(a) # doesn't crash
class APITest(jtu.JaxTestCase):
def test_grad_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: grad(f)("foo"))
def test_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
def test_value_and_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
y = f(1.0, 1.0, 1.0, flag=True)
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
def test_grad_of_jit(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
assert grad(f)(1.0) == 2.0
assert len(side) == 1
assert grad(f)(2.0) == 4.0
assert len(side) == 1
def test_jit_of_grad(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
g = jit(grad(f))
assert g(1.0) == 2.0
assert len(side) == 1
assert g(2.0) == 4.0
assert len(side) == 1
def test_bad_input(self):
def f(x):
return x
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: grad(f)("foo"))
self.assertRaisesRegex(
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
lambda: jit(f)("foo"))
def test_grad_tuple_output(self):
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_unit_output(self):
jtu.check_raises(lambda: grad(lambda x: ())(np.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_nonscalar_output(self):
jtu.check_raises(lambda: grad(lambda x: x)(np.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_unwrapped_numpy(self):
def f(x):
return np.exp(x)
with self.assertRaisesRegex(Exception, "The numpy.ndarray conversion .*"):
grad(f)(np.zeros(3))
def test_binop_mismatch(self):
def f(x, y):
return x + y
jtu.check_raises(
lambda: f(jnp.zeros(3), jnp.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
jtu.check_raises(
lambda: grad(f)(np.zeros(3), np.zeros(4)),
TypeError,
"add got incompatible shapes for broadcasting: (3,), (4,).")
def test_dot_mismatch(self):
def f(x, y):
return jnp.dot(x, y)
self.assertRaisesRegex(
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
lambda: grad(f)(np.zeros(3), np.zeros(4)))
def test_abstract_error_message(self):
for castfun in [float, complex, int]:
def f(x):
return castfun(x)
self.assertRaisesRegex(
TypeError,
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
lambda: jit(f)(1.0))
def test_switch_value_jit(self):
def f(x):
y = x > 0
if y:
return x
else:
return -x
assert grad(f)(1.0) == 1.0
assert grad(f)(-1.0) == -1.0
with self.assertRaisesRegex(core.ConcretizationTypeError,
"Abstract tracer value"):
jit(f)(1)
def test_range_err(self):
def f(x, n):
for i in range(n):
x = x + i
return x
assert jit(f, static_argnums=(1,))(0, 5) == 10
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract value passed to .*)",
lambda: jit(f)(0, 5))
def test_casts(self):
for castfun in [hex, oct, int]:
f = lambda x: castfun(x)
self.assertRaisesRegex(
TypeError,
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')
def foo(x):
return foo_p.bind(x)
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
"Evaluation rule for 'foo' not implemented")
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"Abstract evaluation for 'foo' not implemented")
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Differentiation rule for 'foo' not implemented")
foo_p.def_abstract_eval(lambda x: x)
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"XLA translation rule for primitive 'foo' not found")
foo_p.def_impl(lambda x: x)
ad.defjvp(foo_p, lambda g, x: foo(g))
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
def test_device_put_and_get(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
dx = api.device_put(x)
self.assertIsInstance(dx, xla.DeviceArray)
x2 = api.device_get(dx)
self.assertIsInstance(x2, np.ndarray)
assert np.all(x == x2)
y = [x, (2 * x, 3 * x)]
dy = api.device_put(y)
y2 = api.device_get(dy)
self.assertIsInstance(y2, list)
self.assertIsInstance(y2[0], np.ndarray)
assert np.all(y2[0] == x)
self.assertIsInstance(y2[1], tuple)
self.assertIsInstance(y2[1][0], np.ndarray)
assert np.all(y2[1][0] == 2 * x)
self.assertIsInstance(y2[1][1], np.ndarray)
assert np.all(y2[1][1] == 3 * x)
def test_device_get_scalar(self):
x = np.arange(12.).reshape((3, 4)).astype("float32")
x = api.device_put(x)
self.assertIsInstance(x, xla.DeviceArray)
y = [x, 2]
y2 = api.device_get(y)
self.assertIsInstance(y2, list)
self.assertIsInstance(y2[0], np.ndarray)
assert np.all(y2[0] == x)
self.assertIsInstance(y2[1], int)
self.assertEqual(y2[1], 2)
@parameterized.parameters([(3,)], [(2, 0)])
def test_device_put_across_devices(self, shape):
if len(api.local_devices()) < 2:
raise unittest.SkipTest("this test requires multiple devices")
d1, d2 = api.local_devices()[:2]
data = np.random.randn(*shape).astype(np.float32)
x = api.device_put(data, device=d1)
self.assertEqual(x.device_buffer.device(), d1)
y = api.device_put(x, device=d2)
self.assertEqual(y.device_buffer.device(), d2)
np.testing.assert_array_equal(data, np.array(y))
# Make sure these don't crash
api.device_put(x)
api.device_put(y)
@jtu.skip_on_devices("cpu")
def test_device_put_across_platforms(self):
default_device = jax.devices()[0]
cpu_device = jax.devices("cpu")[0]
np_arr = np.array([1,2,3])
scalar = 1
device_arr = jnp.array([1,2,3])
assert device_arr.device_buffer.device() is default_device
for val in [np_arr, device_arr, scalar]:
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device_buffer.device(), cpu_device)
def test_device_put_sharded_array(self):
devices = api.local_devices()
n_devices = len(devices)
x = [np.arange(i, i + 4) for i in range(n_devices)]
y = api.device_put_sharded(x, devices)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
self.assertAllClose(y, jnp.stack(x))
def test_device_put_sharded_pytree(self):
devices = api.local_devices()
n_devices = len(devices)
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
y1, y2 = api.device_put_sharded(x, devices)
self.assertAllClose(y1, jnp.array([a for a, _ in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
self.assertAllClose(y2, jnp.vstack([b for _, b in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
@jtu.skip_on_devices("tpu")
def test_jacobian(self):
R = np.random.RandomState(0).randn
A = R(4, 3)
x = R(3)
f = lambda x: jnp.dot(A, x)
assert np.allclose(jacfwd(f)(x), A)
assert np.allclose(jacrev(f)(x), A)
f = lambda x: jnp.tanh(jnp.dot(A, x))
assert np.allclose(jacfwd(f)(x), jacrev(f)(x))
@jtu.skip_on_devices("tpu")
def test_hessian(self):
R = np.random.RandomState(0).randn
A = R(4, 4)
x = R(4)
f = lambda x: jnp.dot(x, jnp.dot(A, x))
assert np.allclose(hessian(f)(x), A + A.T)
def test_std_basis(self):
basis = api._std_basis(jnp.zeros(3))
assert getattr(basis, "shape", None) == (3, 3)
assert np.allclose(basis, np.eye(3))
basis = api._std_basis(jnp.zeros((3, 3)))
assert getattr(basis, "shape", None) == (9, 3, 3)
assert np.allclose(basis, np.eye(9).reshape(9, 3, 3))
basis = api._std_basis([0., (jnp.zeros(3), jnp.zeros((3, 4)))])
assert isinstance(basis, list) and len(basis) == 2
assert getattr(basis[0], "shape", None) == (16,)
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
assert getattr(basis[1][0], "shape", None) == (16, 3)
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
@jtu.skip_on_devices("tpu")
def test_jacobian_on_pytrees(self):
for jacfun in [jacfwd, jacrev]:
ans = jacfun(lambda x, y: (x, y))(0., 1.)
expected = (1., 0.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
expected = (0., 1.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
expected = ((1., 0.),
(0., 1.),)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
expected = ((1., 0., 0.),
(0., 1., 0.))
self.assertAllClose(ans, expected, check_dtypes=False)
R = np.random.RandomState(0).randn
x = R(2)
y = R(3)
ans = jacfun(lambda x, y: {'x': x, 'xy': jnp.outer(x, y)})(x, y)
expected = {'x': np.eye(2),
'xy': np.kron(np.eye(2), y[:, None]).reshape(2, 3, 2)}
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_hessian_on_pytrees(self):
ans = hessian(lambda x: jnp.array(x)**2)((1., 2.))
expected = ((np.array([2., 0.]), np.array([0., 0.])),
(np.array([0., 0.]), np.array([0., 2.])))
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_issue1372(self):
def quad(x):
return jnp.dot(x, x)
def f(x, u):
return quad(x) + quad(u)
x, u = jnp.ones(5), jnp.ones(2)
rev = jacrev
fwd = jacfwd
# Diagonal entries
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
# Off-diagonal entries by reverse-mode on the outside
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
# Off-diagonal entries by forward-mode on the outside
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
def test_large_device_constant(self):
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
def test_grad_and_aux_basic(self):
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
self.assertAllClose(g, grad(lambda x: x**3)(3.))
self.assertAllClose(aux, [9.], check_dtypes=False)
def test_grad_and_aux_nested(self):
def f(x):
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
return aux[0]
f2 = lambda x: x**3
self.assertEqual(grad(f)(4.), grad(f2)(4.))
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def f(x):
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
return aux[0] * jnp.sin(x)
f2 = lambda x: x**3 * jnp.sin(x)
self.assertEqual(grad(f)(4.), grad(f2)(4.))
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
def test_grad_and_aux_constant(self):
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
self.assertEqual(g, grad(lambda x: x**3)(4.))
self.assertEqual(aux, [4.])
g, aux = grad(lambda x: (x**3, [x**2, 4.]), has_aux=True)(4.)
self.assertEqual(g, grad(lambda x: x**3)(4.))
self.assertEqual(aux, [4.**2, 4.])
def test_grad_and_aux_no_tracers(self):
# see https://github.com/google/jax/issues/1950
def f(x):
aux = dict(identity=x, p1=x+1)
return x ** 2, aux
_, aux = jax.grad(f, has_aux=True)(3.)
self.assertIsInstance(aux, dict)
for val in aux.values():
self.assertNotIsInstance(val, core.Tracer)
def test_jvp_mismatched_arguments(self):
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), ()))
# If primals and tangents must both be tuples or both lists
self.assertRaisesRegex(
TypeError,
("primal and tangent arguments to jax.jvp must have the same tree "
"structure"),
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), [np.float32(2)]))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must have equal types",
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
def test_jvp_non_tuple_arguments(self):
def f(x, y): return x + y
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
lambda: api.jvp(f, 0., (1.,)))
self.assertRaisesRegex(
TypeError,
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
lambda: api.jvp(f, (0.,), np.array([1., 2.])))
def test_vjp_mismatched_arguments(self):
_, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4))
self.assertRaisesRegex(
TypeError,
"Tree structure of cotangent input.*does not match",
lambda: pullback((np.float32(7), np.float32(100))))
self.assertRaisesRegex(
TypeError,
"Type of cotangent input to vjp pullback.*does not match type",
lambda: pullback((np.float16(42))))
def test_jvp_jit_cached(self):
"""Bug in caching in presence of JVP and JIT."""
def func(x):
def inner(y):
return y * x
# Must have two calls to the inner jit (the second one hits the cache)
res1 = api.jit(inner)(4.)
res2 = api.jit(inner)(5.)
return res1 + res2
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
def test_linear_transpose_abstract(self):
x = types.SimpleNamespace(shape=(3,), dtype=np.float32)
y = jnp.arange(3, dtype=np.float32)
transpose_fun = api.linear_transpose(lambda x: 2 * x, x)
z, = transpose_fun(y)
self.assertArraysEqual(2 * y, z, check_dtypes=True)
def test_linear_transpose_error(self):
with self.assertRaisesRegex(
TypeError, "linear_transpose only supports float and complex inputs"):
api.linear_transpose(lambda x: x, 1)
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: jnp.stack([x, x]), 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: 1j * x, 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1.0)
transpose_fun = api.linear_transpose(lambda x: x, 1.0)
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
transpose_fun(1j)
def test_linear_transpose_complex(self):
f = lambda x: (1 + 2j) * x
transpose = api.linear_transpose(f, 1j)
actual, = transpose(3 + 4j)
expected = -5 + 10j
self.assertEqual(actual, expected)
def test_complex_grad_raises_error(self):
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))
def test_holomorphic_grad(self):
out = grad(lambda x: jnp.sin(x), holomorphic=True)(1 + 2j)
expected = 2.0327230070196656 - 3.0518977991518j
self.assertAllClose(out, expected, check_dtypes=False)
def test_nonholomorphic_grad(self):
zs = 0.5j * np.arange(5) + np.arange(5)
def f(z):
return jnp.sum(jnp.cos(jnp.abs(z)))
ans = grad(f)(zs)
expected = np.array([ 0. +0.j,
-0.80430663+0.40215331j,
-0.70368982+0.35184491j,
0.1886467 -0.09432335j,
0.86873727-0.43436864j])
self.assertAllClose(ans, expected, check_dtypes=False,
atol=jtu.default_gradient_tolerance,
rtol=jtu.default_gradient_tolerance)
def test_complex_output_jacrev_raises_error(self):
self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j))
def test_nonholomorphic_jacrev(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * np.arange(5) + np.arange(5)
def f(z):
return jnp.cos(jnp.linalg.norm(2 * z))
ans = jacrev(f)(zs)
expected = grad(f)(zs)
self.assertAllClose(ans, expected)
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
def test_legacy_devicearray_repr(self):
dx = device_put(3.)
str(dx.item()) # doesn't crash
def test_devicearray_repr(self):
x = device_put(jnp.zeros(3))
self.assertIsInstance(x, xla.DeviceArray)
repr(x) # doesn't crash
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
self.assertIsInstance(x, xla.DeviceArray)
repr(x) # doesn't crash
def test_devicearray_delete(self):
x = device_put(1.)
x.delete()
self.assertRaisesRegex(ValueError, "DeviceArray has been deleted.",
lambda: repr(x))
def test_devicearray_block_until_ready(self):
x = device_put(1.)
y = x.block_until_ready()
# Tests mostly that block_until_ready() does not produce an error.
self.assertTrue(y is x)
def test_namedtuple_transparency(self):
# See https://github.com/google/jax/issues/446
Point = collections.namedtuple("Point", ["x", "y"])
def f(pt):
return jnp.sqrt(pt.x ** 2 + pt.y ** 2)
pt = Point(1., 2.)
f(pt) # doesn't crash
g = api.grad(f)(pt)
self.assertIsInstance(g, Point)
f_jit = api.jit(f)
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
def test_namedtuple_subclass_transparency(self):
# See https://github.com/google/jax/issues/806
Point = collections.namedtuple("Point", ["x", "y"])
class ZeroPoint(Point):
def is_zero(self):
return (self.x == 0) and (self.y == 0)
pt = ZeroPoint(0., 0.)
def f(pt):
return 0. if pt.is_zero() else jnp.sqrt(pt.x ** 2 + pt.y ** 2)
f(pt) # doesn't crash
_ = api.grad(f)(pt)
self.assertIsInstance(pt, ZeroPoint)
@parameterized.parameters(1, 2, 3)
def test_shape_dtype_struct(self, i):
s = api.ShapeDtypeStruct(shape=(i, 2, 3), dtype=jnp.float32)
self.assertEqual(s.shape, (i, 2, 3))
self.assertEqual(s.dtype, jnp.float32)
self.assertEqual(s.ndim, 3)
self.assertEqual(s.size, i * 2 * 3)
self.assertLen(s, i)
for f in (str, repr):
self.assertEqual(
f(s), "ShapeDtypeStruct(shape=({}, 2, 3), dtype=float32)".format(i))
def test_shape_dtype_struct_scalar(self):
s = api.ShapeDtypeStruct(shape=(), dtype=jnp.float32)
self.assertEmpty(s.shape)
self.assertEqual(s.size, 1)
self.assertEqual(s.ndim, 0)
with self.assertRaisesRegex(TypeError, "len[(][)] of unsized object"):
_ = len(s)
def test_eval_shape(self):
def fun(x, y):
return jnp.tanh(jnp.dot(x, y) + 3.)
x = jnp.ones((2, 3))
y = jnp.ones((3, 4))
out_shape = api.eval_shape(fun, x, y)
self.assertEqual(out_shape.shape, (2, 4))
def test_eval_shape_constants(self):
def fun():
x = jnp.ones((2, 3))
y = jnp.ones((3, 4))
return jnp.tanh(jnp.dot(x, y) + 3.)
out_shape = api.eval_shape(fun)
self.assertEqual(out_shape.shape, (2, 4))
def test_eval_shape_tuple_unpacking(self):
def fun(x, y):
a, b = x
return a + b + y
x = (jnp.ones(2), jnp.ones(2))
y = 3.
out_shape = api.eval_shape(fun, x, y)
self.assertEqual(out_shape.shape, (2,))
def test_eval_shape_tuple_itemgetting(self):
def fun(x, y):
return x[0] + x[1] + y
x = (jnp.ones(2), jnp.ones(2))
y = 3.
out_shape = api.eval_shape(fun, x, y)
self.assertEqual(out_shape.shape, (2,))
def test_eval_shape_output_dict(self):
def fun(x, y):
return {'hi': x[0] + x[1] + y}
x = (jnp.ones(2), jnp.ones(2))
y = 3.
out_shape = api.eval_shape(fun, x, y)
out_shape = tree_util.tree_map(np.shape, out_shape)
self.assertEqual(out_shape, {'hi': (2,)})
def test_eval_shape_shape_error(self):
def fun(x, y):
return jnp.tanh(jnp.dot(x, y) + 3.)
x = jnp.ones((3, 3))
y = jnp.ones((4, 4))
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
def test_eval_shape_duck_typing(self):
def fun(A, b, x):
return jnp.dot(A, x) + b
class MyArgArray(object):
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
A = MyArgArray((3, 4), jnp.float32)
b = MyArgArray((5,), jnp.float32)
x = MyArgArray((4, 5), jnp.float32)
out_shape = api.eval_shape(fun, A, b, x)
self.assertEqual(out_shape.shape, (3, 5))
def test_issue_871(self):
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
x = jnp.array([1, 2, 3])
y, f_jvp = api.linearize(jnp.sum, x)
jtu.check_raises(lambda: f_jvp(T), ValueError,
("linearized function called on tangent values "
"inconsistent with the original primal values."))
y, f_jvp = api.linearize(api.jit(jnp.sum), x)
jtu.check_raises(lambda: f_jvp(T), ValueError,
("linearized function called on tangent values "
"inconsistent with the original primal values."))
def test_partial_eval_lower(self):
# this is a simplified model of a bug that arose when we first used @jit in
# a jvp rule. it's in this file because we want to use make_jaxpr.
# NOTE(mattjj): I no longer understand what this was meant to test. My guess
# is it was related to staging out the broadcast into a jaxpr to be
# transposed, but after #1749 that's no longer a problem. After changing
# make_jaxpr (and jit) to stage out sub-calls fully, this test started to
# fail; I left it in as skipped because deleting tests feels wrong.
raise unittest.SkipTest("obsolete test")
@api.jit
def f(a, b, c):
a = lax.broadcast(a, (2,))
return lax.select(a, b, c)
a = np.ones((3, 3), dtype=np.bool_)
b = np.ones((2, 3, 3))
c = np.ones((2, 3, 3))
jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c)
subjaxpr = next(eqn.params["call_jaxpr"] for eqn in jaxpr.jaxpr.eqns
if "call_jaxpr" in eqn.params)
self.assertEqual(len(subjaxpr.eqns), 1)
def test_grad_of_int_errors(self):
dfn = grad(lambda x: x ** 2)
self.assertRaisesRegex(
TypeError,
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
lambda: dfn(3))
def test_grad_complex_result_errors(self):
dfn = grad(lambda x: x ** 2 + 1j)
self.assertRaisesRegex(
TypeError,
(r"grad requires real-valued outputs \(output dtype that is a "
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3.))
def test_holomorphic_grad_of_float_errors(self):
dfn = grad(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"grad with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_holomorphic_jacrev_of_float_errors(self):
dfn = jacrev(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"jacrev with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_holomorphic_jacfwd_of_float_errors(self):
dfn = jacfwd(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"jacfwd with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_jacfwd_of_complex_errors(self):
dfn = jacfwd(lambda x: x ** 2)
self.assertRaisesRegex(
TypeError,
(r"jacfwd requires real-valued inputs \(input dtype that is a "
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3. + 1j))
def test_xla_computation(self):
# these tests basically check the examples in the xla_computation docstring
def e(x):
return jnp.sin(jnp.cos(x))
c = api.xla_computation(e)(2.)
self.assertIn('cosine', c.as_hlo_text())
self.assertIn('sine', c.as_hlo_text())
def f(x):
return x - lax.psum(x, 'i')
axis_env = [('i', 4)]
c = api.xla_computation(f, axis_env=axis_env)(2)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
def g(x):
rowsum = lax.psum(x, 'i')
colsum = lax.psum(x, 'j')
allsum = lax.psum(x, ('i', 'j'))
return rowsum, colsum, allsum
axis_env = [('i', 4), ('j', 2)]
c = api.xla_computation(g, axis_env=axis_env)(5.)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
def h(x):
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
colsum = lax.psum(x, 'j')
return rowsum, colsum
axis_env = [('i', 4), ('j', 2)]
c = api.xla_computation(h, axis_env=axis_env)(5.)
self.assertIn('all-reduce', c.as_hlo_text())
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
def test_xla_computation_args(self):
def foo(x, y, z):
return x + y + z
c = api.xla_computation(foo)(1., 2., 3.)
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xb.xla_client.PrimitiveType.TUPLE)
def test_xla_computation_duck_typing(self):
def foo(x, y, z):
return x + y + z
x = jax.ShapeDtypeStruct((), np.float32)
y = jax.ShapeDtypeStruct((), np.float32)
z = jax.ShapeDtypeStruct((), np.float32)
c = api.xla_computation(foo)(x, y, z)
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xb.xla_client.PrimitiveType.TUPLE)
def test_staging_out_multi_replica(self):
def f(x):
return api.pmap(jnp.mean)(x)
xla_comp = api.xla_computation(f)
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
def test_xla_computation_instantiate_constant_outputs(self):
def f():
return jnp.zeros((3, 4))
if config.omnistaging_enabled:
xla_comp = api.xla_computation(f)()
else:
xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
self.assertEqual(out_shape.dimensions(), (3, 4))
def test_xla_computation_static_argnums(self):
def f(x, y):
return x + y
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
hlo_text = xla_comp.as_hlo_text()
self.assertIn("constant(3)", hlo_text)
# The static arguments should be removed from the function being compiled,
# thus the function should have only a single argument.
self.assertIn("parameter.1", hlo_text)
self.assertNotIn("parameter.2", hlo_text)
def test_xla_computation_return_shape(self):
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
return_shape=True)(np.int32(1))
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
self.assertEqual(shape_tree, expected)
def test_xla_computation_partitioned(self):
def f(x, y):
return jnp.dot(x, y) + 1
x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
out_parts=P(4, 1))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text)
def test_xla_computation_replicated_and_partitioned(self):
def f(x, y):
return jnp.dot(x, y), lax.psum(x, 'i')
x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
axis_env = [('i', 4)]
xla_comp = api.xla_computation(f, axis_env=axis_env,
in_parts=(P(2, 2), None),
out_parts=(P(4, 1), None))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('all-reduce', hlo_text)
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
def test_xla_computation_psum_constant(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test requires omnistaging")
f = lambda: jax.lax.psum(1, "i")
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
@jtu.skip_on_devices("cpu", "gpu")
def test_xla_computation_donate_argnums(self):
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
def test_concurrent_device_get_and_put(self):
def f(x):
for _ in range(100):
y = jax.device_put(x)
x = jax.device_get(y)
return x
xs = [np.random.randn(i) for i in range(10)]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(partial(f, x)) for x in xs]
ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
self.assertAllClose(x, y)
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
return # test only applies when x64 is disabled
def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
nowarn() # get rid of extra startup warning
prev_len = len(w)
nowarn()
assert len(w) == prev_len
warn()
assert len(w) > 0
msg = str(w[-1].message)
expected_prefix = "Explicitly requested dtype "
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
prev_len = len(w)
nowarn()
assert len(w) == prev_len
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
lambda: jnp.array([1, 2, 3], dtype="float32"),)
check_warning(lambda: jnp.ones(3, dtype=np.float64),
lambda: jnp.ones(3))
check_warning(lambda: jnp.ones_like(3, dtype=np.int64),
lambda: jnp.ones_like(3, dtype=np.int32))
check_warning(lambda: jnp.zeros(3, dtype="int64"),
lambda: jnp.zeros(3, dtype="int32"))
check_warning(lambda: jnp.zeros_like(3, dtype="float64"),
lambda: jnp.zeros_like(3, dtype="float32"))
check_warning(lambda: jnp.full((2, 3), 1, dtype="int64"),
lambda: jnp.full((2, 3), 1))
check_warning(lambda: jnp.ones(3).astype("float64"),
lambda: jnp.ones(3).astype("float32"))
check_warning(lambda: jnp.eye(3, dtype=np.float64),
lambda: jnp.eye(3))
check_warning(lambda: jnp.arange(3, dtype=np.float64),
lambda: jnp.arange(3, dtype=np.float32))
check_warning(lambda: jnp.linspace(0, 3, dtype=np.float64),
lambda: jnp.linspace(0, 3, dtype=np.float32))
check_warning(lambda: jnp.tri(2, dtype="float64"),
lambda: jnp.tri(2, dtype="float32"))
def test_vmap_preserves_docstr(self):
def superfun(a):
"""Does things with stuff."""
pass
self.assertRegex(api.vmap(superfun).__doc__, "\n".join([
"Vectorized version of superfun.*",
"",
"Original documentation:",
"",
superfun.__doc__,
]))
def test_vmap_in_axes_list(self):
# https://github.com/google/jax/issues/2367
dictionary = {'a': 5., 'b': jnp.ones(2)}
x = jnp.zeros(3)
y = jnp.arange(3.)
def f(dct, x, y):
return dct['a'] + dct['b'] + x + y
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
self.assertAllClose(out1, out2)
def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
self.assertRaisesRegex(
ValueError,
"vmap in_axes specification must be a tree prefix of the corresponding "
r"value, got specification \(0, 0\) for value tree "
r"PyTreeDef\(tuple, \[\*\]\).",
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(jnp.ones(3))
)
def test_vmap_in_axes_leaf_types(self):
with self.assertRaisesRegex(
TypeError, r"vmap in_axes must be an int, None, or .*"):
api.vmap(lambda x: x, in_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
def test_vmap_out_axes_leaf_types(self):
with self.assertRaisesRegex(
TypeError, r"vmap out_axes must be an int, None, or .*"):
api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
def test_vmap_unbatched_object_passthrough_issue_183(self):
# https://github.com/google/jax/issues/183
fun = lambda f, x: f(x)
vfun = api.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, jnp.arange(3))
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
def h(a, b):
return jnp.sum(a) + jnp.sum(b)
X = np.random.randn(10, 4)
U = np.random.randn(10, 2)
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
"so\n"
"arg 0 has an axis to be mapped of size 10\n"
"arg 1 has an axis to be mapped of size 2"):
api.vmap(h, in_axes=(0, 1))(X, U)
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
"so\n"
"args 0, 2 have axes to be mapped of size 10\n"
"arg 1 has an axis to be mapped of size 2"):
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
with self.assertRaisesRegex(
ValueError,
"vmap got inconsistent sizes for array axes to be mapped:\n"
"the tree of axis sizes is:\n"
r"\(10, \[2, 2\]\)"):
api.vmap(h, in_axes=(0, 1))(X, [U, U])
with self.assertRaisesRegex(
ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
# The mapped inputs cannot be scalars
api.vmap(lambda x: x)(1.)
with self.assertRaisesRegex(
ValueError, "vmap must have at least one non-None value in in_axes"):
# If the output is mapped, there must be a non-None in_axes
api.vmap(lambda x: x, in_axes=None)(jnp.array([1., 2.]))
with self.assertRaisesRegex(
ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
api.vmap(lambda x: x, in_axes=1)(jnp.array([1., 2.]))
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
with self.assertRaisesRegex(
ValueError,
"vmap out_axes specification must be a tree prefix of the "
"corresponding value.*"):
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(jnp.array([1., 2.]))
with self.assertRaisesRegex(
ValueError, "vmap has mapped output but out_axes is None"):
# If the output is mapped, then there must be some out_axes specified
api.vmap(lambda x: x, out_axes=None)(jnp.array([1., 2.]))
def test_vmap_structured_in_axes(self):
A, B, C, D = 2, 3, 4, 5
K = 6 # batch size
x = np.ones((K, A, B)) # batch axis in different locations
y = np.ones((B, K, C))
z = np.ones((C, D, K))
def foo(tree_arg):
x, (y, z) = tree_arg
return jnp.dot(x, jnp.dot(y, z))
tree = (x, (y, z))
vfoo = api.vmap(foo, in_axes=((0, (1, 2)),))
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
Point = collections.namedtuple("Point", ["x", "y"])
tree = (x, Point(y, z))
vfoo = api.vmap(foo, in_axes=((0, Point(1, 2)),))
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
def foo(tree_arg):
x, dct = tree_arg
y, z = dct['a'], dct['b']
return jnp.dot(x, jnp.dot(y, z))
tree = (x, {'a': y, 'b': z})
vfoo = api.vmap(foo, in_axes=((0, {'a': 1, 'b': 2}),))
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
tree = (x, collections.OrderedDict([('a', y), ('b', z)]))
vfoo = api.vmap(
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
def test_pmap_global_cache(self):
def f(x):
assert python_should_be_executing
return x
x = np.ones(1)
python_should_be_executing = True
api.pmap(f)(x)
python_should_be_executing = False
api.pmap(f)(x)
python_should_be_executing = True
api.pmap(f, 'i')(x)
python_should_be_executing = False
api.pmap(f, 'i')(x)
def test_device_array_repr(self):
rep = repr(jnp.ones(()) + 1.)
self.assertStartsWith(rep, 'DeviceArray')
def test_grad_without_enough_args_error_message(self):
# https://github.com/google/jax/issues/1696
def f(x, y): return x + y
df = api.grad(f, argnums=0)
self.assertRaisesRegex(
TypeError,
"differentiating with respect to argnums=0 requires at least 1 "
"positional arguments to be passed by the caller, but got only 0 "
"positional arguments.",
lambda: partial(df, x=0.)(y=1.))
def test_grad_of_jit_compilation_caching(self):
if not hasattr(self, "assertLogs"):
raise unittest.SkipTest("test requires assertLogs (python 3)")
lax.add(1, 2) # make sure some initial warnings are already printed
sin = api.jit(jnp.sin)
prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
ans1 = api.grad(sin)(2.)
ans2 = api.grad(sin)(3.)
finally:
logging.set_verbosity(prev_level)
self.assertLen(l.output, 2)
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
self.assertIs(x, y)
z1, z2 = api.jit(lambda x: (x, x))(x)
self.assertIs(z1, z2)
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertIs(z1, x2)
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_nested_jit_hoisting(self):
@api.jit
def f(x, y):
z = 2 * x
return y + z, 3
@api.jit
def g(x):
return f(2, x)
jaxpr_subcomp = xla.jaxpr_subcomp
jaxprs = []
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
jaxprs.append(jaxpr)
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
try:
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
ans = g(3)
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
self.assertEqual(ans, (7, 3))
self.assertLen(jaxprs, 2)
outer_jaxpr, inner_jaxpr = jaxprs
self.assertLen(outer_jaxpr.eqns, 1)
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
self.assertLen(inner_jaxpr.eqns, 2)
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
def test_primitive_compilation_cache(self):
with jtu.count_primitive_compiles() as count:
lax.add(1, 2)
lax.add(2, 3)
self.assertEqual(count[0], 1)
def test_arange_jit(self):
# see https://github.com/google/jax/issues/553
def fun(x):
r = jnp.arange(x.shape[0])[x]
return r
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
def helper_save_tracer(self, x):
self._saved_tracer = x
return x
def test_escaped_tracers_different_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError, "Encountered an unexpected tracer"):
api.jit(lambda x: self._saved_tracer)(0.)
def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer",
re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
api.grad(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_incompatible_sublevel(self):
def func1(x):
api.jit(self.helper_save_tracer)(0.)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer",
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracers_cant_lift(self):
def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Can't lift",
re.DOTALL)):
api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self):
def func1(x):
api.grad(self.helper_save_tracer)(x)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracer_omnistaging(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 1
@jit
def f():
nonlocal count
count = jnp.add(count, 1)
f() # leaked a tracer! but currently undetected
def f(x, c):
jnp.add(count, 1)
return None, None
@jit
def g():
lax.scan(f, None, None, length=2)
with self.assertRaisesRegex(core.UnexpectedTracerError,
"tracer created on line"):
g()
def test_pmap_static_kwarg_error_message(self):
# https://github.com/google/jax/issues/3007
def f(a, b):
return a + b
g = jax.pmap(f, static_broadcasted_argnums=(1,))
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
r"called with only 1 positional argument. All static broadcasted "
r"arguments must be passed positionally.")
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)
def test_vmap_unmapped_last(self):
@partial(jax.vmap, out_axes=-1)
def f(x):
return np.zeros((2,))
f(np.zeros((5,)))
def test_xla_constant_dedup(self):
y = np.array([7, 14], dtype=np.float32)
def f(x):
return x + y + y
x = np.array([1, 2], dtype=np.float32)
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
hlo_lines = set([s.strip() for s in hlo_lines])
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
def test_omnistaging_flag(self):
if FLAGS.jax_omnistaging:
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
self.assertLen(jaxpr.jaxpr.eqns, 1)
else:
# omnistaging can be enabled programmatically without setting the flag,
# but that shouldn't happen in tests
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_eval_context(self):
@jit
def f():
with core.eval_context():
assert jnp.add(1, 1) == 2
f() # doesn't crash
def test_xla_computation_zeros_doesnt_device_put(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 0
def device_put_and_count(*args, **kwargs):
nonlocal count
count += 1
return orig_device_put(*args, **kwargs)
orig_device_put, xla.device_put = xla.device_put, device_put_and_count
try:
api.xla_computation(lambda: jnp.zeros(3))()
finally:
xla.device_put = orig_device_put
self.assertEqual(count, 0)
class RematTest(jtu.JaxTestCase):
def test_remat_basic(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x)), 3.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = np.sin(np.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans, f_lin = api.linearize(f, 2.)
expected = np.sin(np.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = np.cos(np.sin(2.)) * np.cos(2.) * 3.
self.assertAllClose(ans, expected, check_dtypes=False)
sin_calls = []
cos_calls = []
sin_impl = lax.sin_p.impl
cos_impl = lax.cos_p.impl
try:
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
f_lin(3.)
finally:
lax.sin_p.def_impl(sin_impl)
lax.cos_p.def_impl(cos_impl)
self.assertEqual(len(sin_calls), 1)
self.assertEqual(len(cos_calls), 2)
def test_remat_freevars(self):
def f1(x):
y = 2 * jnp.sin(x)
z = jnp.cos(x) * jnp.sin(y)
return z
def f2(x):
y = 2 * jnp.sin(x)
z = api.remat(lambda x: jnp.cos(x) * jnp.sin(y))(x)
return z
ans, f_lin = api.linearize(f2, 2.)
expected, f_lin_expected = api.linearize(f1, 2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_grad_python_control_flow(self):
@partial(api.remat, concrete=True)
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = np.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_jit(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
def f_(x):
return g(x)
f = api.jit(f_)
ans = f(2.)
expected = np.sin(np.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = np.cos(np.sin(2.)) * np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(f_))(2.)
expected = np.cos(np.sin(2.)) * np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_vmap(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
x = np.arange(3.)
ans = api.vmap(g)(x)
expected = np.sin(np.sin(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacfwd(g)(x)
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacrev(g)(x)
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_higher_order_autodiff(self):
def f(x):
return lax.cos(lax.sin(x))
g = api.remat(f)
ans = api.grad(api.grad(g))(3.)
expected = api.grad(api.grad(f))(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_scan(self):
to_scan = lambda c, x: (jnp.sin(c), None)
def f_noremat(x):
y, _ = lax.scan(to_scan, x, np.arange(3.))
return y
def f_yesremat(x):
y, _ = lax.scan(api.remat(to_scan), x, np.arange(3.))
return y
ans = f_yesremat(4.)
expected = f_noremat(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f_yesremat)(4.)
expected = api.grad(f_noremat)(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_remat_no_redundant_flops(self):
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
@api.jit
def g(x):
return f(2., x)
@api.remat
def f(x, y):
return jnp.sin(x) * y
# We swap out sin_p's impl rule to count how many times it's invoked
called = []
sin_impl = lax.sin_p.impl
try:
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
api.grad(g)(3.)
finally:
lax.sin_p.def_impl(sin_impl)
num_calls = len(called)
self.assertLessEqual(num_calls, 1)
def test_remat_binomial_checkpointing(self):
def binom_checkpoint(funs):
if len(funs) == 1:
return funs[0]
else:
f1 = binom_checkpoint(funs[:len(funs)//2])
f2 = binom_checkpoint(funs[len(funs)//2:])
return api.remat(lambda x: f1(f2(x)))
f1 = binom_checkpoint([jnp.sin, jnp.sin, jnp.sin, jnp.sin])
f2 = lambda x: jnp.sin(jnp.sin(jnp.sin(jnp.sin(x))))
x = 4.
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
def test_remat_symbolic_zeros(self):
# code from https://github.com/google/jax/issues/1907
key = jax.random.PRNGKey(0)
key, split = jax.random.split(key)
n = 5
def func(D0):
def shift(R, dR, **unused_kwargs):
return R + dR
def apply_fn(R):
return D0 * R
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
dtype=jnp.float32)
def move(R,i):
F = apply_fn(R)
return shift(R, 0.001 * F), jnp.array([0.])
move = api.remat(move)
R, temp = lax.scan(move, Rinit, jnp.arange(2))
return R[0, 0]
api.grad(func)(5.0) # doesn't crash
def test_remat_jit2(self):
@api.jit
def f(x):
y = 2 * x
@api.remat
def g():
return y
return g()
self.assertAllClose(f(3), 6, check_dtypes=False)
def test_remat_nontrivial_env(self):
# simplified from https://github.com/google/jax/issues/2030
@api.remat
def foo(state, dt=0.5, c=1):
u, u_t = state
u_tt = c**2 * u
u_t = u_t + u_tt * dt
return (u, u_t)
@partial(api.jit, static_argnums=(1,))
def _multi_step(state, count, dt, c):
f = lambda s, _: (foo(s, dt, c), _)
return lax.scan(f, state, None, count)
def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):
return _multi_step(state, count, dt, c)
def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1):
init = (u0, jnp.zeros_like(u0))
(uf, _), _ = multi_step(init, steps, dt, c)
return ((uf - target) ** 2).mean()
target = jnp.zeros((128, 128))
u0 = jnp.ones_like(target)
loss(u0, target, 10) # doesn't crash
def test_remat_jit3(self):
# https://github.com/google/jax/issues/2180
def f(w, x):
a = jnp.dot(x, w)
b = jnp.einsum("btd,bTd->btT", a, a)
c = jnp.einsum("btT,btd->btd", b, a)
return jnp.sum(c)
w = jnp.ones([1, 1])
x = jnp.ones([1, 1, 1])
f = api.remat(f)
api.grad(f)(w, x) # doesn't crash
@api.jit
def mul(a, b):
return a * b
def f(w, x):
a = mul(w, x)
b = mul(a, a)
return b
w = 1.
x = 1.
f = api.remat(f)
api.grad(f)(w, x) # doesn't crash
def test_remat_scan2(self):
# https://github.com/google/jax/issues/1963
def scan_bug(x0):
f = lambda x, _: (x + 1, None)
def scanned_f(x, _):
return lax.scan(f, x, xs=None, length=1)[0], None
x, _ = jax.remat(scanned_f)(x0, None)
return x
jax.grad(scan_bug)(1.0) # doesn't crash
def test_remat_jit_static_argnum(self):
# https://github.com/google/jax/issues/2833
if config.omnistaging_enabled:
raise unittest.SkipTest("test only works without omnistaging") # see next test
def f(a_bool, y):
if a_bool:
return y + 1
else:
return y
api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash
def test_remat_jit_static_argnum_omnistaging(self):
# https://github.com/google/jax/issues/2833
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging") # see previous test
def named_call(f):
def named_f(*args):
f_ = lu.wrap_init(lambda: (f(*args),))
out, = core.call_p.bind(f_)
return out
return named_f
def f(a_bool, y):
if a_bool:
return y + 1
else:
return y
api.jit(named_call(f), static_argnums=0)(True, 1) # no crash
def test_remat_eval_counter(self):
# https://github.com/google/jax/issues/2737
add_one_p = Primitive('add_one')
add_one = add_one_p.bind
num_evals = 0
@contextmanager
def assertEvals(n):
start = num_evals
yield
assert num_evals - start == n
def add_one_impl(x):
nonlocal num_evals
num_evals += 1
return x + 1
add_one_p.def_impl(add_one_impl)
def add_one_jvp(pin, tin):
pout = add_one(pin[0])
return pout, pout * tin[0]
ad.primitive_jvps[add_one_p] = add_one_jvp
add_one_p.def_abstract_eval(lambda x: x)
v = np.zeros((1,))
f = jax.remat(add_one)
g = jax.remat(lambda x: add_one(f(x)))
# 2 calls needed to evaluate g
with assertEvals(2):
_, vjp = jax.vjp(g, v)
# 2 calls made while transposing g, 1 call made while transposing f
with assertEvals(3):
vjp(v)
@jax.util.curry
def call(f, *args):
return jax.core.call(
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
*args, name='foo')[0]
f = call(add_one)
g = jax.remat(lambda x: add_one(f(x)))
# 2 calls needed to evaluate g
with assertEvals(2):
_, vjp = jax.vjp(g, v)
# 2 calls made while transposing g, no reevaluation for transposition of f
with assertEvals(2):
vjp(v)
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):
jaxpr = api.make_jaxpr(lambda x: x + 2)(42)
self.assertLen(jaxpr.jaxpr.constvars, 0)
def test_const(self):
def fun(x):
return (x, 1., np.zeros(1))
if config.omnistaging_enabled:
expected = """
{ lambda a ; b.
let
in (b, 1.0, a) }
"""
else:
expected = """
{ lambda b ; a.
let
in (a, 1.0, b) }
"""
jaxpr = api.make_jaxpr(fun)(0.)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
def test_cond(self):
def f(x):
return lax.cond(x >= 0.,
x + 1.,
lambda xt: xt + x,
x + 2.,
lambda xf: xf - x)
if config.omnistaging_enabled:
expected = """
{ lambda ; a.
let b = ge a 0.0
c = add a 1.0
d = add a 2.0
e = convert_element_type[ new_dtype=int32
old_dtype=bool ] b
f = cond[ branches=( { lambda ; e_ a b c.
let d = sub c a
in (d,) }
{ lambda ; a f_ b c.
let d = add b a
in (d,) } )
linear=(False, False, False, False) ] e a a c d
in (f,) }
"""
else:
expected = """
{ lambda ; a.
let b = ge a 0.0
c = convert_element_type[ new_dtype=int32
old_dtype=bool ] b
d = add a 1.0
e = add a 2.0
f = cond[ branches=( { lambda ; e_ c a b.
let d = sub b c
in (d,) }
{ lambda ; c f_ a b.
let d = add a c
in (d,) } )
linear=(False, False, False, False) ] c a a d e
in (f,) }
"""
jaxpr = api.make_jaxpr(f)(3.)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
def test_make_jaxpr_static_argnums(self):
def f(x, y):
return x + y
jaxpr = api.make_jaxpr(f, static_argnums=(1,))(2, 3)
self.assertIn('3', str(jaxpr))
class LazyTest(jtu.JaxTestCase):
@contextmanager
def count_compiles(self):
make_computation_builder = xb.make_computation_builder
count = [0]
def make_computation_builder_and_count(*args, **kwargs):
count[0] += 1
return make_computation_builder(*args, **kwargs)
xb.make_computation_builder = make_computation_builder_and_count
try:
yield count
finally:
xb.make_computation_builder = make_computation_builder
@jtu.skip_on_devices("tpu")
def test_lazy_jit_closed_over_values(self):
if not core.skip_checks:
raise unittest.SkipTest("oom test skipped when core.skip_checks is False")
y = jnp.arange(int(1e12)) # will likely oom if materialized
ans = jit(lambda x: (x + y)[1])(1)
self.assertEqual(ans, 2)
def test_jit_forces_arguments(self):
@api.jit
def f(x):
assert python_should_be_executing
return jnp.sum(x)
x = jnp.arange(10, dtype=jnp.int32)
assert xla.is_device_constant(x) # lazy iota
python_should_be_executing = True
_ = f(x)
python_should_be_executing = False # should not recompile
x = np.arange(10, dtype=np.int32)
_ = f(x)
@parameterized.parameters(jtu.cases_from_list(range(10000)))
def test_random_lazy_program(self, seed):
def random_array(rng):
kind = rng.choice(['arr', 'iota', 'eye', 'tri'])
if kind == 'arr':
dtype = [np.float32, np.int32][rng.choice(2)]
dim = rng.randint(4)
shape = rng.randint(4, size=dim)
np_x = np.asarray(rng.randn(*shape), dtype=dtype)
jax_x = jnp.array(np_x, dtype=dtype)
elif kind == 'iota':
dtype = [np.float32, np.int32][rng.choice(2)]
size = rng.randint(5)
np_x = np.arange(size, dtype=dtype)
jax_x = lax.iota(dtype, size)
elif kind == 'eye':
dtype = [np.float32, np.int32][rng.choice(2)]
N = rng.randint(2, 5)
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
k = rng.choice([-1, 0, 1])
np_x = np.eye(N, M, k, dtype=dtype)
jax_x = jnp.eye(N, M, k, dtype=dtype)
elif kind == 'tri':
dtype = [np.float32, np.int32][rng.choice(2)]
N = rng.randint(2, 5)
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
k = rng.choice([-1, 0, 1])
np_x = np.tri(N, M, k, dtype=dtype)
jax_x = jnp.tri(N, M, k, dtype=dtype)
else:
assert False
assert type(np_x) is np.ndarray and type(jax_x) is xla.DeviceArray
return np_x, jax_x
def random_op(rng, shape):
kind = rng.choice(['transpose', 'broadcast', 'reshape'])
if kind == 'transpose':
perm = tuple(rng.permutation(len(shape)))
return Op(partial(np.transpose, axes=perm),
partial(lax.transpose, permutation=perm))
elif kind == 'broadcast':
n = rng.randint(1, 3)
new_sizes = rng.randint(1, 4, size=n)
new_ndim = n + len(shape)
bcast_dims = tuple(sorted(rng.permutation(new_ndim)[:len(shape)]))
shape_iter = iter(shape)
new_sizes = iter(rng.randint(1, 4, size=n))
new_shape = [next(shape_iter) if i in bcast_dims else next(new_sizes)
for i in range(new_ndim)]
return Op(partial(lax_reference.broadcast_in_dim, shape=new_shape,
broadcast_dimensions=bcast_dims),
partial(lax.broadcast_in_dim, shape=new_shape,
broadcast_dimensions=bcast_dims))
elif kind == 'reshape':
new_shape = list(shape)
for _ in range(rng.randint(1, 3)):
loc = len(new_shape) and rng.randint(len(new_shape))
new_shape.insert(loc, 1)
new_shape = tuple(new_shape)
return Op(partial(np.reshape, newshape=new_shape),
partial(lax.reshape, new_sizes=new_shape))
else:
assert False
Op = collections.namedtuple('Op', ['np_fn', 'jax_fn'])
rng = np.random.RandomState(seed)
np_x, jax_x = _, orig_x = random_array(rng)
ops = []
with jtu.count_primitive_compiles() as count:
for _ in range(rng.randint(5)):
op = random_op(rng, np.shape(np_x))
np_x = op.np_fn(np_x)
jax_x = op.jax_fn(jax_x)
ops.append(op)
self.assertEqual(count[0], 0)
kind = rng.choice(['closure', 'npy_value', 'force', 'add'])
if kind == 'closure':
result = api.jit(lambda x: x + jax_x)(0)
self.assertAllClose(np_x, result, check_dtypes=False)
elif kind == 'npy_value':
self.assertAllClose(np_x, jax_x, check_dtypes=False)
elif kind == 'force':
result = xla._force(jax_x)
self.assertAllClose(np_x, result, check_dtypes=False)
elif kind == 'add':
result = jax_x + np.zeros(jax_x.shape, dtype=jax_x.dtype)
self.assertAllClose(np_x, result, check_dtypes=False)
else:
assert False
@jit
def apply_ops(x):
for op in ops:
x = op.jax_fn(x)
return x
jit_result = apply_ops(orig_x)
self.assertAllClose(jit_result, np_x, check_dtypes=False)
@jit
def apply_ops_closure():
x = orig_x
for op in ops:
x = op.jax_fn(x)
return x
jit_result = apply_ops_closure()
self.assertAllClose(jit_result, np_x, check_dtypes=False)
def test_constant_forcing_computations_cached(self):
# from https://github.com/google/jax/issues/1909
xla._lazy_force_computation.cache_clear() # clear force compile cache
big_lazy_x = np.ones((api.device_count(), 100))
f = api.pmap(lambda x: 2 * x)
_ = f(big_lazy_x)
with self.count_compiles() as count:
_ = f(big_lazy_x)
self.assertEqual(count[0], 0)
def test_zeros_ones_compilation(self):
w = jnp.ones(3) + jnp.ones(3) # ensure + has a cache entry
w.block_until_ready()
xla._lazy_force_computation.cache_clear() # clear force compile cache
with self.count_compiles() as count:
x = jnp.ones(3) + jnp.zeros(3)
y = jnp.ones(3) + jnp.ones(3)
self.assertEqual(1, count[0])
self.assertAllClose(x, np.ones(3), check_dtypes=False)
self.assertAllClose(y, np.ones(3) + np.ones(3), check_dtypes=False)
class CustomJVPTest(jtu.JaxTestCase):
def test_basic(self):
@api.custom_jvp
def f(x):
return jnp.sin(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * jnp.cos(x) * g
f.defjvp(f_jvp)
x = 3.
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(api.jvp(f, (x,), (1.,)),
(jnp.sin(x), 2 * jnp.cos(x)))
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
def test_invariance(self):
@api.custom_jvp
def f(x):
return jnp.cos(2 * x) / 2.
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return (f(x), 3 * g)
f.defjvp(f_jvp)
def f2(x):
y, _ = api.jvp(f, (x,), (x,))
return y
def f3(x):
y, _ = api.jvp(f2, (x,), (x,))
return y
x = 1.
self.assertAllClose(api.jvp(f, (x,), (x,)),
api.jvp(f2, (x,), (x,)),
check_dtypes=False)
self.assertAllClose(api.jvp(f, (x,), (x,)),
api.jvp(f3, (x,), (x,)),
check_dtypes=False)
def test_python_control_flow(self):
@api.custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
if x > 0:
return f(x), 2 * g
else:
return f(x), 3 * g
f.defjvp(f_jvp)
x = 2.
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(f(-x), jnp.cos(-x))
self.assertAllClose(api.jvp(f, (x,), (1.,)),
(jnp.sin(x), 2.),
check_dtypes=False)
self.assertAllClose(api.jvp(f, (-x,), (1.,)),
(jnp.cos(-x), 3.),
check_dtypes=False)
self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False)
self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False)
def test_vmap(self):
@api.custom_jvp
def f(x):
assert jnp.ndim(x) == 0
return jnp.sin(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
assert jnp.ndim(x) == jnp.ndim(g) == 0
return f(x), 2 * jnp.cos(x) * g
f.defjvp(f_jvp)
x = jnp.arange(3.)
xx = jnp.arange(6.).reshape(2, 3)
# vmap of f
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
# vmap of jvp of f
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
(jnp.sin(x), 2 * jnp.cos(x) * x))
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
# jvp of vmap of f
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
(jnp.sin(x), 2 * jnp.cos(x) * x))
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
# vmap of jvp of vmap of f
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
def test_jit(self):
@api.custom_jvp
def f(x):
return jnp.sin(x)
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * jnp.cos(x) * g
f.defjvp(f_jvp)
x = 3.
# jit
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
# jit of jvp
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
(jnp.sin(x), 2 * jnp.cos(x) * x),
check_dtypes=False)
# jvp of jit
self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)),
(jnp.sin(x), 2 * jnp.cos(x) * x),
check_dtypes=False)
def test_pytrees(self):
@api.custom_jvp
def f(x):
return {'b': jnp.sin(x['a'])}
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']}
f.defjvp(f_jvp)
x = {'a': 3.}
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
self.assertAllClose(api.jvp(f, (x,), (x,)),
({'b': jnp.sin(x['a'])},
{'b': 2 * jnp.cos(x['a']) * x['a']}),
check_dtypes=False)
def test_kwargs(self):
# from https://github.com/google/jax/issues/1938
@api.custom_jvp
def my_fun(x, y, c=1.):
return c * (x + y)
def my_jvp(primals, tangents):
x, y, c = primals
t_x, t_y, t_c = tangents
return my_fun(x, y, c), t_c
my_fun.defjvp(my_jvp)
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
f(10., 5.) # doesn't crash
api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
def test_initial_style(self):
@api.custom_jvp
def f(x):
return 3 * x
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * g
f.defjvp(f_jvp)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.grad(foo)(3.)
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(foo))(3.)
expected = 0.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_initial_style_vmap(self):
@api.custom_jvp
def f(x):
assert jnp.ndim(x) == 0
return 3 * x
def f_jvp(primals, tangents):
x, = primals
g, = tangents
return f(x), 2 * g
f.defjvp(f_jvp)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.vmap(foo)(jnp.ones(3))
expected = 3. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
expected = 2. * jnp.ones(3)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracers_error_message(self):
def f(x):
@api.custom_jvp
def g(y):
return x + y
def g_jvp(primals, tangents):
return g(x), 2 * primals[0]
g.defjvp(g_jvp)
return g(1.)
self.assertRaises(
core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
self.assertRaises(
core.UnexpectedTracerError, lambda: api.grad(f)(3.))
def test_nondiff_arg(self):
@partial(api.custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_jvp(f, primals, tangents):
(x,), (t,) = primals, tangents
return app(f, x), 3 * t
app.defjvp(app_jvp)
ans = app(lambda x: 2 * x, 1)
expected = 2
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,))
expected = (2., 3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_jit_tracer(self):
@partial(api.custom_jvp, nondiff_argnums=(0,))
def f(x, y):
return x * y
def f_jvp(x, primals, tangents):
(y,), (t_y,) = primals, tangents
return f(x, y), 5 * t_y
f.defjvp(f_jvp)
@jit
def g(x, y):
return f(x, y)
ans = api.jvp(lambda y: g(2., y), (3.,), (1.,))
expected = (6., 5.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_vmap_axes(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
def test_pmap(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
def test_missing_jvp_rule_error_message(self):
@api.custom_jvp
def foo(x):
return x ** 2
self.assertRaisesRegex(
AttributeError,
r"No JVP defined for custom_jvp function foo using defjvp.",
lambda: foo(2))
self.assertRaisesRegex(
AttributeError,
r"No JVP defined for custom_jvp function foo using defjvp.",
lambda: api.jvp(foo, (2.,), (1.,)))
self.assertRaisesRegex(
AttributeError,
r"No JVP defined for custom_jvp function foo using defjvp.",
lambda: api.grad(foo)(2.))
def test_jvp_rule_inconsistent_pytree_structures_error_message(self):
@api.custom_jvp
def f(x):
return (x**2,)
@f.defjvp
def foo_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), [2 * x * t, x]
f(2.) # doesn't crash
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce primal and tangent outputs "
"with equal container (pytree) structures, but got "
"{} and {} respectively.".format(
tree_util.tree_structure((1,)),
tree_util.tree_structure([1, 2]))
),
lambda: api.jvp(f, (2.,), (1.,)))
def test_primal_tangent_aval_disagreement_error_message(self):
@api.custom_jvp
def f(x):
return x ** 2
@f.defjvp
def foo_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.reshape(t, (1,))
f(2.) # doesn't crash
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce primal and tangent outputs "
"with equal shapes and dtypes, but got float32[] and float32[1] "
"respectively."),
lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),)))
def test_jvp_rule_doesnt_return_pair_error_message(self):
# https://github.com/google/jax/issues/2516
@api.custom_jvp
def f(x):
return x ** 2
@f.defjvp
def foo_jvp(primals, tangents):
x, = primals
t, = tangents
return t
f(2.) # doesn't crash
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, got 1.0"),
lambda: api.jvp(f, (2.,), (1.,)))
def test_multiple_rule_invocations(self):
@jax.custom_jvp
def expit(x):
return 1 / (1 + lax.exp(-x))
@expit.defjvp
def _expit_jvp(primals, tangents):
(x,), (t,) = primals, tangents
ans = expit(x)
t_out = t * ans * (1 - ans)
return ans, t_out
def scanned_fun(c, _):
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
def foo(x):
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
return c[-1]
# just make sure these don't crash
foo(3.)
grad(foo)(3.)
grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.))
def test_hard_stuff(self):
arr = jnp.ones((5, 2, 2))
api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash
def test_hard_stuff2(self):
@jax.custom_jvp
def f(x):
return lax.tie_in(x, np.zeros(x.shape, x.dtype))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), t
# don't crash
jax.jit(jax.vmap(f))(jnp.arange(3.))
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))
def test_hard_stuff3(self):
@jax.custom_jvp
def relu(x):
return jnp.maximum(x, 0)
@relu.defjvp
def _relu_jvp(primals, tangents):
x, = primals
t, = tangents
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
def scanned_fun(c, _):
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
def f(x):
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
return c[-1]
# don't crash
jax.jit(jax.vmap(f))(jnp.arange(3.))
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),))
def test_eval_shape(self):
@jax.custom_jvp
def expit(x):
return 1 / (1 + lax.exp(-x))
@expit.defjvp
def _expit_jvp(primals, tangents):
(x,), (t,) = primals, tangents
ans = expit(x)
t_out = t * ans * (1 - ans)
return ans, t_out
# don't crash
api.eval_shape(expit, jnp.ones((2, 3)))
api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3)))
def test_jaxpr_zeros(self):
# from https://github.com/google/jax/issues/2657
@api.custom_jvp
def f(A, b):
return A @ b
def f_jvp(primals, tangents):
A, b = primals
dA, db = tangents
z = f(A, b)
dz = A @ db + dA @ b
return z, dz
f.defjvp(f_jvp)
def experiment(theta):
def step(q, _):
z = f(jnp.eye(3), jnp.ones(3) * theta)
q += z[0]
return q, q
q = 0.
q, _ = lax.scan(step, q, None, 4)
return q
grad(experiment)(1.) # doesn't crash
def test_linear_in_scan(self):
@api.custom_jvp
def f(x):
return -x
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return f(x), f(x_dot)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.grad(foo)(3.)
expected = -1.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_jvps_first_rule_is_none(self):
# https://github.com/google/jax/issues/3389
@api.custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot)
ans = grad(f, 1)(2., 3.) # doesn't crash
expected = 12.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_concurrent_initial_style(self):
# https://github.com/google/jax/issues/3843
def unroll(param, sequence):
def scan_f(prev_state, inputs):
return prev_state, jax.nn.sigmoid(param * inputs)
return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1])
def run():
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))
expected = run()
# we just don't want this to crash
n_workers = 2
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
futures = []
for _ in range(n_workers):
futures.append(e.submit(run))
results = [f.result() for f in futures]
for ans in results:
self.assertAllClose(ans, expected)
class CustomVJPTest(jtu.JaxTestCase):
def test_basic(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
x = 3.
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
self.assertAllClose(api.value_and_grad(f)(x),
(jnp.sin(x), 2 * jnp.cos(x)))
def test_invariance(self):
@api.custom_vjp
def f(x):
return jnp.cos(2 * x) / 2.
def f_fwd(x):
return (f(x), x)
def f_rev(x, g):
return (g * 3,)
f.defvjp(f_fwd, f_rev)
def f2(x):
y, _ = api.value_and_grad(f)(x)
return y
def f3(x):
y, _ = api.value_and_grad(f2)(x)
return y
x = 1.
self.assertAllClose(f(x), f2(x), check_dtypes=False)
self.assertAllClose(f(x), f3(x), check_dtypes=False)
self.assertAllClose(api.grad(f)(x), api.grad(f2)(x),
check_dtypes=False)
self.assertAllClose(api.grad(f)(x), api.grad(f3)(x),
check_dtypes=False)
def test_python_control_flow(self):
@api.custom_vjp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
def f_fwd(x):
if x > 0:
return f(x), x
else:
return f(x), x
def f_rev(x, g):
if x > 0:
return (2 * g,)
else:
return (3 * g,)
f.defvjp(f_fwd, f_rev)
x = 2.
self.assertAllClose(f(x), jnp.sin(x))
self.assertAllClose(f(-x), jnp.cos(-x))
self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.),
check_dtypes=False)
self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.),
check_dtypes=False)
def test_vmap(self):
@api.custom_vjp
def f(x):
assert jnp.ndim(x) == 0
return jnp.sin(x)
def f_fwd(x):
assert jnp.ndim(x) == 0
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
x = jnp.arange(3.)
xx = jnp.arange(6.).reshape(2, 3)
# vmap of f
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
# vmap of grad of f
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x))
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
(jnp.sin(x), 2 * jnp.cos(x)))
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx))
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
(jnp.sin(xx), 2 * jnp.cos(xx)))
# grad of vmap of f
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
2 * jnp.cos(x))
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
2 * jnp.cos(xx))
# vmap of grad of vmap of f
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
2 * jnp.cos(xx))
def test_jit(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
x = 3.
# jit
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
# jit of grad
self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x),
check_dtypes=False)
# grad of jit
self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x),
check_dtypes=False)
def test_pytrees(self):
@api.custom_vjp
def f(x):
return {'b': jnp.sin(x['a'])}
def f_fwd(x):
return f(x), {'r': jnp.cos(x['a'])}
def f_bwd(res, g):
cos_x = res['r']
return ({'a': 2 * cos_x * g['b']},)
f.defvjp(f_fwd, f_bwd)
x = {'a': 3.}
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
{'a': 2 * jnp.cos(x['a'])})
def test_jvp_error(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
self.assertRaisesRegex(
TypeError,
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
lambda: api.jvp(f, (3.,), (1.,)))
self.assertRaisesRegex(
TypeError,
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)))
def test_kwargs(self):
# from https://github.com/google/jax/issues/1938
@api.custom_vjp
def my_fun(x, y, c=1.):
return c * (x + y)
my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
lambda _, g: (g, g, g))
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
f(10., 5.) # doesn't crash
api.grad(f)(10., 5.) # doesn't crash
def test_initial_style(self):
@api.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.grad(foo)(3.)
expected = 2. * jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(api.grad(foo))(3.)
expected = -2. * jnp.sin(3.)
self.assertAllClose(ans, expected)
def test_initial_style_vmap(self):
@api.custom_vjp
def f(x):
assert jnp.ndim(x) == 0
return 3 * x
def f_fwd(x):
return f(x), jnp.cos(x)
def f_rev(cos_x, g):
return (2 * cos_x * g,)
f.defvjp(f_fwd, f_rev)
def foo(x):
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
return out
ans = api.vmap(foo)(jnp.arange(3.))
expected = 3. * jnp.arange(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
expected = 2. * jnp.cos(jnp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg(self):
@partial(api.custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return app(f, x), jnp.cos(x)
def app_rev(f, cos_x, g):
return (cos_x * g,)
app.defvjp(app_fwd, app_rev)
ans = app(lambda x: 2 * x, 1)
expected = 2
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.)
expected = (2., jnp.cos(1.))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_tracer(self):
@partial(api.custom_vjp, nondiff_argnums=(0,))
def f(x, y):
return x * y
def f_fwd(x, y):
return f(x, y), jnp.cos(y)
def f_rev(x, cos_y, g):
return (cos_y * g,)
f.defvjp(f_fwd, f_rev)
@jit
def g(x, y):
return f(x, y)
ans = g(2, 3.)
expected = 6.
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(g, 1)(2., 3.)
expected = jnp.cos(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_vmap_axes(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
def test_pmap(self):
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
def test_missing_vjp_rule_error(self):
@api.custom_vjp
def foo(x):
return x ** 2
self.assertRaisesRegex(
AttributeError,
r"No VJP defined for custom_vjp function foo using defvjp.",
lambda: foo(2))
self.assertRaisesRegex(
AttributeError,
r"No VJP defined for custom_vjp function foo using defvjp.",
lambda: api.grad(foo)(2.))
def test_vjp_rule_inconsistent_pytree_structures_error(self):
@api.custom_vjp
def f(x):
return x
def foo_fwd(x):
return x, None
def foo_bwd(_, g):
return g
f.defvjp(foo_fwd, foo_bwd)
f(2) # doesn't crash
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP rule must produce an output with the same container "
"(pytree) structure as the args tuple of the primal function, "
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"structure {} for primal input structure {}.".format(
tree_util.tree_structure(1),
tree_util.tree_structure((1,)))
),
lambda: api.grad(f)(2.))
def test_issue2511(self):
arr = jnp.ones((5, 2, 2))
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
api.jit(foo)(arr) # doesn't crash
def test_lowering_out_of_traces(self):
# https://github.com/google/jax/issues/2578
class F(collections.namedtuple("F", ["a"])):
def __call__(self, x):
return jax.nn.relu(self.a) * x
@jax.jit
def g(f, x):
return f(x)
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
def test_nondiff_argnums_stop_gradient(self):
# https://github.com/google/jax/issues/2784
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
def _clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
# return x, None
return x, (hi, )
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
def clip_gradient(x):
lo = -1
hi = x + 1 # causes things to break
return _clip_gradient(lo, hi, x)
jax.grad(clip_gradient)(1.) # doesn't crash
def test_nestable_vjp(self):
# Verify that https://github.com/google/jax/issues/3667 is resolved.
def f(x):
return x ** 2
@api.custom_vjp
def g(x):
return f(x)
def g_fwd(x):
y, f_vjp = api.vjp(f, x)
return y, f_vjp
def g_bwd(f_vjp, y_bar):
return f_vjp(y_bar)
g.defvjp(g_fwd, g_bwd)
# Check that VJP can be nested in simple situations. For this to pass,
# vjp has to return a PyTree.
_, g_vjp = api.vjp(g, 1.0)
y, = g_vjp(1.0)
self.assertAllClose(y, jnp.array(2.0))
# Check that VJP can be nested in complex situations. For this to pass,
# vjp can't treat the closed-over tracer x as a static argument.
@jit
def z(x):
_, g_vjp = api.vjp(g, x)
return g_vjp
y, = z(1.0)(3.0)
self.assertAllClose(y, jnp.array(6.0))
class InvertibleADTest(jtu.JaxTestCase):
def test_invertible_basic(self):
def f(x):
return (jnp.exp(x) * 4) * x
finv = jax.invertible(f)
x = jnp.ones((5,))
if config.omnistaging_enabled:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = mul b a
g = div e a
h = mul b g
i = div g 4.0
j = mul f 4.0
_ = log i
k = mul j i
l = add_any h k
in (l,) }
"""
else:
expected = """
{ lambda ; a b.
let c = exp a
d = mul c 4.0
e = mul d a
f = div e a
g = mul b f
h = mul b a
i = mul h 4.0
j = div f 4.0
k = mul i j
l = add_any g k
in (l,) }
"""
jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
check_dtypes=True)
def test_invertible_blocks(self):
# NB: This is the reversible ResNet block
def mk_reversible_block(f, g):
@jax.custom_ivjp
def rev_block(x1, x2):
y1 = f(x2) + x1
y2 = g(y1) + x2
return y1, y2
@rev_block.defivjp
def rev_block_ivjp(xs, ys, dys):
(y1, y2) = ys
(dy1, dy2) = dys
dgo, dx2 = dy2, dy2
go, gvjp = jax.vjp(g, y1)
dy1 += gvjp(dgo)[0]
del gvjp
x2 = y2 - go
dfo, dx1 = dy1, dy1
fo, fvjp = jax.vjp(f, x2)
dx2 += fvjp(dfo)[0]
del fvjp
x1 = y1 - fo
return (x1, x2), (dx1, dx2)
return rev_block
rev_block = mk_reversible_block(jnp.sin, jnp.cos)
def g(x1, x2):
for i in range(2):
x1, x2 = rev_block(x1, x2)
return x1, x2
def reduce(f, x1, x2):
y1, y2 = f(x1, x2)
return np.sum(y1) + np.sum(y2)
x = np.ones((1,))
# FIXME: This breaks when argnums is left as default (i.e. 0), because JVP prunes
# zero tangents from call primitives.
self.assertAllClose(jax.value_and_grad(partial(reduce, jax.invertible(g)), argnums=(0, 1))(x, x + 2),
jax.value_and_grad(partial(reduce, g), argnums=(0, 1))(x, x + 2),
check_dtypes=True)
def test_invertible_partial_diff(self):
# Check that we don't have to differentiate with respect to inputs
# of the invertible function.
def f(x, y):
return (jnp.exp(x) * 4) * x, y + 4
finv = jax.invertible(f)
o = np.ones((5,))
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x, o)[0]))(o),
jax.value_and_grad(lambda x: np.sum(finv(x, o)[0]))(o),
check_dtypes=True)
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
def test_defvjp_all(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * jnp.sin(x),)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
self.assertAllClose(grad_ans, 4 * 2 * np.sin(3.), check_dtypes=False)
def test_defvjp_all_const(self):
foo_p = Primitive('foo')
def foo(x): return foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 9., check_dtypes=False)
self.assertAllClose(grad_ans, 12.)
def test_defvjp_all_higher_order_revmode(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
ans = api.grad(api.grad(foo))(3.)
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
def test_defvjp_all_multiple_arguments(self):
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
# the first argument in one case
foo_p = Primitive('foo')
def foo(x, y): return foo_p.bind(x, y)
def vjpfun(x, y):
out = x**2 + y**3
vjp = lambda g: (g + x + y, g * x * 9.)
return out, vjp
ad.defvjp_all(foo_p, vjpfun)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
ans = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def test_defvjp_all_custom_transforms(self):
@api.custom_transforms
def foo(x):
return jnp.sin(x)
api.defvjp_all(foo, lambda x: (jnp.sin(x), lambda g: (g * x,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, np.sin(3.), check_dtypes=False)
self.assertAllClose(grad_ans, 3., check_dtypes=False)
# TODO(mattjj): add defvjp_all test with pytree arguments
def test_defvjp(self):
@api.custom_transforms
def foo(x, y):
return jnp.sin(x * y)
api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, np.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 0., check_dtypes=False)
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans_0, 0., check_dtypes=False)
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
def test_defvjp_higher_order(self):
@api.custom_transforms
def foo(x):
return jnp.sin(2. * x)
api.defvjp(foo, lambda g, _, x: g * jnp.cos(x))
ans = api.grad(api.grad(foo))(2.)
expected = api.grad(api.grad(jnp.sin))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_defvjp_use_ans(self):
@api.custom_transforms
def foo(x, y):
return jnp.sin(x * y)
api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + jnp.cos(ans))
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
self.assertAllClose(val_ans, np.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 3. * 4. + np.cos(np.sin(3. * 4)),
check_dtypes=False)
# TODO
# def test_defjvp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
# "Detected differentiation with respect to closed-over values with "
# "custom JVP rule, which isn't supported.")
# TODO
# def test_defvjp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defvjp(bar, lambda g, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: grad(foo)(1.,), ValueError,
# "Detected differentiation w.r.t. variables from outside "
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
# "defvjp_all only support differentiation w.r.t. positional arguments.")
def test_custom_transforms_eval_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = f((1, 2))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
def test_custom_transforms_jit_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = jit(f)((1, 2))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
def test_custom_transforms_jit_with_pytrees_consts(self):
# The purpose of this test is to exercise the custom_transforms default
# translation rule in how it deals with constants that are too large to be
# treated as literals (at the time of writing).
z = np.arange(10.)
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': z * b}
ans = jit(f)((1, 2))
self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
def test_custom_transforms_jvp_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
def test_custom_transforms_vmap_with_pytrees(self):
raise unittest.SkipTest("Test deprecated custom_transforms")
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = api.vmap(f)((np.arange(3), np.ones((3, 2))))
expected = {'hi': 2 * np.arange(3), 'bye': 2 * np.ones((3, 2))}
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_transforms_jvp_with_closure(self):
def f(x):
@api.custom_transforms
def g(y):
return x * y
return g(x)
ans = api.grad(f)(1.)
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_gradient(self):
@api.custom_gradient
def f(x):
return x ** 2, lambda g: (g * x,)
self.assertAllClose(f(3.), 9., check_dtypes=False)
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
def test_custom_vjp_zeros(self):
@api.custom_transforms
def f(x, y):
return 2 * x, 3 * y
def f_vjp(x, y):
return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
api.defvjp_all(f, f_vjp, )
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
def test_custom_transforms_vjp_nones(self):
core.skip_checks = True # Fails with checks
# issue raised by jsnoek@ and jumper@
@jax.custom_transforms
def solve(a, b):
return jnp.dot(jnp.linalg.inv(a), b)
# print(solve(a, b))
def solve_vjp(a, b):
x = solve(a, b)
def vjp(x_tangent):
dx = jnp.dot(solve(a, x_tangent), x.T)
out = (dx, b * 0.)
return out
return x, vjp
jax.defvjp_all(solve, solve_vjp)
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
n = 3
a_in = jnp.linspace(0, 1, n)[:, None]
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
real_x = np.random.RandomState(0).randn(n)
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
class BufferDonationTest(jtu.JaxTestCase):
# === pmap ===
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_pmap_donate_argnums_invalidates_input(self):
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
n = jax.local_device_count()
x = api.pmap(lambda x: x)(jnp.ones([n]))
y = move(x)
self.assertDeleted(x)
np.testing.assert_allclose(y, [1.] * n)
def test_pmap_nested_donate_ignored(self):
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
a = api.pmap(lambda x: x)(jnp.array([1]))
# NOTE(mattjj): stopped raising error here and instead just ignored
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
# pmap_fun(a)
pmap_fun(a) # doesn't crash
assertDeleted = lambda self, x: self._assertDeleted(x, True)
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
def _assertDeleted(self, x, deleted):
if hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())