2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2019-05-20 10:15:20 -07:00
|
|
|
import collections
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from contextlib import contextmanager
|
2019-12-11 02:48:51 +00:00
|
|
|
import copy
|
2019-07-24 21:45:56 +03:00
|
|
|
from functools import partial
|
2020-02-15 06:35:49 +01:00
|
|
|
import re
|
2019-08-09 13:12:44 -04:00
|
|
|
import unittest
|
2020-09-16 20:29:19 -07:00
|
|
|
import types
|
2019-08-22 09:22:57 -07:00
|
|
|
import warnings
|
2019-10-30 14:57:00 -07:00
|
|
|
import weakref
|
2020-08-18 10:43:52 +02:00
|
|
|
import functools
|
2018-11-21 13:20:44 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
from absl import logging
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from absl.testing import absltest, parameterized
|
2020-05-05 14:59:16 -04:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import concurrent.futures
|
2019-08-09 13:12:44 -04:00
|
|
|
|
|
|
|
import jax
|
2020-05-05 14:59:16 -04:00
|
|
|
import jax.numpy as jnp
|
2019-07-27 15:46:14 -07:00
|
|
|
from jax import jit, grad, device_put, jacfwd, jacrev, hessian
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from jax import api, core, lax, lax_reference
|
2019-12-06 22:28:41 -08:00
|
|
|
from jax.core import Primitive
|
2019-06-03 07:17:37 -07:00
|
|
|
from jax.interpreters import ad
|
2019-12-10 14:10:57 -08:00
|
|
|
from jax.interpreters import xla
|
2020-08-14 13:05:58 -07:00
|
|
|
from jax.interpreters.sharded_jit import PartitionSpec as P
|
2019-07-23 02:48:53 -07:00
|
|
|
from jax.lib import xla_bridge as xb
|
2019-05-20 10:15:20 -07:00
|
|
|
from jax import test_util as jtu
|
2019-08-21 20:36:47 -07:00
|
|
|
from jax import tree_util
|
2020-07-30 12:59:36 -07:00
|
|
|
from jax import linear_util as lu
|
2020-08-19 18:39:25 +02:00
|
|
|
from jax.lib import version
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
2019-08-22 09:22:57 -07:00
|
|
|
FLAGS = config.FLAGS
|
2018-12-12 09:00:39 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
class CPPJitTest(jtu.JaxTestCase):
|
|
|
|
"""Shared tests between the Python and the C++ jax,jit implementations.
|
|
|
|
|
|
|
|
Because the Python implementation supports more features, we need to have the
|
|
|
|
Python tests that extend the C++ tests (and not the other way around).
|
|
|
|
"""
|
|
|
|
|
|
|
|
@property
|
|
|
|
def jit(self):
|
|
|
|
# Right now, the CPP tests also test the Python code-path when jaxlib is
|
|
|
|
# too old.
|
|
|
|
# TODO(jblespiau,phawkins): Remove this when jaxlib has been released.
|
2020-08-22 03:44:52 +02:00
|
|
|
# This is in the future, because we are making a breaking change to
|
|
|
|
# Tensorflow.
|
2020-09-21 13:25:36 -07:00
|
|
|
if version < (0, 1, 56):
|
|
|
|
raise unittest.SkipTest("Disabled because it depends on some future "
|
|
|
|
"release of jax_jit.cc within jaxlib.")
|
2020-08-19 18:39:25 +02:00
|
|
|
else:
|
|
|
|
return jax.api._cpp_jit
|
|
|
|
|
|
|
|
def test_jit_of_noncallable(self):
|
|
|
|
self.assertRaisesRegex(TypeError, "Expected a callable value.*",
|
|
|
|
lambda: self.jit(3))
|
|
|
|
|
|
|
|
def test_jit_of_generator(self):
|
|
|
|
|
|
|
|
def gen(x):
|
|
|
|
yield x
|
|
|
|
|
|
|
|
self.assertRaisesRegex(TypeError,
|
|
|
|
"Expected a function, got a generator function.*",
|
|
|
|
lambda: self.jit(gen))
|
2020-08-18 10:43:52 +02:00
|
|
|
|
|
|
|
@parameterized.parameters([
|
|
|
|
# Integer support
|
|
|
|
(1, 2, 3, 4, 5),
|
|
|
|
# Numpy array support
|
|
|
|
(
|
|
|
|
np.asarray(1, np.int32),
|
|
|
|
np.asarray(2, np.int32),
|
|
|
|
np.asarray(3, np.int32),
|
|
|
|
np.asarray(4, np.int32),
|
|
|
|
np.asarray(5, np.int32),
|
|
|
|
),
|
|
|
|
])
|
|
|
|
def test_jit_static_args(self, one, two, three, four, five):
|
2018-11-17 18:03:33 -08:00
|
|
|
side = []
|
2020-09-01 09:34:47 +02:00
|
|
|
# For the CPP jit, we need to clear the cache to prevent cache hits between
|
|
|
|
# parameterized tests.
|
|
|
|
if hasattr(self.jit, "cache_clear"):
|
|
|
|
self.jit.cache_clear()
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def f(x, y, z, flag=False, flag2=False):
|
2020-08-18 10:43:52 +02:00
|
|
|
del flag2 # unused
|
2018-11-17 18:03:33 -08:00
|
|
|
assert flag
|
|
|
|
side.append(None)
|
2020-08-18 10:43:52 +02:00
|
|
|
return 100 * x + 10 * y + z
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
f1 = self.jit(f, static_argnums=(3, 4))
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f1(one, two, three, True, False) == 123
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f1(one, two, three, True, False) == 123
|
|
|
|
assert len(side) == 1 # Obvious cache hit.
|
|
|
|
assert f1(two, one, three, True, False) == 213
|
|
|
|
assert len(side) == 1 # Should cache hit because same signature.
|
|
|
|
assert f1(two, one, three, True, True) == 213
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
|
|
|
|
|
|
|
side[:] = []
|
2020-08-19 18:39:25 +02:00
|
|
|
f2 = self.jit(f, static_argnums=(0, 2, 3, 4))
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(one, two, three, True, False) == 123
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(one, three, three, True, False) == 133
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(two, two, three, True, False) == 223
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(two, four, three, True, False) == 243
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(two, four, three, True, True) == 243
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 3
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f2(two, five, three, True, True) == 253
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 3
|
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
@parameterized.parameters([
|
|
|
|
(1, 2, 3),
|
|
|
|
(
|
|
|
|
np.asarray(1, np.int32),
|
|
|
|
np.asarray(2, np.int32),
|
|
|
|
np.asarray(3, np.int32),
|
|
|
|
),
|
|
|
|
])
|
|
|
|
def test_jit_kwargs(self, one, two, three):
|
2019-04-10 22:09:14 -07:00
|
|
|
side = []
|
2020-09-01 09:34:47 +02:00
|
|
|
# For the CPP jit, we need to clear the cache to prevent cache hits between
|
|
|
|
# parameterized tests.
|
|
|
|
if hasattr(self.jit, "cache_clear"):
|
|
|
|
self.jit.cache_clear()
|
2019-04-10 22:09:14 -07:00
|
|
|
|
|
|
|
def f(x, y, z):
|
2020-08-18 10:43:52 +02:00
|
|
|
print(x, y, z)
|
2019-04-10 22:09:14 -07:00
|
|
|
side.append(None)
|
2020-08-18 10:43:52 +02:00
|
|
|
return 100 * x + 10 * y + z
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
f = self.jit(f)
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f(one, two, three) == 123
|
2019-04-10 22:09:14 -07:00
|
|
|
assert len(side) == 1
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f(one, two, three) == 123
|
2019-04-11 08:07:32 -07:00
|
|
|
assert len(side) == 1
|
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f(one, two, z=three) == 123
|
2019-04-11 08:07:32 -07:00
|
|
|
assert len(side) == 2 # actually recompiles from kwarg
|
2020-08-18 10:43:52 +02:00
|
|
|
assert f(one, two, z=three) == 123
|
2019-04-11 08:07:32 -07:00
|
|
|
assert len(side) == 2 # but should still cache
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
f(one, two, z=np.zeros(3)) # doesn't crash
|
|
|
|
if FLAGS.jax_enable_x64:
|
|
|
|
# In the above call, three is of a new type (int64), thus it should
|
|
|
|
# trigger a new compilation.
|
|
|
|
assert len(side) == 3
|
|
|
|
|
|
|
|
def test_jit_device(self):
|
|
|
|
device = xb.devices()[-1]
|
2020-08-19 18:39:25 +02:00
|
|
|
x = self.jit(lambda x: x, device=device)(3.)
|
2020-08-18 10:43:52 +02:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
|
|
self.assertEqual(x.device_buffer.device(), device)
|
|
|
|
|
|
|
|
def test_complex_support(self):
|
2020-08-19 18:39:25 +02:00
|
|
|
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2020-03-17 17:02:22 -04:00
|
|
|
def test_jit_with_many_args_works(self):
|
2020-08-18 10:43:52 +02:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
@self.jit
|
2019-09-18 17:21:57 -07:00
|
|
|
def f(args_list):
|
|
|
|
return sum(args_list)
|
|
|
|
|
2020-03-17 17:02:22 -04:00
|
|
|
self.assertEqual(f(list(range(500))), sum(range(500)))
|
2019-09-18 17:21:57 -07:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
# Jit and Donate arguments
|
|
|
|
assertDeleted = lambda self, x: self._assertDeleted(x, True)
|
|
|
|
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
|
|
|
|
|
|
|
|
def _assertDeleted(self, x, deleted):
|
|
|
|
if hasattr(x, "device_buffer"):
|
|
|
|
self.assertEqual(x.device_buffer.is_deleted(), deleted)
|
|
|
|
else:
|
|
|
|
for buffer in x.device_buffers:
|
|
|
|
self.assertEqual(buffer.is_deleted(), deleted)
|
|
|
|
|
|
|
|
def test_jit_donate_argnums_warning_raised(self):
|
|
|
|
x = jnp.array([1.0, 2.0], jnp.float32)
|
|
|
|
y = jnp.array([1, 2], jnp.int32)
|
|
|
|
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
f(x, y)
|
|
|
|
|
|
|
|
self.assertLen(w, 1)
|
|
|
|
self.assertTrue(issubclass(w[-1].category, UserWarning))
|
|
|
|
self.assertIn(
|
|
|
|
"Some donated buffers were not usable: f32[2]{0}, s32[2]{0}",
|
|
|
|
str(w[-1].message))
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
|
|
def test_jit_donate_argnums_invalidates_input(self):
|
|
|
|
# We can't just use `lambda x: x` because JAX simplifies this away to an
|
|
|
|
# empty XLA computation.
|
|
|
|
move = self.jit(lambda x: x + x - x, donate_argnums=0)
|
|
|
|
x = jnp.ones([])
|
|
|
|
y = move(x)
|
|
|
|
self.assertDeleted(x)
|
|
|
|
self.assertEqual(y, 1.)
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
|
|
def test_jit_donate_argnums_static_argnums(self):
|
|
|
|
jit_fun = self.jit(
|
|
|
|
lambda a, b, c, d: ((a + b + c), (a + b + d)),
|
|
|
|
static_argnums=(0, 1),
|
|
|
|
donate_argnums=(2, 3))
|
|
|
|
|
|
|
|
a = jnp.array(1)
|
|
|
|
b = jnp.array(2)
|
|
|
|
c = jax.device_put(jnp.array([1., 1.]))
|
|
|
|
d = jax.device_put(jnp.array([1., 1., 1.]))
|
|
|
|
e, f = jit_fun(a, b, c, d)
|
|
|
|
np.testing.assert_allclose(e, jnp.array([4., 4.]))
|
|
|
|
np.testing.assert_allclose(f, jnp.array([4., 4., 4.]))
|
|
|
|
self.assertNotDeleted(a)
|
|
|
|
self.assertNotDeleted(b)
|
|
|
|
self.assertDeleted(c)
|
|
|
|
self.assertDeleted(d)
|
|
|
|
|
2020-09-19 22:19:29 -07:00
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
2020-08-19 18:39:25 +02:00
|
|
|
def test_jnp_array_copy(self):
|
|
|
|
# https://github.com/google/jax/issues/3412
|
|
|
|
|
|
|
|
@partial(self.jit, donate_argnums=(0,))
|
|
|
|
def _test(array):
|
|
|
|
return array.at[0].set(77)
|
|
|
|
|
|
|
|
x = jnp.asarray([0, 1])
|
|
|
|
x_copy = jnp.array(x, copy=True)
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.simplefilter("ignore")
|
|
|
|
_test(x) # donation
|
|
|
|
|
|
|
|
# Gives: RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
|
|
|
|
print(x_copy) # doesn't crash
|
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
def test_jit_global_cache(self):
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
|
|
|
return x
|
2020-08-19 18:39:25 +02:00
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
python_should_be_executing = True
|
|
|
|
self.jit(f)(2)
|
|
|
|
python_should_be_executing = False
|
|
|
|
self.jit(f)(3)
|
2020-08-19 18:39:25 +02:00
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
def test_jit_shallow_copy(self):
|
|
|
|
def f(x):
|
|
|
|
return copy.copy(x)
|
|
|
|
self.jit(f)(1)
|
|
|
|
|
|
|
|
def test_jit_deep_copy(self):
|
|
|
|
def f(x):
|
|
|
|
return copy.deepcopy(x)
|
|
|
|
self.jit(f)(1)
|
|
|
|
|
|
|
|
def test_disable_jit(self):
|
|
|
|
effects = []
|
|
|
|
|
|
|
|
@self.jit
|
|
|
|
def f(x):
|
|
|
|
effects.append(1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
with api.disable_jit():
|
|
|
|
f(2)
|
|
|
|
f(2)
|
|
|
|
assert len(effects) == 2
|
|
|
|
|
|
|
|
f(2)
|
|
|
|
f(2)
|
|
|
|
assert len(effects) == 3
|
2020-08-19 18:39:25 +02:00
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
def test_static_argnum_on_method(self):
|
|
|
|
|
|
|
|
class A:
|
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
@functools.partial(self.jit, static_argnums=(0,))
|
2020-08-19 18:39:25 +02:00
|
|
|
def my_func_jit(self, x):
|
2020-08-18 10:43:52 +02:00
|
|
|
return x+2
|
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
A().my_func_jit(3)
|
2020-08-18 10:43:52 +02:00
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
def test_static_argnum_on_static_method_is_not_supported(self):
|
|
|
|
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
|
|
|
|
|
|
|
|
class A:
|
|
|
|
|
|
|
|
@functools.partial(self.jit, static_argnums=(0,))
|
|
|
|
@classmethod
|
|
|
|
def my_classmethod_jit(cls, x):
|
|
|
|
return x+2
|
|
|
|
|
|
|
|
def test_classmethod_is_not_supported(self):
|
|
|
|
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
|
|
|
|
|
|
|
|
class A:
|
|
|
|
|
|
|
|
@functools.partial(self.jit)
|
|
|
|
@staticmethod
|
|
|
|
def my_staticmethod_jit(x):
|
|
|
|
return x + 2
|
|
|
|
|
|
|
|
def test_concurrent_jit(self):
|
|
|
|
@self.jit
|
|
|
|
def f(x):
|
|
|
|
return x + x - 3.
|
|
|
|
|
|
|
|
xs = [np.random.randn(i) for i in range(10)]
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
|
|
ys = [f.result() for f in futures]
|
|
|
|
for x, y in zip(xs, ys):
|
|
|
|
self.assertAllClose(x * 2 - 3., y)
|
|
|
|
|
|
|
|
|
|
|
|
class PythonJitTest(CPPJitTest):
|
|
|
|
|
|
|
|
@property
|
|
|
|
def jit(self):
|
|
|
|
return jax.api._python_jit
|
|
|
|
|
|
|
|
def test_jit_reference_dropping(self):
|
|
|
|
x = jnp.ones(10)
|
|
|
|
f = (lambda x: lambda: x)(x) # reference to x in f's closure
|
|
|
|
g = self.jit(f)
|
|
|
|
x = weakref.ref(x) # no more strong ref to x in this scope
|
|
|
|
assert x() is not None # x is still around
|
|
|
|
f() # f runs
|
|
|
|
g() # g runs
|
|
|
|
g() # g runs a second time
|
|
|
|
del f # delete the raw callable
|
|
|
|
assert x() is not None # x is still around
|
|
|
|
g() # g still runs
|
|
|
|
del g # no more references to x
|
|
|
|
assert x() is None # x is gone
|
|
|
|
|
|
|
|
def test_trivial_computations(self):
|
|
|
|
x = jnp.array([1, 2, 3])
|
|
|
|
y = self.jit(lambda x: x)(x)
|
|
|
|
self.assertIs(x, y)
|
|
|
|
|
|
|
|
z1, z2 = self.jit(lambda x: (x, x))(x)
|
|
|
|
self.assertIs(z1, z2)
|
|
|
|
|
|
|
|
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
|
|
|
z1, z2, z3 = self.jit(lambda x, y: (y, 1, x))(x1, x2)
|
|
|
|
self.assertIs(z1, x2)
|
|
|
|
self.assertIs(z3, x1)
|
|
|
|
self.assertEqual(z2, 1)
|
|
|
|
|
|
|
|
def test_jit_bad_input(self):
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: jit(f)("foo"))
|
|
|
|
|
|
|
|
def test_jit_on_all_devices(self):
|
|
|
|
# Verifies we can run the same computation on every device present, even
|
|
|
|
# if they are, for example, different models of GPU.
|
|
|
|
data = np.random.rand(1000).astype(np.float32)
|
|
|
|
f = self.jit(jnp.negative)
|
|
|
|
for device in jax.local_devices():
|
|
|
|
x = device_put(data, device=device)
|
|
|
|
np.testing.assert_array_equal(-data, f(x))
|
|
|
|
|
2020-09-11 12:12:34 -07:00
|
|
|
def test_jit_nested_donate_ignored(self):
|
|
|
|
jit_fun = self.jit(lambda x: self.jit(lambda y: y**2, donate_argnums=0)(x))
|
|
|
|
a = jax.device_put(jnp.array(1))
|
|
|
|
|
|
|
|
# NOTE(mattjj): stopped raising error here and instead just ignored
|
|
|
|
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
|
|
|
# jit_fun(a)
|
|
|
|
|
|
|
|
jit_fun(a) # doesn't crash
|
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
|
|
|
|
class APITest(jtu.JaxTestCase):
|
|
|
|
|
2020-09-01 09:34:47 +02:00
|
|
|
def test_grad_bad_input(self):
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: grad(f)("foo"))
|
|
|
|
|
2020-08-18 10:43:52 +02:00
|
|
|
def test_grad_argnums(self):
|
|
|
|
def f(x, y, z, flag=False):
|
|
|
|
assert flag
|
|
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
|
|
|
|
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
|
|
|
|
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
|
|
|
|
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
|
|
|
|
|
|
|
|
def test_value_and_grad_argnums(self):
|
|
|
|
def f(x, y, z, flag=False):
|
|
|
|
assert flag
|
|
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
|
|
|
|
y = f(1.0, 1.0, 1.0, flag=True)
|
|
|
|
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
|
|
|
|
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
|
|
|
|
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def test_grad_of_jit(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
side.append(None)
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
assert grad(f)(1.0) == 2.0
|
|
|
|
assert len(side) == 1
|
|
|
|
assert grad(f)(2.0) == 4.0
|
|
|
|
assert len(side) == 1
|
|
|
|
|
|
|
|
def test_jit_of_grad(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
side.append(None)
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
g = jit(grad(f))
|
|
|
|
assert g(1.0) == 2.0
|
|
|
|
assert len(side) == 1
|
|
|
|
assert g(2.0) == 4.0
|
|
|
|
assert len(side) == 1
|
|
|
|
|
|
|
|
def test_bad_input(self):
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: grad(f)("foo"))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: jit(f)("foo"))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_tuple_output(self):
|
|
|
|
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_unit_output(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
jtu.check_raises(lambda: grad(lambda x: ())(np.zeros(3)), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_nonscalar_output(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
jtu.check_raises(lambda: grad(lambda x: x)(np.zeros(3)), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_unwrapped_numpy(self):
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return np.exp(x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-05-20 19:09:44 -07:00
|
|
|
with self.assertRaisesRegex(Exception, "The numpy.ndarray conversion .*"):
|
|
|
|
grad(f)(np.zeros(3))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_binop_mismatch(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
2019-08-23 17:05:32 -07:00
|
|
|
jtu.check_raises(
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: f(jnp.zeros(3), jnp.zeros(4)),
|
2019-08-23 17:05:32 -07:00
|
|
|
TypeError,
|
|
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
|
|
|
|
|
|
|
jtu.check_raises(
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: grad(f)(np.zeros(3), np.zeros(4)),
|
2019-08-23 17:05:32 -07:00
|
|
|
TypeError,
|
|
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_dot_mismatch(self):
|
|
|
|
def f(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, y)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: grad(f)(np.zeros(3), np.zeros(4)))
|
2020-04-03 21:33:32 -07:00
|
|
|
|
2020-01-27 15:44:33 -08:00
|
|
|
def test_abstract_error_message(self):
|
|
|
|
for castfun in [float, complex, int]:
|
|
|
|
def f(x):
|
|
|
|
return castfun(x)
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2020-09-15 08:06:46 -07:00
|
|
|
f"[Tt]ry using `x.astype\\({castfun.__name__}\\)`",
|
2020-01-27 15:44:33 -08:00
|
|
|
lambda: jit(f)(1.0))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_switch_value_jit(self):
|
|
|
|
def f(x):
|
|
|
|
y = x > 0
|
|
|
|
if y:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return -x
|
|
|
|
|
|
|
|
assert grad(f)(1.0) == 1.0
|
|
|
|
assert grad(f)(-1.0) == -1.0
|
2020-04-22 10:25:06 +03:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
2020-07-30 12:59:36 -07:00
|
|
|
"Abstract tracer value"):
|
2020-04-22 10:25:06 +03:00
|
|
|
jit(f)(1)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_range_err(self):
|
|
|
|
def f(x, n):
|
|
|
|
for i in range(n):
|
|
|
|
x = x + i
|
|
|
|
return x
|
|
|
|
|
|
|
|
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2020-07-30 12:59:36 -07:00
|
|
|
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
2019-11-14 16:00:55 -05:00
|
|
|
"|Abstract value passed to .*)",
|
|
|
|
lambda: jit(f)(0, 5))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_casts(self):
|
2020-04-03 21:33:32 -07:00
|
|
|
for castfun in [hex, oct, int]:
|
2018-11-17 18:03:33 -08:00
|
|
|
f = lambda x: castfun(x)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2020-07-30 12:59:36 -07:00
|
|
|
"('(?:JaxprTracer|DynamicJaxprTracer)' object cannot be interpreted as an integer"
|
2020-09-15 08:06:46 -07:00
|
|
|
"|Abstract tracer value encountered where concrete value is expected.*)", lambda: jit(f)(0))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_unimplemented_interpreter_rules(self):
|
2019-12-06 22:28:41 -08:00
|
|
|
foo_p = Primitive('foo')
|
2018-11-17 18:03:33 -08:00
|
|
|
def foo(x):
|
|
|
|
return foo_p.bind(x)
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
|
|
|
|
"Evaluation rule for 'foo' not implemented")
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
|
|
|
"Abstract evaluation for 'foo' not implemented")
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
2020-06-23 09:39:45 -07:00
|
|
|
"Differentiation rule for 'foo' not implemented")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-02-22 07:56:13 -08:00
|
|
|
foo_p.def_abstract_eval(lambda x: x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
"XLA translation rule for primitive 'foo' not found")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
foo_p.def_impl(lambda x: x)
|
2019-06-03 07:17:37 -07:00
|
|
|
ad.defjvp(foo_p, lambda g, x: foo(g))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
2020-01-15 15:00:38 -08:00
|
|
|
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 18:07:24 -08:00
|
|
|
def test_device_put_and_get(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
2019-07-27 15:46:14 -07:00
|
|
|
dx = api.device_put(x)
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(dx, xla.DeviceArray)
|
2019-07-27 15:46:14 -07:00
|
|
|
x2 = api.device_get(dx)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertIsInstance(x2, np.ndarray)
|
|
|
|
assert np.all(x == x2)
|
2018-11-21 18:07:24 -08:00
|
|
|
|
|
|
|
y = [x, (2 * x, 3 * x)]
|
2019-07-27 15:46:14 -07:00
|
|
|
dy = api.device_put(y)
|
|
|
|
y2 = api.device_get(dy)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(y2, list)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertIsInstance(y2[0], np.ndarray)
|
|
|
|
assert np.all(y2[0] == x)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(y2[1], tuple)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertIsInstance(y2[1][0], np.ndarray)
|
|
|
|
assert np.all(y2[1][0] == 2 * x)
|
|
|
|
self.assertIsInstance(y2[1][1], np.ndarray)
|
|
|
|
assert np.all(y2[1][1] == 3 * x)
|
2018-11-21 18:07:24 -08:00
|
|
|
|
2020-09-15 02:35:41 +01:00
|
|
|
def test_device_get_scalar(self):
|
|
|
|
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
|
|
|
x = api.device_put(x)
|
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
|
|
|
y = [x, 2]
|
|
|
|
y2 = api.device_get(y)
|
|
|
|
self.assertIsInstance(y2, list)
|
|
|
|
self.assertIsInstance(y2[0], np.ndarray)
|
|
|
|
assert np.all(y2[0] == x)
|
|
|
|
self.assertIsInstance(y2[1], int)
|
|
|
|
self.assertEqual(y2[1], 2)
|
|
|
|
|
2020-03-13 13:35:18 -04:00
|
|
|
@parameterized.parameters([(3,)], [(2, 0)])
|
|
|
|
def test_device_put_across_devices(self, shape):
|
|
|
|
if len(api.local_devices()) < 2:
|
2019-10-11 14:07:16 -07:00
|
|
|
raise unittest.SkipTest("this test requires multiple devices")
|
2020-03-13 13:35:18 -04:00
|
|
|
d1, d2 = api.local_devices()[:2]
|
2020-05-05 14:59:16 -04:00
|
|
|
data = np.random.randn(*shape).astype(np.float32)
|
2020-03-13 13:35:18 -04:00
|
|
|
x = api.device_put(data, device=d1)
|
2019-10-11 14:07:16 -07:00
|
|
|
self.assertEqual(x.device_buffer.device(), d1)
|
|
|
|
y = api.device_put(x, device=d2)
|
|
|
|
self.assertEqual(y.device_buffer.device(), d2)
|
2020-05-05 14:59:16 -04:00
|
|
|
np.testing.assert_array_equal(data, np.array(y))
|
2019-10-11 14:07:16 -07:00
|
|
|
# Make sure these don't crash
|
|
|
|
api.device_put(x)
|
|
|
|
api.device_put(y)
|
|
|
|
|
2019-11-25 16:23:40 -08:00
|
|
|
@jtu.skip_on_devices("cpu")
|
|
|
|
def test_device_put_across_platforms(self):
|
|
|
|
default_device = jax.devices()[0]
|
|
|
|
cpu_device = jax.devices("cpu")[0]
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
np_arr = np.array([1,2,3])
|
2019-11-25 16:23:40 -08:00
|
|
|
scalar = 1
|
2020-05-05 14:59:16 -04:00
|
|
|
device_arr = jnp.array([1,2,3])
|
2019-11-25 16:23:40 -08:00
|
|
|
assert device_arr.device_buffer.device() is default_device
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
for val in [np_arr, device_arr, scalar]:
|
2019-11-25 16:23:40 -08:00
|
|
|
x = api.device_put(val, device=cpu_device)
|
|
|
|
self.assertEqual(x.device_buffer.device(), cpu_device)
|
|
|
|
|
2020-09-15 16:08:21 -07:00
|
|
|
def test_device_put_sharded_array(self):
|
|
|
|
devices = api.local_devices()
|
|
|
|
n_devices = len(devices)
|
|
|
|
x = [np.arange(i, i + 4) for i in range(n_devices)]
|
|
|
|
y = api.device_put_sharded(x, devices)
|
|
|
|
self.assertEqual(len(y.device_buffers), len(devices))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
|
|
|
|
self.assertAllClose(y, jnp.stack(x))
|
|
|
|
|
|
|
|
def test_device_put_sharded_pytree(self):
|
|
|
|
devices = api.local_devices()
|
|
|
|
n_devices = len(devices)
|
|
|
|
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
|
|
|
|
y1, y2 = api.device_put_sharded(x, devices)
|
|
|
|
self.assertAllClose(y1, jnp.array([a for a, _ in x]))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
|
|
|
|
self.assertAllClose(y2, jnp.vstack([b for _, b in x]))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
|
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2018-12-11 16:24:20 -08:00
|
|
|
def test_jacobian(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
R = np.random.RandomState(0).randn
|
2018-12-11 16:24:20 -08:00
|
|
|
A = R(4, 3)
|
|
|
|
x = R(3)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x: jnp.dot(A, x)
|
|
|
|
assert np.allclose(jacfwd(f)(x), A)
|
|
|
|
assert np.allclose(jacrev(f)(x), A)
|
2018-12-11 16:24:20 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x: jnp.tanh(jnp.dot(A, x))
|
|
|
|
assert np.allclose(jacfwd(f)(x), jacrev(f)(x))
|
2018-12-11 16:24:20 -08:00
|
|
|
|
2019-01-07 08:54:14 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_hessian(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
R = np.random.RandomState(0).randn
|
2019-01-07 08:54:14 -08:00
|
|
|
A = R(4, 4)
|
|
|
|
x = R(4)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x: jnp.dot(x, jnp.dot(A, x))
|
|
|
|
assert np.allclose(hessian(f)(x), A + A.T)
|
2019-01-07 08:54:14 -08:00
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def test_std_basis(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
basis = api._std_basis(jnp.zeros(3))
|
2019-01-06 11:59:33 -08:00
|
|
|
assert getattr(basis, "shape", None) == (3, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
assert np.allclose(basis, np.eye(3))
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
basis = api._std_basis(jnp.zeros((3, 3)))
|
2019-01-06 11:59:33 -08:00
|
|
|
assert getattr(basis, "shape", None) == (9, 3, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
assert np.allclose(basis, np.eye(9).reshape(9, 3, 3))
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
basis = api._std_basis([0., (jnp.zeros(3), jnp.zeros((3, 4)))])
|
2019-01-06 11:59:33 -08:00
|
|
|
assert isinstance(basis, list) and len(basis) == 2
|
|
|
|
assert getattr(basis[0], "shape", None) == (16,)
|
|
|
|
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
|
|
|
|
assert getattr(basis[1][0], "shape", None) == (16, 3)
|
|
|
|
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
|
|
|
|
|
2019-01-07 08:54:14 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_jacobian_on_pytrees(self):
|
|
|
|
for jacfun in [jacfwd, jacrev]:
|
|
|
|
ans = jacfun(lambda x, y: (x, y))(0., 1.)
|
|
|
|
expected = (1., 0.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
|
|
|
|
expected = (0., 1.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
|
|
|
|
expected = ((1., 0.),
|
|
|
|
(0., 1.),)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
|
|
|
|
expected = ((1., 0., 0.),
|
|
|
|
(0., 1., 0.))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
R = np.random.RandomState(0).randn
|
2019-01-07 08:54:14 -08:00
|
|
|
x = R(2)
|
|
|
|
y = R(3)
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = jacfun(lambda x, y: {'x': x, 'xy': jnp.outer(x, y)})(x, y)
|
|
|
|
expected = {'x': np.eye(2),
|
|
|
|
'xy': np.kron(np.eye(2), y[:, None]).reshape(2, 3, 2)}
|
2019-01-07 08:54:14 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_hessian_on_pytrees(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = hessian(lambda x: jnp.array(x)**2)((1., 2.))
|
|
|
|
expected = ((np.array([2., 0.]), np.array([0., 0.])),
|
|
|
|
(np.array([0., 0.]), np.array([0., 2.])))
|
2019-01-07 08:54:14 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-09-23 13:35:52 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_issue1372(self):
|
|
|
|
def quad(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, x)
|
2019-09-23 13:35:52 -07:00
|
|
|
|
|
|
|
def f(x, u):
|
|
|
|
return quad(x) + quad(u)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x, u = jnp.ones(5), jnp.ones(2)
|
2019-09-23 13:35:52 -07:00
|
|
|
|
|
|
|
rev = jacrev
|
|
|
|
fwd = jacfwd
|
|
|
|
|
|
|
|
# Diagonal entries
|
|
|
|
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
|
|
|
|
# Off-diagonal entries by reverse-mode on the outside
|
|
|
|
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
|
|
|
|
# Off-diagonal entries by forward-mode on the outside
|
|
|
|
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
|
2019-02-06 19:44:12 -08:00
|
|
|
|
2019-02-25 13:48:01 -08:00
|
|
|
def test_large_device_constant(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = jit(lambda x: 2 * x)(jnp.ones(int(2e6))) # doesn't crash
|
|
|
|
self.assertAllClose(ans, np.ones(int(2e6)) * 2., check_dtypes=False)
|
2019-02-25 13:48:01 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def test_grad_and_aux_basic(self):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(g, grad(lambda x: x**3)(3.))
|
2019-12-09 21:18:39 -05:00
|
|
|
self.assertAllClose(aux, [9.], check_dtypes=False)
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def test_grad_and_aux_nested(self):
|
|
|
|
def f(x):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
|
|
|
return aux[0]
|
|
|
|
|
|
|
|
f2 = lambda x: x**3
|
|
|
|
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
return aux[0] * jnp.sin(x)
|
2019-03-07 14:08:02 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f2 = lambda x: x**3 * jnp.sin(x)
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
|
2019-03-07 14:48:05 -08:00
|
|
|
def test_grad_and_aux_constant(self):
|
|
|
|
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
|
|
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
|
|
self.assertEqual(aux, [4.])
|
|
|
|
|
2019-03-07 14:49:29 -08:00
|
|
|
g, aux = grad(lambda x: (x**3, [x**2, 4.]), has_aux=True)(4.)
|
|
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
|
|
self.assertEqual(aux, [4.**2, 4.])
|
|
|
|
|
2020-01-06 18:08:00 -08:00
|
|
|
def test_grad_and_aux_no_tracers(self):
|
|
|
|
# see https://github.com/google/jax/issues/1950
|
|
|
|
def f(x):
|
|
|
|
aux = dict(identity=x, p1=x+1)
|
|
|
|
return x ** 2, aux
|
|
|
|
|
|
|
|
_, aux = jax.grad(f, has_aux=True)(3.)
|
|
|
|
self.assertIsInstance(aux, dict)
|
|
|
|
for val in aux.values():
|
|
|
|
self.assertNotIsInstance(val, core.Tracer)
|
|
|
|
|
2019-11-14 15:37:33 -05:00
|
|
|
def test_jvp_mismatched_arguments(self):
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
|
|
"structure"),
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), ()))
|
2019-11-27 14:24:41 +01:00
|
|
|
# If primals and tangents must both be tuples or both lists
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
|
|
"structure"),
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(lambda x, y: x * y, (np.float32(2),), [np.float32(2)]))
|
2019-11-14 15:37:33 -05:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"primal and tangent arguments to jax.jvp must have equal types",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(lambda x: -x, (np.float16(2),), (np.float32(4),)))
|
2019-11-14 15:37:33 -05:00
|
|
|
|
2019-11-27 13:12:24 +01:00
|
|
|
def test_jvp_non_tuple_arguments(self):
|
|
|
|
def f(x, y): return x + y
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2019-11-27 14:24:41 +01:00
|
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
|
2020-01-18 08:26:23 -05:00
|
|
|
lambda: api.jvp(f, 0., (1.,)))
|
2019-11-27 13:12:24 +01:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2019-11-27 14:24:41 +01:00
|
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(f, (0.,), np.array([1., 2.])))
|
2019-11-27 13:12:24 +01:00
|
|
|
|
2019-11-14 15:37:33 -05:00
|
|
|
def test_vjp_mismatched_arguments(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
_, pullback = api.vjp(lambda x, y: x * y, np.float32(3), np.float32(4))
|
2019-11-14 15:37:33 -05:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Tree structure of cotangent input.*does not match",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: pullback((np.float32(7), np.float32(100))))
|
2019-11-14 15:37:33 -05:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Type of cotangent input to vjp pullback.*does not match type",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: pullback((np.float16(42))))
|
2019-11-14 15:37:33 -05:00
|
|
|
|
2020-01-05 04:32:48 +01:00
|
|
|
def test_jvp_jit_cached(self):
|
|
|
|
"""Bug in caching in presence of JVP and JIT."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
def inner(y):
|
|
|
|
return y * x
|
|
|
|
|
|
|
|
# Must have two calls to the inner jit (the second one hits the cache)
|
|
|
|
res1 = api.jit(inner)(4.)
|
|
|
|
res2 = api.jit(inner)(5.)
|
|
|
|
return res1 + res2
|
|
|
|
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)))
|
2020-01-05 04:32:48 +01:00
|
|
|
|
2020-09-16 20:29:19 -07:00
|
|
|
def test_linear_transpose_abstract(self):
|
|
|
|
x = types.SimpleNamespace(shape=(3,), dtype=np.float32)
|
|
|
|
y = jnp.arange(3, dtype=np.float32)
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: 2 * x, x)
|
|
|
|
z, = transpose_fun(y)
|
|
|
|
self.assertArraysEqual(2 * y, z, check_dtypes=True)
|
|
|
|
|
|
|
|
def test_linear_transpose_error(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, "linear_transpose only supports float and complex inputs"):
|
|
|
|
api.linear_transpose(lambda x: x, 1)
|
|
|
|
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
|
|
|
|
transpose_fun(1.0)
|
|
|
|
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: jnp.stack([x, x]), 1.0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
|
|
transpose_fun(1.0)
|
|
|
|
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: 1j * x, 1.0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
|
|
transpose_fun(1.0)
|
|
|
|
|
|
|
|
transpose_fun = api.linear_transpose(lambda x: x, 1.0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "cotangent type does not match"):
|
|
|
|
transpose_fun(1j)
|
|
|
|
|
|
|
|
def test_linear_transpose_complex(self):
|
|
|
|
f = lambda x: (1 + 2j) * x
|
|
|
|
transpose = api.linear_transpose(f, 1j)
|
|
|
|
actual, = transpose(3 + 4j)
|
|
|
|
expected = -5 + 10j
|
|
|
|
self.assertEqual(actual, expected)
|
2020-01-05 04:32:48 +01:00
|
|
|
|
2019-04-12 12:01:19 -07:00
|
|
|
def test_complex_grad_raises_error(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertRaises(TypeError, lambda: grad(lambda x: jnp.sin(x))(1 + 2j))
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_holomorphic_grad(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
out = grad(lambda x: jnp.sin(x), holomorphic=True)(1 + 2j)
|
2019-04-13 13:22:45 -07:00
|
|
|
expected = 2.0327230070196656 - 3.0518977991518j
|
|
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_nonholomorphic_grad(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
zs = 0.5j * np.arange(5) + np.arange(5)
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def f(z):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum(jnp.cos(jnp.abs(z)))
|
2019-04-13 13:22:45 -07:00
|
|
|
|
|
|
|
ans = grad(f)(zs)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.array([ 0. +0.j,
|
2019-04-13 13:22:45 -07:00
|
|
|
-0.80430663+0.40215331j,
|
|
|
|
-0.70368982+0.35184491j,
|
|
|
|
0.1886467 -0.09432335j,
|
|
|
|
0.86873727-0.43436864j])
|
2019-11-16 13:51:42 -05:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
|
|
atol=jtu.default_gradient_tolerance,
|
|
|
|
rtol=jtu.default_gradient_tolerance)
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_complex_output_jacrev_raises_error(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j))
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_nonholomorphic_jacrev(self):
|
2019-04-12 12:01:19 -07:00
|
|
|
# code based on https://github.com/google/jax/issues/603
|
2020-05-05 14:59:16 -04:00
|
|
|
zs = 0.5j * np.arange(5) + np.arange(5)
|
2019-04-13 13:22:45 -07:00
|
|
|
|
2019-04-12 12:01:19 -07:00
|
|
|
def f(z):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.cos(jnp.linalg.norm(2 * z))
|
2019-04-13 13:22:45 -07:00
|
|
|
|
|
|
|
ans = jacrev(f)(zs)
|
|
|
|
expected = grad(f)(zs)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-04-13 13:22:45 -07:00
|
|
|
|
|
|
|
def test_complex_input_jacfwd_raises_error(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-05-03 08:14:03 -07:00
|
|
|
def test_legacy_devicearray_repr(self):
|
|
|
|
dx = device_put(3.)
|
|
|
|
str(dx.item()) # doesn't crash
|
|
|
|
|
2019-05-02 19:27:22 -07:00
|
|
|
def test_devicearray_repr(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
x = device_put(jnp.zeros(3))
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
2019-05-02 19:27:22 -07:00
|
|
|
repr(x) # doesn't crash
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = device_put(jnp.ones(3) + 1j * jnp.ones(3))
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
2019-05-02 19:27:22 -07:00
|
|
|
repr(x) # doesn't crash
|
|
|
|
|
2019-05-30 09:48:38 -04:00
|
|
|
def test_devicearray_delete(self):
|
|
|
|
x = device_put(1.)
|
|
|
|
x.delete()
|
2020-08-03 17:04:46 +02:00
|
|
|
self.assertRaisesRegex(ValueError, "DeviceArray has been deleted.",
|
2019-11-14 16:00:55 -05:00
|
|
|
lambda: repr(x))
|
2019-05-30 09:48:38 -04:00
|
|
|
|
2019-06-03 12:05:28 -04:00
|
|
|
def test_devicearray_block_until_ready(self):
|
|
|
|
x = device_put(1.)
|
2019-09-05 10:16:20 -04:00
|
|
|
y = x.block_until_ready()
|
|
|
|
# Tests mostly that block_until_ready() does not produce an error.
|
|
|
|
self.assertTrue(y is x)
|
2019-06-03 12:05:28 -04:00
|
|
|
|
2019-05-20 10:15:20 -07:00
|
|
|
def test_namedtuple_transparency(self):
|
|
|
|
# See https://github.com/google/jax/issues/446
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
|
|
|
|
def f(pt):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sqrt(pt.x ** 2 + pt.y ** 2)
|
2019-05-20 10:15:20 -07:00
|
|
|
|
|
|
|
pt = Point(1., 2.)
|
|
|
|
|
|
|
|
f(pt) # doesn't crash
|
|
|
|
g = api.grad(f)(pt)
|
|
|
|
self.assertIsInstance(g, Point)
|
|
|
|
|
|
|
|
f_jit = api.jit(f)
|
|
|
|
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
|
|
|
|
|
2019-06-03 07:22:32 -07:00
|
|
|
def test_namedtuple_subclass_transparency(self):
|
|
|
|
# See https://github.com/google/jax/issues/806
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
|
|
|
|
class ZeroPoint(Point):
|
|
|
|
def is_zero(self):
|
|
|
|
return (self.x == 0) and (self.y == 0)
|
|
|
|
|
|
|
|
pt = ZeroPoint(0., 0.)
|
|
|
|
|
|
|
|
def f(pt):
|
2020-05-05 14:59:16 -04:00
|
|
|
return 0. if pt.is_zero() else jnp.sqrt(pt.x ** 2 + pt.y ** 2)
|
2019-06-03 07:22:32 -07:00
|
|
|
|
|
|
|
f(pt) # doesn't crash
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = api.grad(f)(pt)
|
2019-06-03 07:22:32 -07:00
|
|
|
self.assertIsInstance(pt, ZeroPoint)
|
|
|
|
|
2020-02-11 14:11:48 +00:00
|
|
|
@parameterized.parameters(1, 2, 3)
|
|
|
|
def test_shape_dtype_struct(self, i):
|
2020-05-05 14:59:16 -04:00
|
|
|
s = api.ShapeDtypeStruct(shape=(i, 2, 3), dtype=jnp.float32)
|
2020-02-11 14:11:48 +00:00
|
|
|
self.assertEqual(s.shape, (i, 2, 3))
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertEqual(s.dtype, jnp.float32)
|
2020-02-11 14:11:48 +00:00
|
|
|
self.assertEqual(s.ndim, 3)
|
|
|
|
self.assertEqual(s.size, i * 2 * 3)
|
|
|
|
self.assertLen(s, i)
|
|
|
|
for f in (str, repr):
|
|
|
|
self.assertEqual(
|
|
|
|
f(s), "ShapeDtypeStruct(shape=({}, 2, 3), dtype=float32)".format(i))
|
|
|
|
|
|
|
|
def test_shape_dtype_struct_scalar(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
s = api.ShapeDtypeStruct(shape=(), dtype=jnp.float32)
|
2020-02-11 14:11:48 +00:00
|
|
|
self.assertEmpty(s.shape)
|
|
|
|
self.assertEqual(s.size, 1)
|
|
|
|
self.assertEqual(s.ndim, 0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "len[(][)] of unsized object"):
|
|
|
|
_ = len(s)
|
|
|
|
|
2019-06-01 09:34:33 -07:00
|
|
|
def test_eval_shape(self):
|
|
|
|
def fun(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
2019-06-01 09:34:33 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.ones((2, 3))
|
|
|
|
y = jnp.ones((3, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_constants(self):
|
|
|
|
def fun():
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.ones((2, 3))
|
|
|
|
y = jnp.ones((3, 4))
|
|
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
out_shape = api.eval_shape(fun)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_tuple_unpacking(self):
|
|
|
|
def fun(x, y):
|
|
|
|
a, b = x
|
|
|
|
return a + b + y
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
2019-06-01 09:34:33 -07:00
|
|
|
y = 3.
|
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_tuple_itemgetting(self):
|
|
|
|
def fun(x, y):
|
|
|
|
return x[0] + x[1] + y
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
2019-06-01 09:34:33 -07:00
|
|
|
y = 3.
|
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_output_dict(self):
|
2019-06-01 09:48:28 -07:00
|
|
|
def fun(x, y):
|
2019-06-01 09:34:33 -07:00
|
|
|
return {'hi': x[0] + x[1] + y}
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = (jnp.ones(2), jnp.ones(2))
|
2019-06-01 09:34:33 -07:00
|
|
|
y = 3.
|
2019-06-01 09:48:28 -07:00
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
2020-05-05 14:59:16 -04:00
|
|
|
out_shape = tree_util.tree_map(np.shape, out_shape)
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
self.assertEqual(out_shape, {'hi': (2,)})
|
|
|
|
|
|
|
|
def test_eval_shape_shape_error(self):
|
|
|
|
def fun(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.tanh(jnp.dot(x, y) + 3.)
|
2019-06-01 09:34:33 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.ones((3, 3))
|
|
|
|
y = jnp.ones((4, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
|
|
|
|
|
2019-06-01 09:48:28 -07:00
|
|
|
def test_eval_shape_duck_typing(self):
|
|
|
|
def fun(A, b, x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(A, x) + b
|
2019-06-01 09:48:28 -07:00
|
|
|
|
|
|
|
class MyArgArray(object):
|
|
|
|
def __init__(self, shape, dtype):
|
|
|
|
self.shape = shape
|
|
|
|
self.dtype = dtype
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
A = MyArgArray((3, 4), jnp.float32)
|
|
|
|
b = MyArgArray((5,), jnp.float32)
|
|
|
|
x = MyArgArray((4, 5), jnp.float32)
|
2019-06-01 09:48:28 -07:00
|
|
|
out_shape = api.eval_shape(fun, A, b, x)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (3, 5))
|
2019-06-01 09:48:28 -07:00
|
|
|
|
2019-06-18 09:18:44 -07:00
|
|
|
def test_issue_871(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
T = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
|
|
|
|
x = jnp.array([1, 2, 3])
|
2019-06-18 09:18:44 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
y, f_jvp = api.linearize(jnp.sum, x)
|
2019-06-18 09:18:44 -07:00
|
|
|
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
|
|
|
("linearized function called on tangent values "
|
|
|
|
"inconsistent with the original primal values."))
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
y, f_jvp = api.linearize(api.jit(jnp.sum), x)
|
2019-06-18 09:18:44 -07:00
|
|
|
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
|
|
|
("linearized function called on tangent values "
|
|
|
|
"inconsistent with the original primal values."))
|
|
|
|
|
2019-06-18 21:23:52 -07:00
|
|
|
def test_partial_eval_lower(self):
|
|
|
|
# this is a simplified model of a bug that arose when we first used @jit in
|
|
|
|
# a jvp rule. it's in this file because we want to use make_jaxpr.
|
2019-12-31 10:38:45 -08:00
|
|
|
|
|
|
|
# NOTE(mattjj): I no longer understand what this was meant to test. My guess
|
|
|
|
# is it was related to staging out the broadcast into a jaxpr to be
|
|
|
|
# transposed, but after #1749 that's no longer a problem. After changing
|
|
|
|
# make_jaxpr (and jit) to stage out sub-calls fully, this test started to
|
|
|
|
# fail; I left it in as skipped because deleting tests feels wrong.
|
|
|
|
raise unittest.SkipTest("obsolete test")
|
|
|
|
|
2019-06-18 21:23:52 -07:00
|
|
|
@api.jit
|
|
|
|
def f(a, b, c):
|
|
|
|
a = lax.broadcast(a, (2,))
|
|
|
|
return lax.select(a, b, c)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
a = np.ones((3, 3), dtype=np.bool_)
|
|
|
|
b = np.ones((2, 3, 3))
|
|
|
|
c = np.ones((2, 3, 3))
|
2019-06-18 21:23:52 -07:00
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c)
|
2020-02-05 15:38:25 +01:00
|
|
|
subjaxpr = next(eqn.params["call_jaxpr"] for eqn in jaxpr.jaxpr.eqns
|
|
|
|
if "call_jaxpr" in eqn.params)
|
2019-06-18 21:23:52 -07:00
|
|
|
self.assertEqual(len(subjaxpr.eqns), 1)
|
|
|
|
|
2019-06-24 10:45:42 -04:00
|
|
|
def test_grad_of_int_errors(self):
|
|
|
|
dfn = grad(lambda x: x ** 2)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2020-05-19 15:17:03 -07:00
|
|
|
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
|
|
|
|
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
|
|
|
|
lambda: dfn(3))
|
|
|
|
|
|
|
|
def test_grad_complex_result_errors(self):
|
|
|
|
dfn = grad(lambda x: x ** 2 + 1j)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
(r"grad requires real-valued outputs \(output dtype that is a "
|
|
|
|
r"sub-dtype of np.floating\), but got complex.*"),
|
|
|
|
lambda: dfn(3.))
|
|
|
|
|
|
|
|
def test_holomorphic_grad_of_float_errors(self):
|
|
|
|
dfn = grad(lambda x: x ** 2, holomorphic=True)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
(r"grad with holomorphic=True requires inputs with complex dtype, "
|
|
|
|
r"but got float.*"),
|
|
|
|
lambda: dfn(3.))
|
|
|
|
|
|
|
|
def test_holomorphic_jacrev_of_float_errors(self):
|
|
|
|
dfn = jacrev(lambda x: x ** 2, holomorphic=True)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
(r"jacrev with holomorphic=True requires inputs with complex dtype, "
|
|
|
|
r"but got float.*"),
|
|
|
|
lambda: dfn(3.))
|
|
|
|
|
|
|
|
def test_holomorphic_jacfwd_of_float_errors(self):
|
|
|
|
dfn = jacfwd(lambda x: x ** 2, holomorphic=True)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
(r"jacfwd with holomorphic=True requires inputs with complex dtype, "
|
|
|
|
r"but got float.*"),
|
|
|
|
lambda: dfn(3.))
|
|
|
|
|
|
|
|
def test_jacfwd_of_complex_errors(self):
|
|
|
|
dfn = jacfwd(lambda x: x ** 2)
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
(r"jacfwd requires real-valued inputs \(input dtype that is a "
|
|
|
|
r"sub-dtype of np.floating\), but got complex.*"),
|
|
|
|
lambda: dfn(3. + 1j))
|
2019-06-24 10:45:42 -04:00
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
def test_xla_computation(self):
|
|
|
|
# these tests basically check the examples in the xla_computation docstring
|
|
|
|
|
2020-05-08 14:00:34 -07:00
|
|
|
def e(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(jnp.cos(x))
|
2020-05-08 14:00:34 -07:00
|
|
|
c = api.xla_computation(e)(2.)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertIn('cosine', c.as_hlo_text())
|
|
|
|
self.assertIn('sine', c.as_hlo_text())
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return x - lax.psum(x, 'i')
|
|
|
|
axis_env = [('i', 4)]
|
|
|
|
c = api.xla_computation(f, axis_env=axis_env)(2)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text())
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
|
|
|
def g(x):
|
|
|
|
rowsum = lax.psum(x, 'i')
|
|
|
|
colsum = lax.psum(x, 'j')
|
|
|
|
allsum = lax.psum(x, ('i', 'j'))
|
|
|
|
return rowsum, colsum, allsum
|
|
|
|
axis_env = [('i', 4), ('j', 2)]
|
|
|
|
c = api.xla_computation(g, axis_env=axis_env)(5.)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text())
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2020-05-08 14:00:34 -07:00
|
|
|
def h(x):
|
|
|
|
rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]])
|
|
|
|
colsum = lax.psum(x, 'j')
|
|
|
|
return rowsum, colsum
|
|
|
|
axis_env = [('i', 4), ('j', 2)]
|
|
|
|
c = api.xla_computation(h, axis_env=axis_env)(5.)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertIn('all-reduce', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text())
|
|
|
|
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text())
|
2020-05-08 14:00:34 -07:00
|
|
|
|
2019-09-27 17:37:44 -07:00
|
|
|
def test_xla_computation_args(self):
|
|
|
|
def foo(x, y, z):
|
|
|
|
return x + y + z
|
|
|
|
|
|
|
|
c = api.xla_computation(foo)(1., 2., 3.)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
2019-09-27 17:37:44 -07:00
|
|
|
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
2020-05-11 17:43:55 -04:00
|
|
|
param_shapes = c.program_shape().parameter_shapes()
|
2019-09-27 17:37:44 -07:00
|
|
|
self.assertEqual(len(param_shapes), 1)
|
|
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
|
2020-03-30 11:31:29 -07:00
|
|
|
def test_xla_computation_duck_typing(self):
|
|
|
|
def foo(x, y, z):
|
|
|
|
return x + y + z
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jax.ShapeDtypeStruct((), np.float32)
|
|
|
|
y = jax.ShapeDtypeStruct((), np.float32)
|
|
|
|
z = jax.ShapeDtypeStruct((), np.float32)
|
2020-03-30 11:31:29 -07:00
|
|
|
|
|
|
|
c = api.xla_computation(foo)(x, y, z)
|
2020-05-11 17:43:55 -04:00
|
|
|
self.assertEqual(len(c.program_shape().parameter_shapes()), 3)
|
2020-03-30 11:31:29 -07:00
|
|
|
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
2020-05-11 17:43:55 -04:00
|
|
|
param_shapes = c.program_shape().parameter_shapes()
|
2020-03-30 11:31:29 -07:00
|
|
|
self.assertEqual(len(param_shapes), 1)
|
|
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
|
2019-07-09 15:12:02 -07:00
|
|
|
def test_staging_out_multi_replica(self):
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return api.pmap(jnp.mean)(x)
|
2019-07-09 15:12:02 -07:00
|
|
|
xla_comp = api.xla_computation(f)
|
2020-05-11 17:43:55 -04:00
|
|
|
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
|
2019-07-09 15:12:02 -07:00
|
|
|
|
2019-12-04 09:50:29 -08:00
|
|
|
def test_xla_computation_instantiate_constant_outputs(self):
|
|
|
|
def f():
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.zeros((3, 4))
|
2019-12-04 09:50:29 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
xla_comp = api.xla_computation(f)()
|
|
|
|
else:
|
|
|
|
xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
|
2020-05-11 17:43:55 -04:00
|
|
|
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
|
2019-12-04 09:50:29 -08:00
|
|
|
self.assertEqual(out_shape.dimensions(), (3, 4))
|
|
|
|
|
2020-04-23 18:07:51 -07:00
|
|
|
def test_xla_computation_static_argnums(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
|
2020-08-01 00:15:51 +02:00
|
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
|
|
self.assertIn("constant(3)", hlo_text)
|
|
|
|
# The static arguments should be removed from the function being compiled,
|
|
|
|
# thus the function should have only a single argument.
|
|
|
|
self.assertIn("parameter.1", hlo_text)
|
|
|
|
self.assertNotIn("parameter.2", hlo_text)
|
2020-04-23 18:07:51 -07:00
|
|
|
|
2020-07-23 19:38:56 -07:00
|
|
|
def test_xla_computation_return_shape(self):
|
|
|
|
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
|
|
|
|
return_shape=True)(np.int32(1))
|
|
|
|
expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32),
|
|
|
|
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
|
|
|
self.assertEqual(shape_tree, expected)
|
|
|
|
|
2020-08-14 13:05:58 -07:00
|
|
|
def test_xla_computation_partitioned(self):
|
|
|
|
def f(x, y):
|
|
|
|
return jnp.dot(x, y) + 1
|
|
|
|
|
|
|
|
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
|
|
|
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
|
|
|
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
|
|
|
|
out_parts=P(4, 1))(x, y)
|
|
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
|
|
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
|
|
|
self.assertIn('sharding={replicated}', hlo_text)
|
|
|
|
self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text)
|
|
|
|
|
|
|
|
def test_xla_computation_replicated_and_partitioned(self):
|
|
|
|
def f(x, y):
|
|
|
|
return jnp.dot(x, y), lax.psum(x, 'i')
|
|
|
|
|
|
|
|
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
|
|
|
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
|
|
|
axis_env = [('i', 4)]
|
|
|
|
xla_comp = api.xla_computation(f, axis_env=axis_env,
|
|
|
|
in_parts=(P(2, 2), None),
|
|
|
|
out_parts=(P(4, 1), None))(x, y)
|
|
|
|
hlo_text = xla_comp.as_hlo_text()
|
|
|
|
self.assertIn('all-reduce', hlo_text)
|
|
|
|
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
|
|
|
|
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
|
|
|
self.assertIn('sharding={replicated}', hlo_text)
|
|
|
|
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
|
|
|
|
|
2020-08-16 20:00:40 -07:00
|
|
|
def test_xla_computation_psum_constant(self):
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise unittest.SkipTest("test requires omnistaging")
|
|
|
|
f = lambda: jax.lax.psum(1, "i")
|
|
|
|
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
|
|
|
|
2020-09-18 19:54:37 -07:00
|
|
|
@jtu.skip_on_devices("cpu", "gpu")
|
|
|
|
def test_xla_computation_donate_argnums(self):
|
|
|
|
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
def test_concurrent_device_get_and_put(self):
|
|
|
|
def f(x):
|
|
|
|
for _ in range(100):
|
|
|
|
y = jax.device_put(x)
|
|
|
|
x = jax.device_get(y)
|
|
|
|
return x
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
xs = [np.random.randn(i) for i in range(10)]
|
2019-08-09 13:12:44 -04:00
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
|
|
ys = [f.result() for f in futures]
|
|
|
|
for x, y in zip(xs, ys):
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(x, y)
|
2019-08-09 13:12:44 -04:00
|
|
|
|
2019-08-24 12:34:44 -07:00
|
|
|
def test_dtype_warning(self):
|
|
|
|
# cf. issue #1230
|
2019-08-22 09:22:57 -07:00
|
|
|
if FLAGS.jax_enable_x64:
|
|
|
|
return # test only applies when x64 is disabled
|
|
|
|
|
|
|
|
def check_warning(warn, nowarn):
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
|
|
|
nowarn() # get rid of extra startup warning
|
|
|
|
|
|
|
|
prev_len = len(w)
|
|
|
|
nowarn()
|
|
|
|
assert len(w) == prev_len
|
|
|
|
|
|
|
|
warn()
|
|
|
|
assert len(w) > 0
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
expected_prefix = "Explicitly requested dtype "
|
|
|
|
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
|
|
|
|
|
|
|
|
prev_len = len(w)
|
|
|
|
nowarn()
|
|
|
|
assert len(w) == prev_len
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
check_warning(lambda: jnp.array([1, 2, 3], dtype="float64"),
|
|
|
|
lambda: jnp.array([1, 2, 3], dtype="float32"),)
|
|
|
|
check_warning(lambda: jnp.ones(3, dtype=np.float64),
|
|
|
|
lambda: jnp.ones(3))
|
|
|
|
check_warning(lambda: jnp.ones_like(3, dtype=np.int64),
|
|
|
|
lambda: jnp.ones_like(3, dtype=np.int32))
|
|
|
|
check_warning(lambda: jnp.zeros(3, dtype="int64"),
|
|
|
|
lambda: jnp.zeros(3, dtype="int32"))
|
|
|
|
check_warning(lambda: jnp.zeros_like(3, dtype="float64"),
|
|
|
|
lambda: jnp.zeros_like(3, dtype="float32"))
|
|
|
|
check_warning(lambda: jnp.full((2, 3), 1, dtype="int64"),
|
|
|
|
lambda: jnp.full((2, 3), 1))
|
|
|
|
check_warning(lambda: jnp.ones(3).astype("float64"),
|
|
|
|
lambda: jnp.ones(3).astype("float32"))
|
|
|
|
check_warning(lambda: jnp.eye(3, dtype=np.float64),
|
|
|
|
lambda: jnp.eye(3))
|
|
|
|
check_warning(lambda: jnp.arange(3, dtype=np.float64),
|
|
|
|
lambda: jnp.arange(3, dtype=np.float32))
|
|
|
|
check_warning(lambda: jnp.linspace(0, 3, dtype=np.float64),
|
|
|
|
lambda: jnp.linspace(0, 3, dtype=np.float32))
|
|
|
|
check_warning(lambda: jnp.tri(2, dtype="float64"),
|
|
|
|
lambda: jnp.tri(2, dtype="float32"))
|
2019-08-22 09:22:57 -07:00
|
|
|
|
2020-06-30 05:16:02 +01:00
|
|
|
def test_vmap_preserves_docstr(self):
|
|
|
|
def superfun(a):
|
|
|
|
"""Does things with stuff."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
self.assertRegex(api.vmap(superfun).__doc__, "\n".join([
|
|
|
|
"Vectorized version of superfun.*",
|
|
|
|
"",
|
|
|
|
"Original documentation:",
|
|
|
|
"",
|
|
|
|
superfun.__doc__,
|
|
|
|
]))
|
|
|
|
|
2020-03-22 19:50:06 +01:00
|
|
|
def test_vmap_in_axes_list(self):
|
|
|
|
# https://github.com/google/jax/issues/2367
|
2020-05-05 14:59:16 -04:00
|
|
|
dictionary = {'a': 5., 'b': jnp.ones(2)}
|
|
|
|
x = jnp.zeros(3)
|
|
|
|
y = jnp.arange(3.)
|
2020-03-22 19:50:06 +01:00
|
|
|
|
|
|
|
|
|
|
|
def f(dct, x, y):
|
|
|
|
return dct['a'] + dct['b'] + x + y
|
|
|
|
|
|
|
|
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
|
|
|
|
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(out1, out2)
|
2020-03-22 19:50:06 +01:00
|
|
|
|
2019-10-28 14:03:52 -07:00
|
|
|
def test_vmap_in_axes_tree_prefix_error(self):
|
|
|
|
# https://github.com/google/jax/issues/795
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-10-31 13:04:12 -07:00
|
|
|
ValueError,
|
2020-06-30 22:19:16 -07:00
|
|
|
"vmap in_axes specification must be a tree prefix of the corresponding "
|
|
|
|
r"value, got specification \(0, 0\) for value tree "
|
2019-11-14 16:00:55 -05:00
|
|
|
r"PyTreeDef\(tuple, \[\*\]\).",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(jnp.ones(3))
|
2019-10-31 13:04:12 -07:00
|
|
|
)
|
2019-10-28 14:03:52 -07:00
|
|
|
|
2020-05-21 08:00:18 -07:00
|
|
|
def test_vmap_in_axes_leaf_types(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, r"vmap in_axes must be an int, None, or .*"):
|
|
|
|
api.vmap(lambda x: x, in_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
|
|
|
|
|
|
|
|
def test_vmap_out_axes_leaf_types(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
TypeError, r"vmap out_axes must be an int, None, or .*"):
|
|
|
|
api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.]))
|
|
|
|
|
2019-10-31 11:57:37 -07:00
|
|
|
def test_vmap_unbatched_object_passthrough_issue_183(self):
|
2019-10-28 15:20:49 -07:00
|
|
|
# https://github.com/google/jax/issues/183
|
|
|
|
fun = lambda f, x: f(x)
|
|
|
|
vfun = api.vmap(fun, (None, 0))
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = vfun(lambda x: x + 1, jnp.arange(3))
|
|
|
|
self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False)
|
2019-10-28 15:20:49 -07:00
|
|
|
|
2019-10-31 11:57:37 -07:00
|
|
|
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
|
2019-10-30 17:31:37 -07:00
|
|
|
# https://github.com/google/jax/issues/705
|
|
|
|
def h(a, b):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum(a) + jnp.sum(b)
|
2019-10-30 17:31:37 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
X = np.random.randn(10, 4)
|
|
|
|
U = np.random.randn(10, 2)
|
2019-10-31 13:20:32 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-30 17:31:37 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
2019-10-31 12:01:37 -07:00
|
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
2019-10-30 17:31:37 -07:00
|
|
|
"so\n"
|
|
|
|
"arg 0 has an axis to be mapped of size 10\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
|
|
api.vmap(h, in_axes=(0, 1))(X, U)
|
2019-10-30 17:31:37 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-31 13:20:32 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
|
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
|
|
|
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
"so\n"
|
|
|
|
"args 0, 2 have axes to be mapped of size 10\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
|
|
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
|
2019-10-31 13:20:32 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-30 17:31:37 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
2019-10-31 11:57:37 -07:00
|
|
|
"the tree of axis sizes is:\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
r"\(10, \[2, 2\]\)"):
|
|
|
|
api.vmap(h, in_axes=(0, 1))(X, [U, U])
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# The mapped inputs cannot be scalars
|
|
|
|
api.vmap(lambda x: x)(1.)
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2020-05-08 17:58:02 -07:00
|
|
|
ValueError, "vmap must have at least one non-None value in in_axes"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# If the output is mapped, there must be a non-None in_axes
|
2020-05-05 14:59:16 -04:00
|
|
|
api.vmap(lambda x: x, in_axes=None)(jnp.array([1., 2.]))
|
2020-03-28 16:50:31 +01:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
|
2020-05-05 14:59:16 -04:00
|
|
|
api.vmap(lambda x: x, in_axes=1)(jnp.array([1., 2.]))
|
2020-03-28 16:50:31 +01:00
|
|
|
|
|
|
|
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2020-06-30 22:19:16 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap out_axes specification must be a tree prefix of the "
|
|
|
|
"corresponding value.*"):
|
2020-05-05 14:59:16 -04:00
|
|
|
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(jnp.array([1., 2.]))
|
2020-03-28 16:50:31 +01:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap has mapped output but out_axes is None"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# If the output is mapped, then there must be some out_axes specified
|
2020-05-05 14:59:16 -04:00
|
|
|
api.vmap(lambda x: x, out_axes=None)(jnp.array([1., 2.]))
|
2020-03-28 16:50:31 +01:00
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
def test_vmap_structured_in_axes(self):
|
|
|
|
|
|
|
|
A, B, C, D = 2, 3, 4, 5
|
|
|
|
K = 6 # batch size
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.ones((K, A, B)) # batch axis in different locations
|
|
|
|
y = np.ones((B, K, C))
|
|
|
|
z = np.ones((C, D, K))
|
2019-10-31 14:09:12 -07:00
|
|
|
|
|
|
|
def foo(tree_arg):
|
|
|
|
x, (y, z) = tree_arg
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, jnp.dot(y, z))
|
2019-10-31 14:09:12 -07:00
|
|
|
|
|
|
|
tree = (x, (y, z))
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, (1, 2)),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
tree = (x, Point(y, z))
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, Point(1, 2)),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
def foo(tree_arg):
|
|
|
|
x, dct = tree_arg
|
|
|
|
y, z = dct['a'], dct['b']
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, jnp.dot(y, z))
|
2019-10-31 14:09:12 -07:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
tree = (x, {'a': y, 'b': z})
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, {'a': 1, 'b': 2}),))
|
2019-10-31 14:09:12 -07:00
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
tree = (x, collections.OrderedDict([('a', y), ('b', z)]))
|
|
|
|
vfoo = api.vmap(
|
|
|
|
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
2019-10-30 14:57:00 -07:00
|
|
|
def test_pmap_global_cache(self):
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
|
|
|
return x
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.ones(1)
|
2019-10-30 14:57:00 -07:00
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
api.pmap(f)(x)
|
|
|
|
python_should_be_executing = False
|
|
|
|
api.pmap(f)(x)
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
api.pmap(f, 'i')(x)
|
|
|
|
python_should_be_executing = False
|
|
|
|
api.pmap(f, 'i')(x)
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
def test_device_array_repr(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
rep = repr(jnp.ones(()) + 1.)
|
2019-11-12 06:18:43 -08:00
|
|
|
self.assertStartsWith(rep, 'DeviceArray')
|
2019-06-01 09:34:33 -07:00
|
|
|
|
2019-11-14 21:18:23 -08:00
|
|
|
def test_grad_without_enough_args_error_message(self):
|
|
|
|
# https://github.com/google/jax/issues/1696
|
|
|
|
def f(x, y): return x + y
|
|
|
|
df = api.grad(f, argnums=0)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 21:18:23 -08:00
|
|
|
TypeError,
|
|
|
|
"differentiating with respect to argnums=0 requires at least 1 "
|
|
|
|
"positional arguments to be passed by the caller, but got only 0 "
|
|
|
|
"positional arguments.",
|
|
|
|
lambda: partial(df, x=0.)(y=1.))
|
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
def test_grad_of_jit_compilation_caching(self):
|
|
|
|
if not hasattr(self, "assertLogs"):
|
|
|
|
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
|
|
|
|
|
|
|
lax.add(1, 2) # make sure some initial warnings are already printed
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
sin = api.jit(jnp.sin)
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2019-11-26 17:06:57 -08:00
|
|
|
prev_level = logging.get_verbosity()
|
|
|
|
try:
|
|
|
|
logging.set_verbosity('DEBUG')
|
|
|
|
with self.assertLogs(level=logging.DEBUG) as l:
|
|
|
|
ans1 = api.grad(sin)(2.)
|
|
|
|
ans2 = api.grad(sin)(3.)
|
|
|
|
finally:
|
|
|
|
logging.set_verbosity(prev_level)
|
2019-11-26 07:56:48 -08:00
|
|
|
self.assertLen(l.output, 2)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
|
|
|
|
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2020-06-15 18:42:53 -07:00
|
|
|
def test_trivial_computations(self):
|
|
|
|
x = jnp.array([1, 2, 3])
|
|
|
|
y = api.jit(lambda x: x)(x)
|
|
|
|
self.assertIs(x, y)
|
|
|
|
|
|
|
|
z1, z2 = api.jit(lambda x: (x, x))(x)
|
|
|
|
self.assertIs(z1, z2)
|
|
|
|
|
|
|
|
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
|
|
|
|
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
|
|
|
self.assertIs(z1, x2)
|
|
|
|
self.assertIs(z3, x1)
|
|
|
|
self.assertEqual(z2, 1)
|
|
|
|
|
|
|
|
def test_nested_jit_hoisting(self):
|
|
|
|
@api.jit
|
|
|
|
def f(x, y):
|
|
|
|
z = 2 * x
|
|
|
|
return y + z, 3
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def g(x):
|
|
|
|
return f(2, x)
|
|
|
|
|
|
|
|
jaxpr_subcomp = xla.jaxpr_subcomp
|
|
|
|
|
|
|
|
jaxprs = []
|
|
|
|
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
|
|
|
jaxprs.append(jaxpr)
|
|
|
|
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
|
|
|
|
|
|
|
try:
|
|
|
|
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
|
|
|
ans = g(3)
|
|
|
|
finally:
|
|
|
|
xla.jaxpr_subcomp = jaxpr_subcomp
|
|
|
|
|
|
|
|
self.assertEqual(ans, (7, 3))
|
|
|
|
self.assertLen(jaxprs, 2)
|
|
|
|
outer_jaxpr, inner_jaxpr = jaxprs
|
|
|
|
|
|
|
|
self.assertLen(outer_jaxpr.eqns, 1)
|
|
|
|
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
|
|
|
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
|
|
|
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
|
|
|
self.assertLen(inner_jaxpr.eqns, 2)
|
|
|
|
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
|
|
|
|
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
|
|
|
|
|
|
|
|
def test_primitive_compilation_cache(self):
|
|
|
|
with jtu.count_primitive_compiles() as count:
|
|
|
|
lax.add(1, 2)
|
|
|
|
lax.add(2, 3)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
|
|
|
def test_arange_jit(self):
|
|
|
|
# see https://github.com/google/jax/issues/553
|
|
|
|
def fun(x):
|
|
|
|
r = jnp.arange(x.shape[0])[x]
|
|
|
|
return r
|
|
|
|
|
|
|
|
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
|
|
|
|
|
|
|
|
def helper_save_tracer(self, x):
|
|
|
|
self._saved_tracer = x
|
|
|
|
return x
|
|
|
|
|
2020-09-16 23:59:58 -07:00
|
|
|
def test_escaped_tracers_different_top_level_traces(self):
|
2020-06-15 18:42:53 -07:00
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
2020-09-16 23:59:58 -07:00
|
|
|
core.UnexpectedTracerError, "Encountered an unexpected tracer"):
|
2020-06-15 18:42:53 -07:00
|
|
|
api.jit(lambda x: self._saved_tracer)(0.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_cant_lift_sublevels(self):
|
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile(
|
2020-09-16 15:59:50 -07:00
|
|
|
"Encountered an unexpected tracer",
|
2020-06-15 18:42:53 -07:00
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(lambda x: x)(self._saved_tracer)
|
|
|
|
|
|
|
|
def test_escaped_tracers_tracer_from_higher_level(self):
|
|
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile(
|
|
|
|
"Encountered an unexpected tracer.*Tracer from a higher level",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.grad(lambda x: x)(self._saved_tracer)
|
|
|
|
|
|
|
|
def test_escaped_tracers_incompatible_sublevel(self):
|
|
|
|
def func1(x):
|
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
# Use the tracer
|
|
|
|
return x + self._saved_tracer
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.UnexpectedTracerError,
|
2020-07-30 12:59:36 -07:00
|
|
|
re.compile("Encountered an unexpected tracer",
|
2020-06-15 18:42:53 -07:00
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(func1)(2.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_cant_lift(self):
|
|
|
|
def func1(x):
|
|
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
|
|
return x + self._saved_tracer
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile("Encountered an unexpected tracer.*Can't lift",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.grad(func1)(2.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_not_among_input_tracers(self):
|
|
|
|
def func1(x):
|
|
|
|
api.grad(self.helper_save_tracer)(x)
|
|
|
|
# Use the tracer
|
|
|
|
return x + self._saved_tracer
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile(
|
|
|
|
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(func1)(2.)
|
|
|
|
|
2020-09-16 15:59:50 -07:00
|
|
|
def test_escaped_tracer_omnistaging(self):
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise unittest.SkipTest("test is omnistaging-specific")
|
|
|
|
|
|
|
|
count = 1
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f():
|
|
|
|
nonlocal count
|
|
|
|
count = jnp.add(count, 1)
|
|
|
|
f() # leaked a tracer! but currently undetected
|
|
|
|
|
|
|
|
def f(x, c):
|
|
|
|
jnp.add(count, 1)
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def g():
|
|
|
|
lax.scan(f, None, None, length=2)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
|
|
|
"tracer created on line"):
|
|
|
|
g()
|
|
|
|
|
2020-06-15 18:42:53 -07:00
|
|
|
def test_pmap_static_kwarg_error_message(self):
|
|
|
|
# https://github.com/google/jax/issues/3007
|
|
|
|
def f(a, b):
|
|
|
|
return a + b
|
|
|
|
|
|
|
|
g = jax.pmap(f, static_broadcasted_argnums=(1,))
|
|
|
|
|
|
|
|
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
|
|
|
|
r"called with only 1 positional argument. All static broadcasted "
|
|
|
|
r"arguments must be passed positionally.")
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
|
|
g(jnp.ones((1, 1)), b=1)
|
|
|
|
|
|
|
|
def test_vmap_unmapped_last(self):
|
2020-09-10 09:38:14 -04:00
|
|
|
@partial(jax.vmap, out_axes=-1)
|
2020-06-15 18:42:53 -07:00
|
|
|
def f(x):
|
|
|
|
return np.zeros((2,))
|
|
|
|
f(np.zeros((5,)))
|
|
|
|
|
2020-06-11 17:15:23 -07:00
|
|
|
def test_xla_constant_dedup(self):
|
|
|
|
y = np.array([7, 14], dtype=np.float32)
|
|
|
|
def f(x):
|
|
|
|
return x + y + y
|
|
|
|
|
|
|
|
x = np.array([1, 2], dtype=np.float32)
|
|
|
|
hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n')
|
|
|
|
hlo_lines = set([s.strip() for s in hlo_lines])
|
|
|
|
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
|
|
|
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def test_omnistaging_flag(self):
|
|
|
|
if FLAGS.jax_omnistaging:
|
|
|
|
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
|
|
|
|
self.assertLen(jaxpr.jaxpr.eqns, 1)
|
|
|
|
else:
|
|
|
|
# omnistaging can be enabled programmatically without setting the flag,
|
|
|
|
# but that shouldn't happen in tests
|
|
|
|
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
|
|
|
|
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
|
|
|
|
2020-09-17 09:57:43 -07:00
|
|
|
def test_eval_context(self):
|
|
|
|
@jit
|
|
|
|
def f():
|
|
|
|
with core.eval_context():
|
|
|
|
assert jnp.add(1, 1) == 2
|
|
|
|
|
|
|
|
f() # doesn't crash
|
|
|
|
|
2020-09-21 17:55:30 -07:00
|
|
|
def test_xla_computation_zeros_doesnt_device_put(self):
|
2020-09-21 19:33:14 -07:00
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise unittest.SkipTest("test is omnistaging-specific")
|
|
|
|
|
2020-09-21 17:55:30 -07:00
|
|
|
count = 0
|
|
|
|
def device_put_and_count(*args, **kwargs):
|
|
|
|
nonlocal count
|
|
|
|
count += 1
|
|
|
|
return orig_device_put(*args, **kwargs)
|
|
|
|
orig_device_put, xla.device_put = xla.device_put, device_put_and_count
|
|
|
|
try:
|
|
|
|
api.xla_computation(lambda: jnp.zeros(3))()
|
|
|
|
finally:
|
|
|
|
xla.device_put = orig_device_put
|
|
|
|
self.assertEqual(count, 0)
|
|
|
|
|
2020-06-15 18:42:53 -07:00
|
|
|
|
|
|
|
class RematTest(jtu.JaxTestCase):
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def test_remat_basic(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
2019-11-27 14:28:13 -08:00
|
|
|
return lax.sin(lax.sin(x)), 3.
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x, _ = g(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
ans = f(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.sin(np.sin(2.))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans, f_lin = api.linearize(f, 2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.sin(np.sin(2.))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = f_lin(3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.) * 3.
|
2019-11-27 14:28:13 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
sin_calls = []
|
|
|
|
cos_calls = []
|
|
|
|
sin_impl = lax.sin_p.impl
|
|
|
|
cos_impl = lax.cos_p.impl
|
|
|
|
try:
|
|
|
|
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
|
|
|
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
|
|
|
f_lin(3.)
|
|
|
|
finally:
|
|
|
|
lax.sin_p.def_impl(sin_impl)
|
|
|
|
lax.cos_p.def_impl(cos_impl)
|
|
|
|
self.assertEqual(len(sin_calls), 1)
|
|
|
|
self.assertEqual(len(cos_calls), 2)
|
|
|
|
|
|
|
|
def test_remat_freevars(self):
|
|
|
|
def f1(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y = 2 * jnp.sin(x)
|
|
|
|
z = jnp.cos(x) * jnp.sin(y)
|
2019-11-27 14:28:13 -08:00
|
|
|
return z
|
|
|
|
|
|
|
|
def f2(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y = 2 * jnp.sin(x)
|
|
|
|
z = api.remat(lambda x: jnp.cos(x) * jnp.sin(y))(x)
|
2019-11-27 14:28:13 -08:00
|
|
|
return z
|
|
|
|
|
|
|
|
ans, f_lin = api.linearize(f2, 2.)
|
|
|
|
expected, f_lin_expected = api.linearize(f1, 2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-11-27 14:28:13 -08:00
|
|
|
ans = f_lin(3.)
|
|
|
|
expected = f_lin_expected(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def test_remat_grad_python_control_flow(self):
|
|
|
|
@partial(api.remat, concrete=True)
|
|
|
|
def g(x):
|
|
|
|
if x > 0:
|
|
|
|
return lax.sin(x), 3.
|
|
|
|
else:
|
|
|
|
return lax.cos(x), 4.
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x, _ = g(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
ans = f(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.sin(2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f)(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.cos(2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_jit(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
|
|
|
return lax.sin(lax.sin(x))
|
|
|
|
|
|
|
|
def f_(x):
|
|
|
|
return g(x)
|
|
|
|
f = api.jit(f_)
|
|
|
|
|
|
|
|
ans = f(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.sin(np.sin(2.))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f)(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jit(api.grad(f_))(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.cos(np.sin(2.)) * np.cos(2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_vmap(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
|
|
|
return lax.sin(lax.sin(x))
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(3.)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
ans = api.vmap(g)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.sin(np.sin(x))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jacfwd(g)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jacrev(g)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_higher_order_autodiff(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.cos(lax.sin(x))
|
|
|
|
g = api.remat(f)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(g))(3.)
|
|
|
|
expected = api.grad(api.grad(f))(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_scan(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
to_scan = lambda c, x: (jnp.sin(c), None)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def f_noremat(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y, _ = lax.scan(to_scan, x, np.arange(3.))
|
2019-11-22 10:53:11 -08:00
|
|
|
return y
|
|
|
|
|
|
|
|
def f_yesremat(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y, _ = lax.scan(api.remat(to_scan), x, np.arange(3.))
|
2019-11-22 10:53:11 -08:00
|
|
|
return y
|
|
|
|
|
|
|
|
ans = f_yesremat(4.)
|
|
|
|
expected = f_noremat(4.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f_yesremat)(4.)
|
|
|
|
expected = api.grad(f_noremat)(4.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
2019-11-28 09:00:55 +01:00
|
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
2019-11-27 15:25:49 -08:00
|
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
2019-11-28 09:00:55 +01:00
|
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
|
|
|
|
|
|
|
def test_remat_no_redundant_flops(self):
|
|
|
|
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def g(x):
|
|
|
|
return f(2., x)
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def f(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x) * y
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
# We swap out sin_p's impl rule to count how many times it's invoked
|
|
|
|
called = []
|
|
|
|
sin_impl = lax.sin_p.impl
|
|
|
|
try:
|
|
|
|
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
|
|
|
api.grad(g)(3.)
|
|
|
|
finally:
|
|
|
|
lax.sin_p.def_impl(sin_impl)
|
|
|
|
num_calls = len(called)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertLessEqual(num_calls, 1)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def test_remat_binomial_checkpointing(self):
|
|
|
|
def binom_checkpoint(funs):
|
|
|
|
if len(funs) == 1:
|
|
|
|
return funs[0]
|
|
|
|
else:
|
|
|
|
f1 = binom_checkpoint(funs[:len(funs)//2])
|
|
|
|
f2 = binom_checkpoint(funs[len(funs)//2:])
|
|
|
|
return api.remat(lambda x: f1(f2(x)))
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f1 = binom_checkpoint([jnp.sin, jnp.sin, jnp.sin, jnp.sin])
|
|
|
|
f2 = lambda x: jnp.sin(jnp.sin(jnp.sin(jnp.sin(x))))
|
2019-11-22 10:53:11 -08:00
|
|
|
x = 4.
|
|
|
|
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
|
|
|
|
2019-12-23 11:49:01 -08:00
|
|
|
def test_remat_symbolic_zeros(self):
|
|
|
|
# code from https://github.com/google/jax/issues/1907
|
|
|
|
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
|
|
key, split = jax.random.split(key)
|
|
|
|
n = 5
|
|
|
|
|
|
|
|
def func(D0):
|
|
|
|
def shift(R, dR, **unused_kwargs):
|
|
|
|
return R + dR
|
|
|
|
|
|
|
|
def apply_fn(R):
|
|
|
|
return D0 * R
|
|
|
|
|
|
|
|
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
|
2020-05-05 14:59:16 -04:00
|
|
|
dtype=jnp.float32)
|
2019-12-23 11:49:01 -08:00
|
|
|
|
|
|
|
def move(R,i):
|
|
|
|
F = apply_fn(R)
|
2020-05-05 14:59:16 -04:00
|
|
|
return shift(R, 0.001 * F), jnp.array([0.])
|
2019-12-23 11:49:01 -08:00
|
|
|
|
|
|
|
move = api.remat(move)
|
2020-05-05 14:59:16 -04:00
|
|
|
R, temp = lax.scan(move, Rinit, jnp.arange(2))
|
2019-12-23 11:49:01 -08:00
|
|
|
return R[0, 0]
|
|
|
|
|
|
|
|
api.grad(func)(5.0) # doesn't crash
|
|
|
|
|
2020-01-31 23:47:30 -08:00
|
|
|
def test_remat_jit2(self):
|
|
|
|
@api.jit
|
|
|
|
def f(x):
|
|
|
|
y = 2 * x
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def g():
|
|
|
|
return y
|
|
|
|
|
|
|
|
return g()
|
|
|
|
|
|
|
|
self.assertAllClose(f(3), 6, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_nontrivial_env(self):
|
|
|
|
# simplified from https://github.com/google/jax/issues/2030
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def foo(state, dt=0.5, c=1):
|
|
|
|
u, u_t = state
|
|
|
|
u_tt = c**2 * u
|
|
|
|
u_t = u_t + u_tt * dt
|
|
|
|
return (u, u_t)
|
|
|
|
|
|
|
|
@partial(api.jit, static_argnums=(1,))
|
|
|
|
def _multi_step(state, count, dt, c):
|
|
|
|
f = lambda s, _: (foo(s, dt, c), _)
|
|
|
|
return lax.scan(f, state, None, count)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):
|
2020-01-31 23:47:30 -08:00
|
|
|
return _multi_step(state, count, dt, c)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1):
|
|
|
|
init = (u0, jnp.zeros_like(u0))
|
2020-01-31 23:47:30 -08:00
|
|
|
(uf, _), _ = multi_step(init, steps, dt, c)
|
|
|
|
return ((uf - target) ** 2).mean()
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
target = jnp.zeros((128, 128))
|
|
|
|
u0 = jnp.ones_like(target)
|
2020-01-31 23:47:30 -08:00
|
|
|
loss(u0, target, 10) # doesn't crash
|
|
|
|
|
2020-02-11 15:56:53 -08:00
|
|
|
def test_remat_jit3(self):
|
|
|
|
# https://github.com/google/jax/issues/2180
|
|
|
|
def f(w, x):
|
2020-05-05 14:59:16 -04:00
|
|
|
a = jnp.dot(x, w)
|
|
|
|
b = jnp.einsum("btd,bTd->btT", a, a)
|
|
|
|
c = jnp.einsum("btT,btd->btd", b, a)
|
|
|
|
return jnp.sum(c)
|
2020-02-11 15:56:53 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
w = jnp.ones([1, 1])
|
|
|
|
x = jnp.ones([1, 1, 1])
|
2020-02-11 15:56:53 -08:00
|
|
|
f = api.remat(f)
|
|
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def mul(a, b):
|
|
|
|
return a * b
|
|
|
|
|
|
|
|
def f(w, x):
|
|
|
|
a = mul(w, x)
|
|
|
|
b = mul(a, a)
|
|
|
|
return b
|
|
|
|
|
|
|
|
w = 1.
|
|
|
|
x = 1.
|
|
|
|
f = api.remat(f)
|
|
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
|
|
|
|
def test_remat_scan2(self):
|
|
|
|
# https://github.com/google/jax/issues/1963
|
|
|
|
|
|
|
|
def scan_bug(x0):
|
|
|
|
f = lambda x, _: (x + 1, None)
|
|
|
|
def scanned_f(x, _):
|
|
|
|
return lax.scan(f, x, xs=None, length=1)[0], None
|
|
|
|
x, _ = jax.remat(scanned_f)(x0, None)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jax.grad(scan_bug)(1.0) # doesn't crash
|
|
|
|
|
2020-04-24 18:19:24 -07:00
|
|
|
def test_remat_jit_static_argnum(self):
|
|
|
|
# https://github.com/google/jax/issues/2833
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
raise unittest.SkipTest("test only works without omnistaging") # see next test
|
|
|
|
|
2020-04-24 18:19:24 -07:00
|
|
|
def f(a_bool, y):
|
|
|
|
if a_bool:
|
|
|
|
return y + 1
|
|
|
|
else:
|
|
|
|
return y
|
|
|
|
|
|
|
|
api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
def test_remat_jit_static_argnum_omnistaging(self):
|
|
|
|
# https://github.com/google/jax/issues/2833
|
|
|
|
if not config.omnistaging_enabled:
|
|
|
|
raise unittest.SkipTest("test only works with omnistaging") # see previous test
|
|
|
|
|
|
|
|
def named_call(f):
|
|
|
|
def named_f(*args):
|
|
|
|
f_ = lu.wrap_init(lambda: (f(*args),))
|
|
|
|
out, = core.call_p.bind(f_)
|
|
|
|
return out
|
|
|
|
return named_f
|
|
|
|
|
|
|
|
def f(a_bool, y):
|
|
|
|
if a_bool:
|
|
|
|
return y + 1
|
|
|
|
else:
|
|
|
|
return y
|
|
|
|
|
|
|
|
api.jit(named_call(f), static_argnums=0)(True, 1) # no crash
|
|
|
|
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
def test_remat_eval_counter(self):
|
|
|
|
# https://github.com/google/jax/issues/2737
|
|
|
|
add_one_p = Primitive('add_one')
|
|
|
|
add_one = add_one_p.bind
|
|
|
|
|
|
|
|
num_evals = 0
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def assertEvals(n):
|
|
|
|
start = num_evals
|
|
|
|
yield
|
|
|
|
assert num_evals - start == n
|
|
|
|
|
|
|
|
def add_one_impl(x):
|
|
|
|
nonlocal num_evals
|
|
|
|
num_evals += 1
|
|
|
|
return x + 1
|
|
|
|
add_one_p.def_impl(add_one_impl)
|
|
|
|
|
|
|
|
def add_one_jvp(pin, tin):
|
|
|
|
pout = add_one(pin[0])
|
|
|
|
return pout, pout * tin[0]
|
|
|
|
ad.primitive_jvps[add_one_p] = add_one_jvp
|
|
|
|
|
|
|
|
add_one_p.def_abstract_eval(lambda x: x)
|
|
|
|
|
|
|
|
v = np.zeros((1,))
|
|
|
|
|
|
|
|
f = jax.remat(add_one)
|
|
|
|
g = jax.remat(lambda x: add_one(f(x)))
|
|
|
|
|
|
|
|
# 2 calls needed to evaluate g
|
|
|
|
with assertEvals(2):
|
|
|
|
_, vjp = jax.vjp(g, v)
|
|
|
|
# 2 calls made while transposing g, 1 call made while transposing f
|
|
|
|
with assertEvals(3):
|
|
|
|
vjp(v)
|
|
|
|
|
|
|
|
@jax.util.curry
|
|
|
|
def call(f, *args):
|
2020-06-23 09:39:45 -07:00
|
|
|
return jax.core.call(
|
|
|
|
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
|
|
|
|
*args, name='foo')[0]
|
Simplify handling of non-linear equations in backward_pass and fix remat (#3162)
Previously, `backward_pass` has been generalized to be able to handle
non-linear computation in the body, but it could easily get confused
into doing unnecessary work only to throw it away later. Additionally, it
treated any call primitive embedded inside remat like remat itself,
which is obviously wrong.
This patch fixes both of those issues and simplifies a bunch of the code
at the same time. `backward_pass` now has an invariant that it only
deals with jaxprs containing linear equations alone, and becomes
a simple transposing interpreter again.
**Background on JVP vs linearization**
Ok, so why does this change actually fix the problem? It is important to
understand that JVP and linearization transforms are actually two
different things, even though we often identify them as one. Both take
in a function of type `a -> b`, but their ranges are different! JVP
returns a function of type `(a, T a) -> (b, T b)` while linearization
returns `a -> (b, T a --o T b)`. Note that the second type carries more
information, because we get a guarantee that (1) `b` does not depend on
`T a` and (2) the dependence of `T b` on `T a` is linear.
The reason why we usually treat them as equivalent, is that they can be
shown to be "isomorphic". If we take the output of linearization, we can
make it a JVP-like function using the following combinator:
```haskell
jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta)
```
More importantly for JAX, which doesn't have a linearization interpreter,
if we assume (1) and (2), linearization can be recovered in terms of jvp
as well:
```haskell
linearize f = \a -> let fjvp = jvp f in
partial_eval fjvp (Known a) Unknown
```
That is, if we have a mathematically correct JVP, then linearization is
simply partial evaluation with all primal values marked as known, and
all tangents treated as yet unknown values.
One important performance consideration is that for forward-mode AD we
really want to use the JVP formulation, which can interleave the computation
of primals and tangents, instead of sequencing them and increasing the memory
cost. On the other hand, transposition (necessary for VJPs!) can only be
applied to linear functions, and so it can't possibly work on the output
of JVP. It really can only be apply to the second output of the
linearization transform. Hence, we really care about both, but can we avoid
having two very similar implementations of (approximately) the same thing?
It seems that the answer is yes, because of the equivalence outlined above!
**If all this is so nice, then what's the problem?**
The problem is, of course, remat. Partial eval is able to thread the
known/unknown information correctly through regular call primitives, but
mind you, remat is no regular call primitive! Once we enter remat, we are
no longer interested in treating _anything_ like a known value. After
all, our goal here is to record an accurate trace of everything that has
happened in the body of a remat, including the primal (known!)
computation. This however presents a challenge for implementing
linearization in terms of JVP, because inside the body of remat we break
the assumption that known/unknown corresponds to the primal/tangent
distinction. Its body, instead of representing the second output of
linearization simply contains the traced JVP code now...
One way to fix it would be to implement a proper linearization pass that
would track the distinciton between primal and tangent information while
still allowing to stage out code for primals. @mattjj and I have even
started hacking together an implementation for that.
I've been trying to convince @mattjj that there is no other way to go
about it, but I couldn't really convince him that this is the case.
Then, once I wanted to write a semi-formal proof I could no longer even
convince myself! Turns out that there is an alternative solution!
What this patch does is, it stops caring about the output of the
`linearize` function (defined as JVP + partial eval, as discussed above)
to be a good linearization. It still is if you don't use remats in your
code, but it still breaks miserably once you do. However, as long as all
the complications are contained solely in the `call_jaxpr` embedded inside
a remat, we still have a chance to fix them! This is because the
transposition interpreter never reaches into those bodies directly, but
rather asks the call primitive to transpose itself.
Now, how do you transpose remat? We can't just reuse the code used for
regular call primitives (this is what happens now BTW), because unlike
for them, the `call_jaxpr` doesn't represent a linear function! But it's
not completely useless either --- it contains the traced JVP code. So,
how do we get from there to a linear function? Partial eval! And if you
think about it, it is exactly what we wanted --- we end up evaluating all
the primal code in the body once again, while only staging out the tangent
computation, to be passed into the transposing interpreter again.
Fin.
2020-05-27 20:22:40 +02:00
|
|
|
|
|
|
|
f = call(add_one)
|
|
|
|
g = jax.remat(lambda x: add_one(f(x)))
|
|
|
|
|
|
|
|
# 2 calls needed to evaluate g
|
|
|
|
with assertEvals(2):
|
|
|
|
_, vjp = jax.vjp(g, v)
|
|
|
|
# 2 calls made while transposing g, no reevaluation for transposition of f
|
|
|
|
with assertEvals(2):
|
|
|
|
vjp(v)
|
|
|
|
|
|
|
|
|
2019-12-04 19:34:21 -08:00
|
|
|
class JaxprTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_scalar_literals(self):
|
|
|
|
jaxpr = api.make_jaxpr(lambda x: x + 2)(42)
|
|
|
|
self.assertLen(jaxpr.jaxpr.constvars, 0)
|
|
|
|
|
|
|
|
def test_const(self):
|
|
|
|
def fun(x):
|
2020-07-30 12:59:36 -07:00
|
|
|
return (x, 1., np.zeros(1))
|
|
|
|
|
|
|
|
if config.omnistaging_enabled:
|
|
|
|
expected = """
|
|
|
|
{ lambda a ; b.
|
|
|
|
let
|
|
|
|
in (b, 1.0, a) }
|
|
|
|
"""
|
|
|
|
else:
|
|
|
|
expected = """
|
|
|
|
{ lambda b ; a.
|
|
|
|
let
|
|
|
|
in (a, 1.0, b) }
|
|
|
|
"""
|
2019-12-04 19:34:21 -08:00
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(fun)(0.)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
2019-12-04 19:34:21 -08:00
|
|
|
|
|
|
|
def test_cond(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.cond(x >= 0.,
|
|
|
|
x + 1.,
|
|
|
|
lambda xt: xt + x,
|
|
|
|
x + 2.,
|
|
|
|
lambda xf: xf - x)
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
expected = """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
|
|
|
c = add a 1.0
|
|
|
|
d = add a 2.0
|
|
|
|
e = convert_element_type[ new_dtype=int32
|
|
|
|
old_dtype=bool ] b
|
|
|
|
f = cond[ branches=( { lambda ; e_ a b c.
|
|
|
|
let d = sub c a
|
|
|
|
in (d,) }
|
|
|
|
{ lambda ; a f_ b c.
|
|
|
|
let d = add b a
|
|
|
|
in (d,) } )
|
|
|
|
linear=(False, False, False, False) ] e a a c d
|
|
|
|
in (f,) }
|
|
|
|
"""
|
|
|
|
else:
|
|
|
|
expected = """
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
|
|
|
c = convert_element_type[ new_dtype=int32
|
|
|
|
old_dtype=bool ] b
|
|
|
|
d = add a 1.0
|
|
|
|
e = add a 2.0
|
|
|
|
f = cond[ branches=( { lambda ; e_ c a b.
|
|
|
|
let d = sub b c
|
|
|
|
in (d,) }
|
|
|
|
{ lambda ; c f_ a b.
|
|
|
|
let d = add a c
|
|
|
|
in (d,) } )
|
|
|
|
linear=(False, False, False, False) ] c a a d e
|
|
|
|
in (f,) }
|
|
|
|
"""
|
2019-12-04 19:34:21 -08:00
|
|
|
jaxpr = api.make_jaxpr(f)(3.)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
2020-02-10 11:40:05 +01:00
|
|
|
|
2020-04-23 18:07:51 -07:00
|
|
|
def test_make_jaxpr_static_argnums(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(f, static_argnums=(1,))(2, 3)
|
|
|
|
self.assertIn('3', str(jaxpr))
|
|
|
|
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
class LazyTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_compiles(self):
|
|
|
|
|
|
|
|
make_computation_builder = xb.make_computation_builder
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def make_computation_builder_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return make_computation_builder(*args, **kwargs)
|
|
|
|
|
|
|
|
xb.make_computation_builder = make_computation_builder_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
xb.make_computation_builder = make_computation_builder
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_lazy_jit_closed_over_values(self):
|
|
|
|
if not core.skip_checks:
|
2020-01-18 08:26:23 -05:00
|
|
|
raise unittest.SkipTest("oom test skipped when core.skip_checks is False")
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
y = jnp.arange(int(1e12)) # will likely oom if materialized
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
ans = jit(lambda x: (x + y)[1])(1)
|
|
|
|
self.assertEqual(ans, 2)
|
|
|
|
|
|
|
|
def test_jit_forces_arguments(self):
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sum(x)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(10, dtype=jnp.int32)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
assert xla.is_device_constant(x) # lazy iota
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
_ = f(x)
|
|
|
|
|
|
|
|
python_should_be_executing = False # should not recompile
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(10, dtype=np.int32)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
_ = f(x)
|
|
|
|
|
|
|
|
@parameterized.parameters(jtu.cases_from_list(range(10000)))
|
|
|
|
def test_random_lazy_program(self, seed):
|
|
|
|
|
|
|
|
def random_array(rng):
|
|
|
|
kind = rng.choice(['arr', 'iota', 'eye', 'tri'])
|
|
|
|
if kind == 'arr':
|
2020-05-05 14:59:16 -04:00
|
|
|
dtype = [np.float32, np.int32][rng.choice(2)]
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
dim = rng.randint(4)
|
|
|
|
shape = rng.randint(4, size=dim)
|
2020-05-05 14:59:16 -04:00
|
|
|
np_x = np.asarray(rng.randn(*shape), dtype=dtype)
|
|
|
|
jax_x = jnp.array(np_x, dtype=dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
elif kind == 'iota':
|
2020-05-05 14:59:16 -04:00
|
|
|
dtype = [np.float32, np.int32][rng.choice(2)]
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
size = rng.randint(5)
|
2020-05-05 14:59:16 -04:00
|
|
|
np_x = np.arange(size, dtype=dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
jax_x = lax.iota(dtype, size)
|
|
|
|
elif kind == 'eye':
|
2020-05-05 14:59:16 -04:00
|
|
|
dtype = [np.float32, np.int32][rng.choice(2)]
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
N = rng.randint(2, 5)
|
|
|
|
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
|
|
|
|
k = rng.choice([-1, 0, 1])
|
2020-05-05 14:59:16 -04:00
|
|
|
np_x = np.eye(N, M, k, dtype=dtype)
|
|
|
|
jax_x = jnp.eye(N, M, k, dtype=dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
elif kind == 'tri':
|
2020-05-05 14:59:16 -04:00
|
|
|
dtype = [np.float32, np.int32][rng.choice(2)]
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
N = rng.randint(2, 5)
|
|
|
|
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
|
|
|
|
k = rng.choice([-1, 0, 1])
|
2020-05-05 14:59:16 -04:00
|
|
|
np_x = np.tri(N, M, k, dtype=dtype)
|
|
|
|
jax_x = jnp.tri(N, M, k, dtype=dtype)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
assert False
|
2020-05-05 14:59:16 -04:00
|
|
|
assert type(np_x) is np.ndarray and type(jax_x) is xla.DeviceArray
|
|
|
|
return np_x, jax_x
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
def random_op(rng, shape):
|
|
|
|
kind = rng.choice(['transpose', 'broadcast', 'reshape'])
|
|
|
|
if kind == 'transpose':
|
|
|
|
perm = tuple(rng.permutation(len(shape)))
|
2020-05-05 14:59:16 -04:00
|
|
|
return Op(partial(np.transpose, axes=perm),
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
partial(lax.transpose, permutation=perm))
|
|
|
|
elif kind == 'broadcast':
|
|
|
|
n = rng.randint(1, 3)
|
|
|
|
new_sizes = rng.randint(1, 4, size=n)
|
|
|
|
new_ndim = n + len(shape)
|
|
|
|
bcast_dims = tuple(sorted(rng.permutation(new_ndim)[:len(shape)]))
|
|
|
|
shape_iter = iter(shape)
|
|
|
|
new_sizes = iter(rng.randint(1, 4, size=n))
|
|
|
|
new_shape = [next(shape_iter) if i in bcast_dims else next(new_sizes)
|
|
|
|
for i in range(new_ndim)]
|
|
|
|
return Op(partial(lax_reference.broadcast_in_dim, shape=new_shape,
|
|
|
|
broadcast_dimensions=bcast_dims),
|
|
|
|
partial(lax.broadcast_in_dim, shape=new_shape,
|
|
|
|
broadcast_dimensions=bcast_dims))
|
|
|
|
elif kind == 'reshape':
|
|
|
|
new_shape = list(shape)
|
|
|
|
for _ in range(rng.randint(1, 3)):
|
|
|
|
loc = len(new_shape) and rng.randint(len(new_shape))
|
|
|
|
new_shape.insert(loc, 1)
|
|
|
|
new_shape = tuple(new_shape)
|
2020-05-05 14:59:16 -04:00
|
|
|
return Op(partial(np.reshape, newshape=new_shape),
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
partial(lax.reshape, new_sizes=new_shape))
|
|
|
|
else:
|
|
|
|
assert False
|
2020-05-05 14:59:16 -04:00
|
|
|
Op = collections.namedtuple('Op', ['np_fn', 'jax_fn'])
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
rng = np.random.RandomState(seed)
|
|
|
|
np_x, jax_x = _, orig_x = random_array(rng)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
ops = []
|
|
|
|
with jtu.count_primitive_compiles() as count:
|
|
|
|
for _ in range(rng.randint(5)):
|
2020-05-05 14:59:16 -04:00
|
|
|
op = random_op(rng, np.shape(np_x))
|
|
|
|
np_x = op.np_fn(np_x)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
jax_x = op.jax_fn(jax_x)
|
|
|
|
ops.append(op)
|
|
|
|
self.assertEqual(count[0], 0)
|
|
|
|
|
|
|
|
kind = rng.choice(['closure', 'npy_value', 'force', 'add'])
|
|
|
|
if kind == 'closure':
|
|
|
|
result = api.jit(lambda x: x + jax_x)(0)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(np_x, result, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
elif kind == 'npy_value':
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(np_x, jax_x, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
elif kind == 'force':
|
|
|
|
result = xla._force(jax_x)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(np_x, result, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
elif kind == 'add':
|
2020-05-05 14:59:16 -04:00
|
|
|
result = jax_x + np.zeros(jax_x.shape, dtype=jax_x.dtype)
|
|
|
|
self.assertAllClose(np_x, result, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def apply_ops(x):
|
|
|
|
for op in ops:
|
|
|
|
x = op.jax_fn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jit_result = apply_ops(orig_x)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(jit_result, np_x, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
@jit
|
|
|
|
def apply_ops_closure():
|
|
|
|
x = orig_x
|
|
|
|
for op in ops:
|
|
|
|
x = op.jax_fn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jit_result = apply_ops_closure()
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(jit_result, np_x, check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
def test_constant_forcing_computations_cached(self):
|
|
|
|
# from https://github.com/google/jax/issues/1909
|
|
|
|
xla._lazy_force_computation.cache_clear() # clear force compile cache
|
2020-07-30 12:59:36 -07:00
|
|
|
big_lazy_x = np.ones((api.device_count(), 100))
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
f = api.pmap(lambda x: 2 * x)
|
|
|
|
_ = f(big_lazy_x)
|
|
|
|
|
|
|
|
with self.count_compiles() as count:
|
|
|
|
_ = f(big_lazy_x)
|
|
|
|
self.assertEqual(count[0], 0)
|
|
|
|
|
|
|
|
def test_zeros_ones_compilation(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
w = jnp.ones(3) + jnp.ones(3) # ensure + has a cache entry
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
w.block_until_ready()
|
|
|
|
|
|
|
|
xla._lazy_force_computation.cache_clear() # clear force compile cache
|
|
|
|
|
|
|
|
with self.count_compiles() as count:
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.ones(3) + jnp.zeros(3)
|
|
|
|
y = jnp.ones(3) + jnp.ones(3)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-05-04 11:30:28 +03:00
|
|
|
self.assertEqual(1, count[0])
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(x, np.ones(3), check_dtypes=False)
|
|
|
|
self.assertAllClose(y, np.ones(3) + np.ones(3), check_dtypes=False)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
class CustomJVPTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_basic(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), 2 * jnp.cos(x) * g
|
2020-01-15 15:00:38 -08:00
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
x = 3.
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_invariance(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.cos(2 * x) / 2.
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return (f(x), 3 * g)
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
def f2(x):
|
|
|
|
y, _ = api.jvp(f, (x,), (x,))
|
|
|
|
return y
|
|
|
|
def f3(x):
|
|
|
|
y, _ = api.jvp(f2, (x,), (x,))
|
|
|
|
return y
|
|
|
|
x = 1.
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
|
|
api.jvp(f2, (x,), (x,)),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
|
|
api.jvp(f3, (x,), (x,)),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_python_control_flow(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
if x > 0:
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
else:
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
if x > 0:
|
|
|
|
return f(x), 2 * g
|
|
|
|
else:
|
|
|
|
return f(x), 3 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
x = 2.
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(f(-x), jnp.cos(-x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(jnp.sin(x), 2.),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.jvp(f, (-x,), (1.,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(jnp.cos(-x), 3.),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == 0
|
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == jnp.ndim(g) == 0
|
|
|
|
return f(x), 2 * jnp.cos(x) * g
|
2020-01-15 15:00:38 -08:00
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(3.)
|
|
|
|
xx = jnp.arange(6.).reshape(2, 3)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of f
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of jvp of f
|
|
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# jvp of vmap of f
|
|
|
|
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of jvp of vmap of f
|
|
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(xx), 2 * jnp.cos(xx) * xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_jit(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), 2 * jnp.cos(x) * g
|
2020-01-15 15:00:38 -08:00
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
|
|
|
|
# jit
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# jit of jvp
|
|
|
|
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
|
2020-05-05 14:59:16 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# jvp of jit
|
|
|
|
self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x) * x),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_pytrees(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return {'b': jnp.sin(x['a'])}
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']}
|
2020-01-15 15:00:38 -08:00
|
|
|
f.defjvp(f_jvp)
|
|
|
|
x = {'a': 3.}
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
2020-05-05 14:59:16 -04:00
|
|
|
({'b': jnp.sin(x['a'])},
|
|
|
|
{'b': 2 * jnp.cos(x['a']) * x['a']}),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_kwargs(self):
|
|
|
|
# from https://github.com/google/jax/issues/1938
|
|
|
|
@api.custom_jvp
|
|
|
|
def my_fun(x, y, c=1.):
|
|
|
|
return c * (x + y)
|
|
|
|
def my_jvp(primals, tangents):
|
|
|
|
x, y, c = primals
|
|
|
|
t_x, t_y, t_c = tangents
|
|
|
|
return my_fun(x, y, c), t_c
|
|
|
|
my_fun.defjvp(my_jvp)
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
|
2020-01-15 15:00:38 -08:00
|
|
|
f(10., 5.) # doesn't crash
|
|
|
|
api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
|
|
|
|
|
|
|
|
def test_initial_style(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return 3 * x
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.grad(foo)(3.)
|
|
|
|
expected = 2.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
|
|
expected = 0.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_initial_style_vmap(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == 0
|
2020-01-15 15:00:38 -08:00
|
|
|
return 3 * x
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = api.vmap(foo)(jnp.ones(3))
|
|
|
|
expected = 3. * jnp.ones(3)
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3))
|
|
|
|
expected = 2. * jnp.ones(3)
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_closed_over_tracers_error_message(self):
|
|
|
|
def f(x):
|
|
|
|
@api.custom_jvp
|
|
|
|
def g(y):
|
|
|
|
return x + y
|
|
|
|
def g_jvp(primals, tangents):
|
2020-06-02 19:25:47 -07:00
|
|
|
return g(x), 2 * primals[0]
|
2020-01-15 15:00:38 -08:00
|
|
|
g.defjvp(g_jvp)
|
|
|
|
return g(1.)
|
|
|
|
|
|
|
|
self.assertRaises(
|
|
|
|
core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
|
|
|
|
self.assertRaises(
|
|
|
|
core.UnexpectedTracerError, lambda: api.grad(f)(3.))
|
|
|
|
|
|
|
|
def test_nondiff_arg(self):
|
|
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
|
|
def app(f, x):
|
|
|
|
return f(x)
|
|
|
|
def app_jvp(f, primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
return app(f, x), 3 * t
|
|
|
|
app.defjvp(app_jvp)
|
|
|
|
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
|
|
expected = 2
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,))
|
|
|
|
expected = (2., 3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def test_nondiff_arg_jit_tracer(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
|
|
def f(x, y):
|
|
|
|
return x * y
|
|
|
|
def f_jvp(x, primals, tangents):
|
|
|
|
(y,), (t_y,) = primals, tangents
|
|
|
|
return f(x, y), 5 * t_y
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def g(x, y):
|
|
|
|
return f(x, y)
|
|
|
|
|
|
|
|
ans = api.jvp(lambda y: g(2., y), (3.,), (1.,))
|
|
|
|
expected = (6., 5.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap_axes(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_pmap(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_missing_jvp_rule_error_message(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
@api.custom_jvp
|
|
|
|
def foo(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: foo(2))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: api.jvp(foo, (2.,), (1.,)))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_jvp_rule_inconsistent_pytree_structures_error_message(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return (x**2,)
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return f(x), [2 * x * t, x]
|
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
|
|
"with equal container (pytree) structures, but got "
|
2020-03-24 20:43:33 -07:00
|
|
|
"{} and {} respectively.".format(
|
2020-01-15 15:00:38 -08:00
|
|
|
tree_util.tree_structure((1,)),
|
|
|
|
tree_util.tree_structure([1, 2]))
|
|
|
|
),
|
|
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_primal_tangent_aval_disagreement_error_message(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.reshape(t, (1,))
|
2020-03-24 20:43:33 -07:00
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
|
|
"with equal shapes and dtypes, but got float32[] and float32[1] "
|
|
|
|
"respectively."),
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),)))
|
2020-03-24 20:43:33 -07:00
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_jvp_rule_doesnt_return_pair_error_message(self):
|
|
|
|
# https://github.com/google/jax/issues/2516
|
|
|
|
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return t
|
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce a pair (list or tuple of length two) "
|
|
|
|
"representing primal and tangent outputs, got 1.0"),
|
|
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
|
2020-03-28 13:52:40 -07:00
|
|
|
def test_multiple_rule_invocations(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def expit(x):
|
|
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
|
|
|
|
@expit.defjvp
|
|
|
|
def _expit_jvp(primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
ans = expit(x)
|
|
|
|
t_out = t * ans * (1 - ans)
|
|
|
|
return ans, t_out
|
|
|
|
|
|
|
|
def scanned_fun(c, _):
|
|
|
|
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
|
|
return c[-1]
|
|
|
|
|
|
|
|
# just make sure these don't crash
|
|
|
|
foo(3.)
|
|
|
|
grad(foo)(3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.))
|
2020-03-28 13:52:40 -07:00
|
|
|
|
|
|
|
def test_hard_stuff(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
arr = jnp.ones((5, 2, 2))
|
|
|
|
api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash
|
2020-03-28 13:52:40 -07:00
|
|
|
|
|
|
|
def test_hard_stuff2(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return lax.tie_in(x, np.zeros(x.shape, x.dtype))
|
2020-03-28 13:52:40 -07:00
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return f(x), t
|
|
|
|
|
|
|
|
# don't crash
|
2020-05-05 14:59:16 -04:00
|
|
|
jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
|
|
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
|
|
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
|
|
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
|
|
|
|
jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))
|
2020-03-28 13:52:40 -07:00
|
|
|
|
|
|
|
def test_hard_stuff3(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def relu(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.maximum(x, 0)
|
2020-03-28 13:52:40 -07:00
|
|
|
|
|
|
|
@relu.defjvp
|
|
|
|
def _relu_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
|
|
|
|
|
|
|
|
def scanned_fun(c, _):
|
|
|
|
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
|
|
return c[-1]
|
|
|
|
|
|
|
|
# don't crash
|
2020-05-05 14:59:16 -04:00
|
|
|
jax.jit(jax.vmap(f))(jnp.arange(3.))
|
|
|
|
jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.))
|
|
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.))
|
|
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.))
|
|
|
|
jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_eval_shape(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def expit(x):
|
|
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
|
|
|
|
@expit.defjvp
|
|
|
|
def _expit_jvp(primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
ans = expit(x)
|
|
|
|
t_out = t * ans * (1 - ans)
|
|
|
|
return ans, t_out
|
|
|
|
|
|
|
|
# don't crash
|
2020-05-05 14:59:16 -04:00
|
|
|
api.eval_shape(expit, jnp.ones((2, 3)))
|
|
|
|
api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3)))
|
2020-03-29 20:51:51 -07:00
|
|
|
|
2020-04-10 11:45:33 -07:00
|
|
|
def test_jaxpr_zeros(self):
|
|
|
|
# from https://github.com/google/jax/issues/2657
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(A, b):
|
2020-08-19 18:39:25 +02:00
|
|
|
return A @ b
|
2020-04-10 11:45:33 -07:00
|
|
|
|
|
|
|
def f_jvp(primals, tangents):
|
2020-08-19 18:39:25 +02:00
|
|
|
A, b = primals
|
|
|
|
dA, db = tangents
|
|
|
|
z = f(A, b)
|
|
|
|
dz = A @ db + dA @ b
|
|
|
|
return z, dz
|
2020-04-10 11:45:33 -07:00
|
|
|
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def experiment(theta):
|
2020-08-19 18:39:25 +02:00
|
|
|
def step(q, _):
|
|
|
|
z = f(jnp.eye(3), jnp.ones(3) * theta)
|
|
|
|
q += z[0]
|
|
|
|
return q, q
|
2020-04-10 11:45:33 -07:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
q = 0.
|
|
|
|
q, _ = lax.scan(step, q, None, 4)
|
|
|
|
return q
|
2020-04-10 11:45:33 -07:00
|
|
|
|
|
|
|
grad(experiment)(1.) # doesn't crash
|
|
|
|
|
2020-05-28 10:20:36 -07:00
|
|
|
def test_linear_in_scan(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return -x
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
x_dot, = tangents
|
|
|
|
return f(x), f(x_dot)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.grad(foo)(3.)
|
|
|
|
expected = -1.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-06-09 15:19:53 -07:00
|
|
|
def test_custom_jvps_first_rule_is_none(self):
|
|
|
|
# https://github.com/google/jax/issues/3389
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x, y):
|
|
|
|
return x ** 2 * y
|
|
|
|
|
|
|
|
f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot)
|
|
|
|
ans = grad(f, 1)(2., 3.) # doesn't crash
|
|
|
|
expected = 12.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-07-23 19:49:04 -07:00
|
|
|
def test_concurrent_initial_style(self):
|
|
|
|
# https://github.com/google/jax/issues/3843
|
|
|
|
def unroll(param, sequence):
|
|
|
|
def scan_f(prev_state, inputs):
|
|
|
|
return prev_state, jax.nn.sigmoid(param * inputs)
|
|
|
|
return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1])
|
|
|
|
|
|
|
|
def run():
|
|
|
|
return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0]))
|
|
|
|
|
2020-07-23 20:59:12 -07:00
|
|
|
expected = run()
|
|
|
|
|
2020-07-23 19:49:04 -07:00
|
|
|
# we just don't want this to crash
|
2020-07-30 12:59:36 -07:00
|
|
|
n_workers = 2
|
2020-07-23 19:49:04 -07:00
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e:
|
|
|
|
futures = []
|
|
|
|
for _ in range(n_workers):
|
|
|
|
futures.append(e.submit(run))
|
2020-07-23 20:59:12 -07:00
|
|
|
results = [f.result() for f in futures]
|
|
|
|
for ans in results:
|
|
|
|
self.assertAllClose(ans, expected)
|
2020-07-23 19:49:04 -07:00
|
|
|
|
2020-04-10 11:45:33 -07:00
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
class CustomVJPTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_basic(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
x = 3.
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.value_and_grad(f)(x),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_invariance(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.cos(2 * x) / 2.
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
|
|
|
return (f(x), x)
|
|
|
|
def f_rev(x, g):
|
|
|
|
return (g * 3,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def f2(x):
|
|
|
|
y, _ = api.value_and_grad(f)(x)
|
|
|
|
return y
|
|
|
|
def f3(x):
|
|
|
|
y, _ = api.value_and_grad(f2)(x)
|
|
|
|
return y
|
|
|
|
x = 1.
|
|
|
|
self.assertAllClose(f(x), f2(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(f(x), f3(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f2)(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f3)(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_python_control_flow(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
if x > 0:
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
else:
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
|
|
|
if x > 0:
|
|
|
|
return f(x), x
|
|
|
|
else:
|
|
|
|
return f(x), x
|
|
|
|
def f_rev(x, g):
|
|
|
|
if x > 0:
|
|
|
|
return (2 * g,)
|
|
|
|
else:
|
|
|
|
return (3 * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = 2.
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(f(-x), jnp.cos(-x))
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == 0
|
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == 0
|
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(3.)
|
|
|
|
xx = jnp.arange(6.).reshape(2, 3)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of f
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(api.vmap(f)(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of grad of f
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(x), 2 * jnp.cos(x)))
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
|
2020-06-01 17:19:23 -04:00
|
|
|
(jnp.sin(xx), 2 * jnp.cos(xx)))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# grad of vmap of f
|
|
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
|
2020-06-01 17:19:23 -04:00
|
|
|
2 * jnp.cos(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
|
2020-06-01 17:19:23 -04:00
|
|
|
2 * jnp.cos(xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# vmap of grad of vmap of f
|
|
|
|
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
|
2020-06-01 17:19:23 -04:00
|
|
|
2 * jnp.cos(xx))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_jit(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
|
|
|
|
# jit
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(api.jit(f)(x), jnp.sin(x))
|
|
|
|
self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
# jit of grad
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# grad of jit
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x),
|
2020-01-15 15:00:38 -08:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_pytrees(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return {'b': jnp.sin(x['a'])}
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), {'r': jnp.cos(x['a'])}
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_bwd(res, g):
|
|
|
|
cos_x = res['r']
|
|
|
|
return ({'a': 2 * cos_x * g['b']},)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
x = {'a': 3.}
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f(x)['b'], jnp.sin(x['a']))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
|
2020-06-01 17:19:23 -04:00
|
|
|
{'a': 2 * jnp.cos(x['a'])})
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_jvp_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
|
|
lambda: api.jvp(f, (3.,), (1.,)))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
2020-05-05 14:59:16 -04:00
|
|
|
lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_kwargs(self):
|
|
|
|
# from https://github.com/google/jax/issues/1938
|
|
|
|
@api.custom_vjp
|
|
|
|
def my_fun(x, y, c=1.):
|
|
|
|
return c * (x + y)
|
|
|
|
my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
|
|
|
|
lambda _, g: (g, g, g))
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
|
2020-01-15 15:00:38 -08:00
|
|
|
f(10., 5.) # doesn't crash
|
|
|
|
api.grad(f)(10., 5.) # doesn't crash
|
|
|
|
|
|
|
|
def test_initial_style(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.grad(foo)(3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = 2. * jnp.cos(3.)
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = -2. * jnp.sin(3.)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
def test_initial_style_vmap(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
assert jnp.ndim(x) == 0
|
2020-01-15 15:00:38 -08:00
|
|
|
return 3 * x
|
|
|
|
def f_fwd(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = api.vmap(foo)(jnp.arange(3.))
|
|
|
|
expected = 3. * jnp.arange(3.)
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.))
|
|
|
|
expected = 2. * jnp.cos(jnp.arange(3.))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_nondiff_arg(self):
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
|
|
def app(f, x):
|
|
|
|
return f(x)
|
|
|
|
def app_fwd(f, x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return app(f, x), jnp.cos(x)
|
2020-01-15 15:00:38 -08:00
|
|
|
def app_rev(f, cos_x, g):
|
|
|
|
return (cos_x * g,)
|
|
|
|
app.defvjp(app_fwd, app_rev)
|
|
|
|
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
|
|
expected = 2
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = (2., jnp.cos(1.))
|
2020-01-15 15:00:38 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def test_nondiff_arg_tracer(self):
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
|
|
def f(x, y):
|
|
|
|
return x * y
|
|
|
|
def f_fwd(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return f(x, y), jnp.cos(y)
|
2020-03-28 14:15:46 -07:00
|
|
|
def f_rev(x, cos_y, g):
|
|
|
|
return (cos_y * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def g(x, y):
|
|
|
|
return f(x, y)
|
|
|
|
|
|
|
|
ans = g(2, 3.)
|
|
|
|
expected = 6.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(g, 1)(2., 3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = jnp.cos(3.)
|
2020-03-28 14:15:46 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def test_vmap_axes(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_pmap(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_missing_vjp_rule_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def foo(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
|
|
lambda: foo(2))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
|
|
|
|
def test_vjp_rule_inconsistent_pytree_structures_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def foo_fwd(x):
|
|
|
|
return x, None
|
|
|
|
|
|
|
|
def foo_bwd(_, g):
|
|
|
|
return g
|
|
|
|
|
|
|
|
f.defvjp(foo_fwd, foo_bwd)
|
|
|
|
|
|
|
|
f(2) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom VJP rule must produce an output with the same container "
|
|
|
|
"(pytree) structure as the args tuple of the primal function, "
|
|
|
|
"and in particular must produce a tuple of length equal to the "
|
|
|
|
"number of arguments to the primal function, but got VJP output "
|
|
|
|
"structure {} for primal input structure {}.".format(
|
|
|
|
tree_util.tree_structure(1),
|
|
|
|
tree_util.tree_structure((1,)))
|
|
|
|
),
|
|
|
|
lambda: api.grad(f)(2.))
|
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_issue2511(self):
|
2020-05-05 14:59:16 -04:00
|
|
|
arr = jnp.ones((5, 2, 2))
|
|
|
|
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
|
2020-03-29 20:51:51 -07:00
|
|
|
api.jit(foo)(arr) # doesn't crash
|
|
|
|
|
2020-04-02 22:52:07 -07:00
|
|
|
def test_lowering_out_of_traces(self):
|
|
|
|
# https://github.com/google/jax/issues/2578
|
|
|
|
|
|
|
|
class F(collections.namedtuple("F", ["a"])):
|
|
|
|
def __call__(self, x):
|
|
|
|
return jax.nn.relu(self.a) * x
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def g(f, x):
|
|
|
|
return f(x)
|
|
|
|
|
|
|
|
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
|
|
|
|
|
2020-04-22 23:29:32 -07:00
|
|
|
def test_nondiff_argnums_stop_gradient(self):
|
|
|
|
# https://github.com/google/jax/issues/2784
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
|
|
|
|
def _clip_gradient(lo, hi, x):
|
|
|
|
return x # identity function
|
|
|
|
|
|
|
|
def clip_gradient_fwd(lo, hi, x):
|
2020-07-30 12:59:36 -07:00
|
|
|
# return x, None
|
|
|
|
return x, (hi, )
|
2020-04-22 23:29:32 -07:00
|
|
|
|
|
|
|
def clip_gradient_bwd(lo, hi, _, g):
|
2020-07-30 12:59:36 -07:00
|
|
|
return (jnp.clip(g, lo, hi),)
|
2020-04-22 23:29:32 -07:00
|
|
|
|
|
|
|
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
|
|
|
|
|
|
|
def clip_gradient(x):
|
2020-07-30 12:59:36 -07:00
|
|
|
lo = -1
|
|
|
|
hi = x + 1 # causes things to break
|
|
|
|
return _clip_gradient(lo, hi, x)
|
2020-04-22 23:29:32 -07:00
|
|
|
|
|
|
|
jax.grad(clip_gradient)(1.) # doesn't crash
|
|
|
|
|
2020-07-09 14:13:45 -04:00
|
|
|
def test_nestable_vjp(self):
|
|
|
|
# Verify that https://github.com/google/jax/issues/3667 is resolved.
|
|
|
|
def f(x):
|
2020-08-19 18:39:25 +02:00
|
|
|
return x ** 2
|
2020-07-09 14:13:45 -04:00
|
|
|
|
|
|
|
@api.custom_vjp
|
|
|
|
def g(x):
|
2020-08-19 18:39:25 +02:00
|
|
|
return f(x)
|
2020-07-09 14:13:45 -04:00
|
|
|
|
|
|
|
def g_fwd(x):
|
2020-08-19 18:39:25 +02:00
|
|
|
y, f_vjp = api.vjp(f, x)
|
|
|
|
return y, f_vjp
|
2020-07-09 14:13:45 -04:00
|
|
|
|
|
|
|
def g_bwd(f_vjp, y_bar):
|
2020-08-19 18:39:25 +02:00
|
|
|
return f_vjp(y_bar)
|
2020-07-09 14:13:45 -04:00
|
|
|
|
|
|
|
g.defvjp(g_fwd, g_bwd)
|
|
|
|
|
|
|
|
# Check that VJP can be nested in simple situations. For this to pass,
|
|
|
|
# vjp has to return a PyTree.
|
|
|
|
_, g_vjp = api.vjp(g, 1.0)
|
|
|
|
y, = g_vjp(1.0)
|
|
|
|
self.assertAllClose(y, jnp.array(2.0))
|
|
|
|
|
|
|
|
# Check that VJP can be nested in complex situations. For this to pass,
|
|
|
|
# vjp can't treat the closed-over tracer x as a static argument.
|
|
|
|
@jit
|
|
|
|
def z(x):
|
2020-08-19 18:39:25 +02:00
|
|
|
_, g_vjp = api.vjp(g, x)
|
|
|
|
return g_vjp
|
2020-07-09 14:13:45 -04:00
|
|
|
y, = z(1.0)(3.0)
|
|
|
|
self.assertAllClose(y, jnp.array(6.0))
|
2020-06-15 18:42:53 -07:00
|
|
|
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
class InvertibleADTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_invertible_basic(self):
|
|
|
|
def f(x):
|
|
|
|
return (jnp.exp(x) * 4) * x
|
|
|
|
|
|
|
|
finv = jax.invertible(f)
|
|
|
|
|
2020-08-11 11:45:58 +02:00
|
|
|
x = jnp.ones((5,))
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
if config.omnistaging_enabled:
|
|
|
|
expected = """
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = exp a
|
|
|
|
d = mul c 4.0
|
|
|
|
e = mul d a
|
|
|
|
f = mul b a
|
|
|
|
g = div e a
|
|
|
|
h = mul b g
|
|
|
|
i = div g 4.0
|
|
|
|
j = mul f 4.0
|
|
|
|
_ = log i
|
|
|
|
k = mul j i
|
|
|
|
l = add_any h k
|
|
|
|
in (l,) }
|
|
|
|
"""
|
|
|
|
else:
|
|
|
|
expected = """
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = exp a
|
|
|
|
d = mul c 4.0
|
|
|
|
e = mul d a
|
|
|
|
f = div e a
|
|
|
|
g = mul b f
|
|
|
|
h = mul b a
|
|
|
|
i = mul h 4.0
|
|
|
|
j = div f 4.0
|
|
|
|
k = mul i j
|
|
|
|
l = add_any g k
|
|
|
|
in (l,) }
|
|
|
|
"""
|
|
|
|
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
jaxpr = jax.make_jaxpr(lambda p, ct: jax.vjp(finv, p)[1](ct))(x, x)
|
2020-07-30 12:59:36 -07:00
|
|
|
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
|
|
|
|
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x)))(x),
|
|
|
|
jax.value_and_grad(lambda x: np.sum(finv(x)))(x),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
def test_invertible_blocks(self):
|
|
|
|
# NB: This is the reversible ResNet block
|
|
|
|
def mk_reversible_block(f, g):
|
|
|
|
@jax.custom_ivjp
|
|
|
|
def rev_block(x1, x2):
|
|
|
|
y1 = f(x2) + x1
|
|
|
|
y2 = g(y1) + x2
|
|
|
|
return y1, y2
|
|
|
|
|
|
|
|
@rev_block.defivjp
|
|
|
|
def rev_block_ivjp(xs, ys, dys):
|
|
|
|
(y1, y2) = ys
|
|
|
|
(dy1, dy2) = dys
|
|
|
|
|
|
|
|
dgo, dx2 = dy2, dy2
|
|
|
|
go, gvjp = jax.vjp(g, y1)
|
|
|
|
dy1 += gvjp(dgo)[0]
|
|
|
|
del gvjp
|
|
|
|
x2 = y2 - go
|
|
|
|
|
|
|
|
dfo, dx1 = dy1, dy1
|
|
|
|
fo, fvjp = jax.vjp(f, x2)
|
|
|
|
dx2 += fvjp(dfo)[0]
|
|
|
|
del fvjp
|
|
|
|
x1 = y1 - fo
|
|
|
|
|
|
|
|
return (x1, x2), (dx1, dx2)
|
|
|
|
|
|
|
|
return rev_block
|
|
|
|
|
|
|
|
rev_block = mk_reversible_block(jnp.sin, jnp.cos)
|
|
|
|
|
|
|
|
def g(x1, x2):
|
|
|
|
for i in range(2):
|
|
|
|
x1, x2 = rev_block(x1, x2)
|
|
|
|
return x1, x2
|
|
|
|
|
|
|
|
def reduce(f, x1, x2):
|
|
|
|
y1, y2 = f(x1, x2)
|
|
|
|
return np.sum(y1) + np.sum(y2)
|
|
|
|
|
|
|
|
x = np.ones((1,))
|
|
|
|
# FIXME: This breaks when argnums is left as default (i.e. 0), because JVP prunes
|
|
|
|
# zero tangents from call primitives.
|
|
|
|
self.assertAllClose(jax.value_and_grad(partial(reduce, jax.invertible(g)), argnums=(0, 1))(x, x + 2),
|
|
|
|
jax.value_and_grad(partial(reduce, g), argnums=(0, 1))(x, x + 2),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
2020-08-11 11:45:58 +02:00
|
|
|
def test_invertible_partial_diff(self):
|
|
|
|
# Check that we don't have to differentiate with respect to inputs
|
|
|
|
# of the invertible function.
|
|
|
|
def f(x, y):
|
|
|
|
return (jnp.exp(x) * 4) * x, y + 4
|
|
|
|
|
|
|
|
finv = jax.invertible(f)
|
|
|
|
o = np.ones((5,))
|
|
|
|
self.assertAllClose(jax.value_and_grad(lambda x: np.sum(f(x, o)[0]))(o),
|
|
|
|
jax.value_and_grad(lambda x: np.sum(finv(x, o)[0]))(o),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
|
2020-03-23 14:29:22 -07:00
|
|
|
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_defvjp_all(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return 2. * foo_p.bind(x)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * jnp.sin(x),)))
|
2020-03-23 14:29:22 -07:00
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
|
|
|
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(grad_ans, 4 * 2 * np.sin(3.), check_dtypes=False)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
def test_defvjp_all_const(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return foo_p.bind(x)
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
|
|
|
self.assertAllClose(val_ans, 9., check_dtypes=False)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(grad_ans, 12.)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
def test_defvjp_all_higher_order_revmode(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return 2. * foo_p.bind(x)
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
|
|
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_all_multiple_arguments(self):
|
|
|
|
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
|
|
|
|
# the first argument in one case
|
|
|
|
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x, y): return foo_p.bind(x, y)
|
|
|
|
|
|
|
|
def vjpfun(x, y):
|
|
|
|
out = x**2 + y**3
|
|
|
|
vjp = lambda g: (g + x + y, g * x * 9.)
|
|
|
|
return out, vjp
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, vjpfun)
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
|
|
|
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(foo, (0, 1))(3., 4.)
|
|
|
|
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_all_custom_transforms(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
api.defvjp_all(foo, lambda x: (jnp.sin(x), lambda g: (g * x,)))
|
2020-03-23 14:29:22 -07:00
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(val_ans, np.sin(3.), check_dtypes=False)
|
2020-03-23 14:29:22 -07:00
|
|
|
self.assertAllClose(grad_ans, 3., check_dtypes=False)
|
|
|
|
|
|
|
|
# TODO(mattjj): add defvjp_all test with pytree arguments
|
|
|
|
|
|
|
|
def test_defvjp(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x * y)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(val_ans, np.sin(3. * 4.), check_dtypes=False)
|
2020-03-23 14:29:22 -07:00
|
|
|
self.assertAllClose(grad_ans, 0., check_dtypes=False)
|
|
|
|
|
|
|
|
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
|
|
|
|
self.assertAllClose(ans_0, 0., check_dtypes=False)
|
|
|
|
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_higher_order(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(2. * x)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
api.defvjp(foo, lambda g, _, x: g * jnp.cos(x))
|
2020-03-23 14:29:22 -07:00
|
|
|
ans = api.grad(api.grad(foo))(2.)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = api.grad(api.grad(jnp.sin))(2.)
|
2020-03-23 14:29:22 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_use_ans(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x * y)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + jnp.cos(ans))
|
2020-03-23 14:29:22 -07:00
|
|
|
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(val_ans, np.sin(3. * 4.), check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 3. * 4. + np.cos(np.sin(3. * 4)),
|
2020-03-23 14:29:22 -07:00
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
# def test_defjvp_closure_error(self):
|
|
|
|
# def foo(x):
|
|
|
|
# @api.custom_transforms
|
|
|
|
# def bar(y):
|
|
|
|
# return x * y
|
|
|
|
|
|
|
|
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
|
|
|
|
# return bar(x)
|
|
|
|
# jtu.check_raises(
|
|
|
|
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
|
|
|
|
# "Detected differentiation with respect to closed-over values with "
|
|
|
|
# "custom JVP rule, which isn't supported.")
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
# def test_defvjp_closure_error(self):
|
|
|
|
# def foo(x):
|
|
|
|
# @api.custom_transforms
|
|
|
|
# def bar(y):
|
|
|
|
# return x * y
|
|
|
|
|
|
|
|
# api.defvjp(bar, lambda g, ans, y: x * y)
|
|
|
|
# return bar(x)
|
|
|
|
# jtu.check_raises(
|
|
|
|
# lambda: grad(foo)(1.,), ValueError,
|
|
|
|
# "Detected differentiation w.r.t. variables from outside "
|
|
|
|
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
|
|
|
|
# "defvjp_all only support differentiation w.r.t. positional arguments.")
|
|
|
|
|
|
|
|
def test_custom_transforms_eval_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans = f((1, 2))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
|
|
|
|
def test_custom_transforms_jit_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans = jit(f)((1, 2))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
|
|
|
|
def test_custom_transforms_jit_with_pytrees_consts(self):
|
|
|
|
# The purpose of this test is to exercise the custom_transforms default
|
|
|
|
# translation rule in how it deals with constants that are too large to be
|
|
|
|
# treated as literals (at the time of writing).
|
2020-05-05 14:59:16 -04:00
|
|
|
z = np.arange(10.)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': z * b}
|
|
|
|
|
|
|
|
ans = jit(f)((1, 2))
|
|
|
|
self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_transforms_jvp_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
|
|
|
|
|
|
|
|
def test_custom_transforms_vmap_with_pytrees(self):
|
|
|
|
raise unittest.SkipTest("Test deprecated custom_transforms")
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = api.vmap(f)((np.arange(3), np.ones((3, 2))))
|
|
|
|
expected = {'hi': 2 * np.arange(3), 'bye': 2 * np.ones((3, 2))}
|
2020-03-23 14:29:22 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_transforms_jvp_with_closure(self):
|
|
|
|
def f(x):
|
|
|
|
@api.custom_transforms
|
|
|
|
def g(y):
|
|
|
|
return x * y
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
ans = api.grad(f)(1.)
|
|
|
|
expected = 2.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_gradient(self):
|
|
|
|
@api.custom_gradient
|
|
|
|
def f(x):
|
|
|
|
return x ** 2, lambda g: (g * x,)
|
|
|
|
|
|
|
|
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_vjp_zeros(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x, y):
|
|
|
|
return 2 * x, 3 * y
|
|
|
|
|
|
|
|
def f_vjp(x, y):
|
|
|
|
return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
|
|
|
|
|
|
|
|
api.defvjp_all(f, f_vjp, )
|
|
|
|
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
|
|
|
|
|
|
|
|
def test_custom_transforms_vjp_nones(self):
|
2020-05-01 09:16:31 +03:00
|
|
|
core.skip_checks = True # Fails with checks
|
|
|
|
# issue raised by jsnoek@ and jumper@
|
2020-03-23 14:29:22 -07:00
|
|
|
@jax.custom_transforms
|
|
|
|
def solve(a, b):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(jnp.linalg.inv(a), b)
|
2020-03-23 14:29:22 -07:00
|
|
|
# print(solve(a, b))
|
|
|
|
|
|
|
|
def solve_vjp(a, b):
|
|
|
|
x = solve(a, b)
|
|
|
|
def vjp(x_tangent):
|
2020-05-05 14:59:16 -04:00
|
|
|
dx = jnp.dot(solve(a, x_tangent), x.T)
|
2020-03-23 14:29:22 -07:00
|
|
|
out = (dx, b * 0.)
|
|
|
|
return out
|
|
|
|
return x, vjp
|
|
|
|
jax.defvjp_all(solve, solve_vjp)
|
2020-05-05 14:59:16 -04:00
|
|
|
gf = grad(lambda a,b: jnp.sum(solve(a, b)))
|
2020-03-23 14:29:22 -07:00
|
|
|
|
|
|
|
n = 3
|
2020-05-05 14:59:16 -04:00
|
|
|
a_in = jnp.linspace(0, 1, n)[:, None]
|
|
|
|
a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1
|
|
|
|
real_x = np.random.RandomState(0).randn(n)
|
|
|
|
b = jnp.dot(a + jnp.eye(a.shape[0]), real_x)
|
2020-03-23 14:29:22 -07:00
|
|
|
print(gf(a, b)) # doesn't crash
|
|
|
|
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
class BufferDonationTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
# === pmap ===
|
|
|
|
|
2020-07-20 14:59:13 +02:00
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
def test_pmap_donate_argnums_invalidates_input(self):
|
|
|
|
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
|
|
|
|
n = jax.local_device_count()
|
|
|
|
x = api.pmap(lambda x: x)(jnp.ones([n]))
|
|
|
|
y = move(x)
|
|
|
|
self.assertDeleted(x)
|
|
|
|
np.testing.assert_allclose(y, [1.] * n)
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
def test_pmap_nested_donate_ignored(self):
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x))
|
|
|
|
a = api.pmap(lambda x: x)(jnp.array([1]))
|
2020-06-23 09:39:45 -07:00
|
|
|
|
|
|
|
# NOTE(mattjj): stopped raising error here and instead just ignored
|
|
|
|
# with self.assertRaisesRegex(ValueError, "nested.*not supported"):
|
|
|
|
# pmap_fun(a)
|
|
|
|
|
|
|
|
pmap_fun(a) # doesn't crash
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
|
|
|
|
assertDeleted = lambda self, x: self._assertDeleted(x, True)
|
|
|
|
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)
|
|
|
|
|
|
|
|
def _assertDeleted(self, x, deleted):
|
|
|
|
if hasattr(x, "device_buffer"):
|
|
|
|
self.assertEqual(x.device_buffer.is_deleted(), deleted)
|
|
|
|
else:
|
|
|
|
for buffer in x.device_buffers:
|
|
|
|
self.assertEqual(buffer.is_deleted(), deleted)
|
2020-03-23 14:29:22 -07:00
|
|
|
|
2020-08-19 18:39:25 +02:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|