2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2019-05-20 10:15:20 -07:00
|
|
|
import collections
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from contextlib import contextmanager
|
2019-12-11 02:48:51 +00:00
|
|
|
import copy
|
2019-07-24 21:45:56 +03:00
|
|
|
from functools import partial
|
2020-02-15 06:35:49 +01:00
|
|
|
import re
|
2019-08-09 13:12:44 -04:00
|
|
|
import unittest
|
2019-08-22 09:22:57 -07:00
|
|
|
import warnings
|
2019-10-30 14:57:00 -07:00
|
|
|
import weakref
|
2018-11-21 13:20:44 -08:00
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
from absl import logging
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from absl.testing import absltest, parameterized
|
2019-05-20 10:15:20 -07:00
|
|
|
import numpy as onp
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-08 13:17:55 -05:00
|
|
|
import concurrent.futures
|
2019-08-09 13:12:44 -04:00
|
|
|
|
|
|
|
import jax
|
2018-11-17 18:03:33 -08:00
|
|
|
import jax.numpy as np
|
2019-07-27 15:46:14 -07:00
|
|
|
from jax import jit, grad, device_put, jacfwd, jacrev, hessian
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
from jax import api, core, lax, lax_reference
|
2019-12-06 22:28:41 -08:00
|
|
|
from jax.core import Primitive
|
2019-06-03 07:17:37 -07:00
|
|
|
from jax.interpreters import ad
|
2019-12-10 14:10:57 -08:00
|
|
|
from jax.interpreters import xla
|
2019-07-23 02:48:53 -07:00
|
|
|
from jax.lib import xla_bridge as xb
|
2019-05-20 10:15:20 -07:00
|
|
|
from jax import test_util as jtu
|
2019-08-21 20:36:47 -07:00
|
|
|
from jax import tree_util
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
2019-08-22 09:22:57 -07:00
|
|
|
FLAGS = config.FLAGS
|
2018-12-12 09:00:39 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
class APITest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_grad_argnums(self):
|
|
|
|
def f(x, y, z, flag=False):
|
|
|
|
assert flag
|
|
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
|
|
|
|
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
|
|
|
|
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
|
|
|
|
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
|
|
|
|
|
2019-01-22 15:34:09 -08:00
|
|
|
def test_value_and_grad_argnums(self):
|
2018-12-20 10:17:42 -08:00
|
|
|
def f(x, y, z, flag=False):
|
|
|
|
assert flag
|
|
|
|
return 1.0 * x + 2.0 * y + 3.0 * z
|
|
|
|
|
|
|
|
y = f(1.0, 1.0, 1.0, flag=True)
|
Undefined name: 'value_and_grad' in ./tests/api_test.py
[flake8](http://flake8.pycqa.org) testing of https://github.com/google/jax on Python 3.7.1
$ __flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics__
```
./tests/api_test.py:54:12: F821 undefined name 'value_and_grad'
assert value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
^
./tests/api_test.py:55:12: F821 undefined name 'value_and_grad'
assert value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
^
./tests/api_test.py:56:12: F821 undefined name 'value_and_grad'
assert value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
^
3 F821 undefined name 'value_and_grad'
3
```
__E901,E999,F821,F822,F823__ are the "_showstopper_" [flake8](http://flake8.pycqa.org) issues that can halt the runtime with a SyntaxError, NameError, etc. These 5 are different from most other flake8 issues which are merely "style violations" -- useful for readability but they do not effect runtime safety.
* F821: undefined name `name`
* F822: undefined name `name` in `__all__`
* F823: local variable name referenced before assignment
* E901: SyntaxError or IndentationError
* E999: SyntaxError -- failed to compile a file into an Abstract Syntax Tree
2019-01-22 23:51:37 +01:00
|
|
|
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
|
|
|
|
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
|
|
|
|
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_jit_static_args(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
def f(x, y, z, flag=False, flag2=False):
|
|
|
|
assert flag
|
|
|
|
side.append(None)
|
|
|
|
return 100*x + 10*y + z
|
|
|
|
|
2019-04-10 22:09:14 -07:00
|
|
|
f1 = jit(f, static_argnums=(3, 4))
|
|
|
|
assert f1(1, 2, 3, True, False) == 123
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f1(2, 1, 3, True, False) == 213
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f1(2, 1, 3, True, True) == 213
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
|
|
|
|
|
|
|
side[:] = []
|
2019-04-10 22:09:14 -07:00
|
|
|
f2 = jit(f, static_argnums=(0, 2, 3, 4))
|
|
|
|
assert f2(1, 2, 3, True, False) == 123
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f2(1, 3, 3, True, False) == 133
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 1
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f2(2, 2, 3, True, False) == 223
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f2(2, 4, 3, True, False) == 243
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 2
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f2(2, 4, 3, True, True) == 243
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 3
|
2019-04-10 22:09:14 -07:00
|
|
|
assert f2(2, 5, 3, True, True) == 253
|
2018-11-17 18:03:33 -08:00
|
|
|
assert len(side) == 3
|
|
|
|
|
2019-04-10 22:09:14 -07:00
|
|
|
def test_jit_kwargs(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
def f(x, y, z):
|
|
|
|
side.append(None)
|
|
|
|
return 100*x + 10*y + z
|
|
|
|
|
2019-04-11 08:07:32 -07:00
|
|
|
f = jit(f)
|
|
|
|
assert f(1, 2, 3) == 123
|
2019-04-10 22:09:14 -07:00
|
|
|
assert len(side) == 1
|
2019-04-11 08:07:32 -07:00
|
|
|
assert f(1, 2, 3) == 123
|
|
|
|
assert len(side) == 1
|
|
|
|
|
|
|
|
assert f(1, 2, z=3) == 123
|
|
|
|
assert len(side) == 2 # actually recompiles from kwarg
|
|
|
|
assert f(1, 2, z=3) == 123
|
|
|
|
assert len(side) == 2 # but should still cache
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2019-04-11 08:07:32 -07:00
|
|
|
f(1, 2, z=onp.zeros(3)) # doesn't crash
|
2019-04-10 22:09:14 -07:00
|
|
|
|
2020-03-17 17:02:22 -04:00
|
|
|
def test_jit_with_many_args_works(self):
|
2019-09-18 17:21:57 -07:00
|
|
|
@jit
|
|
|
|
def f(args_list):
|
|
|
|
return sum(args_list)
|
|
|
|
|
2020-03-17 17:02:22 -04:00
|
|
|
self.assertEqual(f(list(range(500))), sum(range(500)))
|
2019-09-18 17:21:57 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def test_grad_of_jit(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
side.append(None)
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
assert grad(f)(1.0) == 2.0
|
|
|
|
assert len(side) == 1
|
|
|
|
assert grad(f)(2.0) == 4.0
|
|
|
|
assert len(side) == 1
|
|
|
|
|
|
|
|
def test_jit_of_grad(self):
|
|
|
|
side = []
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
side.append(None)
|
|
|
|
return x * x
|
|
|
|
|
|
|
|
g = jit(grad(f))
|
|
|
|
assert g(1.0) == 2.0
|
|
|
|
assert len(side) == 1
|
|
|
|
assert g(2.0) == 4.0
|
|
|
|
assert len(side) == 1
|
|
|
|
|
|
|
|
def test_bad_input(self):
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: grad(f)("foo"))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type",
|
|
|
|
lambda: jit(f)("foo"))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_tuple_output(self):
|
|
|
|
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_unit_output(self):
|
|
|
|
jtu.check_raises(lambda: grad(lambda x: ())(onp.zeros(3)), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_grad_nonscalar_output(self):
|
|
|
|
jtu.check_raises(lambda: grad(lambda x: x)(onp.zeros(3)), TypeError,
|
2018-12-06 21:47:47 -05:00
|
|
|
"Gradient only defined for scalar-output functions. ")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_unwrapped_numpy(self):
|
|
|
|
def f(x):
|
|
|
|
return onp.exp(x)
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception,
|
|
|
|
"Tracer can't be used with raw numpy functions. "
|
|
|
|
"You might have\n import numpy as np\ninstead of\n"
|
|
|
|
" import jax.numpy as np")
|
|
|
|
|
|
|
|
def test_binop_mismatch(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
2019-08-23 17:05:32 -07:00
|
|
|
jtu.check_raises(
|
|
|
|
lambda: f(np.zeros(3), np.zeros(4)),
|
|
|
|
TypeError,
|
|
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
|
|
|
|
|
|
|
jtu.check_raises(
|
|
|
|
lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
|
|
|
|
TypeError,
|
|
|
|
"add got incompatible shapes for broadcasting: (3,), (4,).")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_dot_mismatch(self):
|
|
|
|
def f(x, y):
|
|
|
|
return np.dot(x, y)
|
|
|
|
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).",
|
|
|
|
lambda: grad(f)(onp.zeros(3), onp.zeros(4)))
|
2020-04-03 21:33:32 -07:00
|
|
|
|
2020-01-27 15:44:33 -08:00
|
|
|
def test_abstract_error_message(self):
|
|
|
|
for castfun in [float, complex, int]:
|
|
|
|
def f(x):
|
|
|
|
return castfun(x)
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2020-04-12 15:35:35 -04:00
|
|
|
f"Try using `x.astype\\({castfun.__name__}\\)` instead.",
|
2020-01-27 15:44:33 -08:00
|
|
|
lambda: jit(f)(1.0))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_switch_value_jit(self):
|
|
|
|
def f(x):
|
|
|
|
y = x > 0
|
|
|
|
if y:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return -x
|
|
|
|
|
|
|
|
assert grad(f)(1.0) == 1.0
|
|
|
|
assert grad(f)(-1.0) == -1.0
|
2020-04-22 10:25:06 +03:00
|
|
|
with self.assertRaisesRegex(core.ConcretizationTypeError,
|
|
|
|
"Abstract tracer value encountered where concrete value is expected"):
|
|
|
|
jit(f)(1)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_range_err(self):
|
|
|
|
def f(x, n):
|
|
|
|
for i in range(n):
|
|
|
|
x = x + i
|
|
|
|
return x
|
|
|
|
|
|
|
|
assert jit(f, static_argnums=(1,))(0, 5) == 10
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2018-12-18 09:58:42 -08:00
|
|
|
"('JaxprTracer' object cannot be interpreted as an integer"
|
2019-11-14 16:00:55 -05:00
|
|
|
"|Abstract value passed to .*)",
|
|
|
|
lambda: jit(f)(0, 5))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_casts(self):
|
2020-04-03 21:33:32 -07:00
|
|
|
for castfun in [hex, oct, int]:
|
2018-11-17 18:03:33 -08:00
|
|
|
f = lambda x: castfun(x)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2018-12-18 09:58:42 -08:00
|
|
|
"('JaxprTracer' object cannot be interpreted as an integer"
|
2020-04-22 10:25:06 +03:00
|
|
|
"|Abstract tracer value encountered where concrete value is expected .*)", lambda: jit(f)(0))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def test_unimplemented_interpreter_rules(self):
|
2019-12-06 22:28:41 -08:00
|
|
|
foo_p = Primitive('foo')
|
2018-11-17 18:03:33 -08:00
|
|
|
def foo(x):
|
|
|
|
return foo_p.bind(x)
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
|
|
|
|
"Evaluation rule for 'foo' not implemented")
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
|
|
|
"Abstract evaluation for 'foo' not implemented")
|
|
|
|
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
|
|
|
"Forward-mode differentiation rule for 'foo' not implemented")
|
|
|
|
|
2019-02-22 07:56:13 -08:00
|
|
|
foo_p.def_abstract_eval(lambda x: x)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
"XLA translation rule for primitive 'foo' not found")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
foo_p.def_impl(lambda x: x)
|
2019-06-03 07:17:37 -07:00
|
|
|
ad.defjvp(foo_p, lambda g, x: foo(g))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
|
2020-01-15 15:00:38 -08:00
|
|
|
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 18:07:24 -08:00
|
|
|
def test_device_put_and_get(self):
|
|
|
|
x = onp.arange(12.).reshape((3, 4)).astype("float32")
|
2019-07-27 15:46:14 -07:00
|
|
|
dx = api.device_put(x)
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(dx, xla.DeviceArray)
|
2019-07-27 15:46:14 -07:00
|
|
|
x2 = api.device_get(dx)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(x2, onp.ndarray)
|
2018-11-21 18:07:24 -08:00
|
|
|
assert onp.all(x == x2)
|
|
|
|
|
|
|
|
y = [x, (2 * x, 3 * x)]
|
2019-07-27 15:46:14 -07:00
|
|
|
dy = api.device_put(y)
|
|
|
|
y2 = api.device_get(dy)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(y2, list)
|
|
|
|
self.assertIsInstance(y2[0], onp.ndarray)
|
2018-11-21 18:07:24 -08:00
|
|
|
assert onp.all(y2[0] == x)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(y2[1], tuple)
|
|
|
|
self.assertIsInstance(y2[1][0], onp.ndarray)
|
2018-11-21 18:07:24 -08:00
|
|
|
assert onp.all(y2[1][0] == 2 * x)
|
2019-07-23 02:48:53 -07:00
|
|
|
self.assertIsInstance(y2[1][1], onp.ndarray)
|
2018-11-21 18:07:24 -08:00
|
|
|
assert onp.all(y2[1][1] == 3 * x)
|
|
|
|
|
2020-03-13 13:35:18 -04:00
|
|
|
@parameterized.parameters([(3,)], [(2, 0)])
|
|
|
|
def test_device_put_across_devices(self, shape):
|
|
|
|
if len(api.local_devices()) < 2:
|
2019-10-11 14:07:16 -07:00
|
|
|
raise unittest.SkipTest("this test requires multiple devices")
|
2020-03-13 13:35:18 -04:00
|
|
|
d1, d2 = api.local_devices()[:2]
|
|
|
|
data = onp.random.randn(*shape).astype(onp.float32)
|
|
|
|
x = api.device_put(data, device=d1)
|
2019-10-11 14:07:16 -07:00
|
|
|
self.assertEqual(x.device_buffer.device(), d1)
|
|
|
|
y = api.device_put(x, device=d2)
|
|
|
|
self.assertEqual(y.device_buffer.device(), d2)
|
2020-03-13 13:35:18 -04:00
|
|
|
onp.testing.assert_array_equal(data, onp.array(y))
|
2019-10-11 14:07:16 -07:00
|
|
|
# Make sure these don't crash
|
|
|
|
api.device_put(x)
|
|
|
|
api.device_put(y)
|
|
|
|
|
2019-11-25 16:23:40 -08:00
|
|
|
@jtu.skip_on_devices("cpu")
|
|
|
|
def test_device_put_across_platforms(self):
|
|
|
|
default_device = jax.devices()[0]
|
|
|
|
cpu_device = jax.devices("cpu")[0]
|
|
|
|
|
|
|
|
onp_arr = onp.array([1,2,3])
|
|
|
|
scalar = 1
|
|
|
|
device_arr = np.array([1,2,3])
|
|
|
|
assert device_arr.device_buffer.device() is default_device
|
|
|
|
|
|
|
|
for val in [onp_arr, device_arr, scalar]:
|
|
|
|
x = api.device_put(val, device=cpu_device)
|
|
|
|
self.assertEqual(x.device_buffer.device(), cpu_device)
|
|
|
|
|
2020-04-16 13:38:38 -04:00
|
|
|
def test_jit_on_all_devices(self):
|
|
|
|
# Verifies we can run the same computation on every device present, even
|
|
|
|
# if they are, for example, different models of GPU.
|
|
|
|
data = onp.random.rand(1000).astype(onp.float32)
|
|
|
|
f = api.jit(np.negative)
|
|
|
|
for device in jax.local_devices():
|
|
|
|
x = device_put(data, device=device)
|
|
|
|
onp.testing.assert_array_equal(-data, f(x))
|
|
|
|
|
2018-12-12 09:00:39 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2018-12-11 16:24:20 -08:00
|
|
|
def test_jacobian(self):
|
|
|
|
R = onp.random.RandomState(0).randn
|
|
|
|
A = R(4, 3)
|
|
|
|
x = R(3)
|
|
|
|
|
|
|
|
f = lambda x: np.dot(A, x)
|
|
|
|
assert onp.allclose(jacfwd(f)(x), A)
|
|
|
|
assert onp.allclose(jacrev(f)(x), A)
|
|
|
|
|
|
|
|
f = lambda x: np.tanh(np.dot(A, x))
|
|
|
|
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
|
|
|
|
|
2019-01-07 08:54:14 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_hessian(self):
|
|
|
|
R = onp.random.RandomState(0).randn
|
|
|
|
A = R(4, 4)
|
|
|
|
x = R(4)
|
|
|
|
|
|
|
|
f = lambda x: np.dot(x, np.dot(A, x))
|
|
|
|
assert onp.allclose(hessian(f)(x), A + A.T)
|
|
|
|
|
2019-01-06 11:59:33 -08:00
|
|
|
def test_std_basis(self):
|
|
|
|
basis = api._std_basis(np.zeros(3))
|
|
|
|
assert getattr(basis, "shape", None) == (3, 3)
|
|
|
|
assert onp.allclose(basis, onp.eye(3))
|
|
|
|
|
|
|
|
basis = api._std_basis(np.zeros((3, 3)))
|
|
|
|
assert getattr(basis, "shape", None) == (9, 3, 3)
|
|
|
|
assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3))
|
|
|
|
|
|
|
|
basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))])
|
|
|
|
assert isinstance(basis, list) and len(basis) == 2
|
|
|
|
assert getattr(basis[0], "shape", None) == (16,)
|
|
|
|
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
|
|
|
|
assert getattr(basis[1][0], "shape", None) == (16, 3)
|
|
|
|
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
|
|
|
|
|
2019-01-07 08:54:14 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_jacobian_on_pytrees(self):
|
|
|
|
for jacfun in [jacfwd, jacrev]:
|
|
|
|
ans = jacfun(lambda x, y: (x, y))(0., 1.)
|
|
|
|
expected = (1., 0.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
|
|
|
|
expected = (0., 1.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
|
|
|
|
expected = ((1., 0.),
|
|
|
|
(0., 1.),)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
|
|
|
|
expected = ((1., 0., 0.),
|
|
|
|
(0., 1., 0.))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
R = onp.random.RandomState(0).randn
|
|
|
|
x = R(2)
|
|
|
|
y = R(3)
|
|
|
|
ans = jacfun(lambda x, y: {'x': x, 'xy': np.outer(x, y)})(x, y)
|
|
|
|
expected = {'x': onp.eye(2),
|
|
|
|
'xy': onp.kron(onp.eye(2), y[:, None]).reshape(2, 3, 2)}
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_hessian_on_pytrees(self):
|
|
|
|
ans = hessian(lambda x: np.array(x)**2)((1., 2.))
|
|
|
|
expected = ((onp.array([2., 0.]), onp.array([0., 0.])),
|
|
|
|
(onp.array([0., 0.]), onp.array([0., 2.])))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-09-23 13:35:52 -07:00
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_issue1372(self):
|
|
|
|
def quad(x):
|
|
|
|
return np.dot(x, x)
|
|
|
|
|
|
|
|
def f(x, u):
|
|
|
|
return quad(x) + quad(u)
|
|
|
|
|
|
|
|
x, u = np.ones(5), np.ones(2)
|
|
|
|
|
|
|
|
rev = jacrev
|
|
|
|
fwd = jacfwd
|
|
|
|
|
|
|
|
# Diagonal entries
|
|
|
|
self.assertEqual(rev(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(rev(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(fwd(rev(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(fwd(fwd(f, 0), 0)(x, u).shape, (5, 5))
|
|
|
|
self.assertEqual(rev(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(rev(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(fwd(rev(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
self.assertEqual(fwd(fwd(f, 1), 1)(x, u).shape, (2, 2))
|
|
|
|
|
|
|
|
# Off-diagonal entries by reverse-mode on the outside
|
|
|
|
self.assertEqual(rev(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(rev(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(rev(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
self.assertEqual(rev(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
|
|
|
|
# Off-diagonal entries by forward-mode on the outside
|
|
|
|
self.assertEqual(fwd(rev(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(fwd(fwd(f, 1), 0)(x, u).shape, (2, 5))
|
|
|
|
self.assertEqual(fwd(rev(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
self.assertEqual(fwd(fwd(f, 0), 1)(x, u).shape, (5, 2))
|
|
|
|
|
2019-02-06 19:44:12 -08:00
|
|
|
def test_disable_jit(self):
|
|
|
|
effects = []
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def f(x):
|
|
|
|
effects.append(1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
with api.disable_jit():
|
|
|
|
f(2)
|
|
|
|
f(2)
|
|
|
|
assert len(effects) == 2
|
|
|
|
|
|
|
|
f(2)
|
|
|
|
f(2)
|
|
|
|
assert len(effects) == 3
|
|
|
|
|
2019-02-25 13:48:01 -08:00
|
|
|
def test_large_device_constant(self):
|
2019-02-25 13:49:40 -08:00
|
|
|
ans = jit(lambda x: 2 * x)(np.ones(int(2e6))) # doesn't crash
|
2019-03-14 21:59:31 -04:00
|
|
|
self.assertAllClose(ans, onp.ones(int(2e6)) * 2., check_dtypes=False)
|
2019-02-25 13:48:01 -08:00
|
|
|
|
2019-03-07 14:08:02 -08:00
|
|
|
def test_grad_and_aux_basic(self):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**2]), has_aux=True)(3.)
|
2019-03-08 14:14:50 -05:00
|
|
|
self.assertAllClose(g, grad(lambda x: x**3)(3.), check_dtypes=True)
|
2019-12-09 21:18:39 -05:00
|
|
|
self.assertAllClose(aux, [9.], check_dtypes=False)
|
2019-03-07 14:08:02 -08:00
|
|
|
|
|
|
|
def test_grad_and_aux_nested(self):
|
|
|
|
def f(x):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
|
|
|
return aux[0]
|
|
|
|
|
|
|
|
f2 = lambda x: x**3
|
|
|
|
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
|
|
|
|
return aux[0] * np.sin(x)
|
|
|
|
|
|
|
|
f2 = lambda x: x**3 * np.sin(x)
|
|
|
|
|
|
|
|
self.assertEqual(grad(f)(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(f))(4.), grad(f2)(4.))
|
|
|
|
self.assertEqual(jit(grad(jit(f)))(4.), grad(f2)(4.))
|
|
|
|
|
2019-03-07 14:48:05 -08:00
|
|
|
def test_grad_and_aux_constant(self):
|
|
|
|
g, aux = grad(lambda x: (x**3, [4.]), has_aux=True)(4.)
|
|
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
|
|
self.assertEqual(aux, [4.])
|
|
|
|
|
2019-03-07 14:49:29 -08:00
|
|
|
g, aux = grad(lambda x: (x**3, [x**2, 4.]), has_aux=True)(4.)
|
|
|
|
self.assertEqual(g, grad(lambda x: x**3)(4.))
|
|
|
|
self.assertEqual(aux, [4.**2, 4.])
|
|
|
|
|
2020-01-06 18:08:00 -08:00
|
|
|
def test_grad_and_aux_no_tracers(self):
|
|
|
|
# see https://github.com/google/jax/issues/1950
|
|
|
|
def f(x):
|
|
|
|
aux = dict(identity=x, p1=x+1)
|
|
|
|
return x ** 2, aux
|
|
|
|
|
|
|
|
_, aux = jax.grad(f, has_aux=True)(3.)
|
|
|
|
self.assertIsInstance(aux, dict)
|
|
|
|
for val in aux.values():
|
|
|
|
self.assertNotIsInstance(val, core.Tracer)
|
|
|
|
|
2019-11-14 15:37:33 -05:00
|
|
|
def test_jvp_mismatched_arguments(self):
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
|
|
"structure"),
|
|
|
|
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), ()))
|
2019-11-27 14:24:41 +01:00
|
|
|
# If primals and tangents must both be tuples or both lists
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
("primal and tangent arguments to jax.jvp must have the same tree "
|
|
|
|
"structure"),
|
|
|
|
lambda: api.jvp(lambda x, y: x * y, (onp.float32(2),), [onp.float32(2)]))
|
2019-11-14 15:37:33 -05:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"primal and tangent arguments to jax.jvp must have equal types",
|
|
|
|
lambda: api.jvp(lambda x: -x, (onp.float16(2),), (onp.float32(4),)))
|
|
|
|
|
2019-11-27 13:12:24 +01:00
|
|
|
def test_jvp_non_tuple_arguments(self):
|
|
|
|
def f(x, y): return x + y
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2019-11-27 14:24:41 +01:00
|
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found float and tuple.",
|
2020-01-18 08:26:23 -05:00
|
|
|
lambda: api.jvp(f, 0., (1.,)))
|
2019-11-27 13:12:24 +01:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
2019-11-27 14:24:41 +01:00
|
|
|
"primal and tangent arguments to jax.jvp must be tuples or lists; found tuple and ndarray.",
|
2020-01-18 08:26:23 -05:00
|
|
|
lambda: api.jvp(f, (0.,), onp.array([1., 2.])))
|
2019-11-27 13:12:24 +01:00
|
|
|
|
2019-11-14 15:37:33 -05:00
|
|
|
def test_vjp_mismatched_arguments(self):
|
|
|
|
_, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Tree structure of cotangent input.*does not match",
|
|
|
|
lambda: pullback((onp.float32(7), onp.float32(100))))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Type of cotangent input to vjp pullback.*does not match type",
|
|
|
|
lambda: pullback((onp.float16(42))))
|
|
|
|
|
2020-01-05 04:32:48 +01:00
|
|
|
def test_jvp_jit_cached(self):
|
|
|
|
"""Bug in caching in presence of JVP and JIT."""
|
|
|
|
|
|
|
|
def func(x):
|
|
|
|
def inner(y):
|
|
|
|
return y * x
|
|
|
|
|
|
|
|
# Must have two calls to the inner jit (the second one hits the cache)
|
|
|
|
res1 = api.jit(inner)(4.)
|
|
|
|
res2 = api.jit(inner)(5.)
|
|
|
|
return res1 + res2
|
|
|
|
|
|
|
|
self.assertAllClose((45., 9.), api.jvp(func, (5.,), (1.,)), check_dtypes=True)
|
|
|
|
|
|
|
|
|
2019-04-12 12:01:19 -07:00
|
|
|
def test_complex_grad_raises_error(self):
|
|
|
|
self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
|
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_holomorphic_grad(self):
|
|
|
|
out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j)
|
|
|
|
expected = 2.0327230070196656 - 3.0518977991518j
|
|
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_nonholomorphic_grad(self):
|
2019-04-12 12:01:19 -07:00
|
|
|
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def f(z):
|
|
|
|
return np.sum(np.cos(np.abs(z)))
|
|
|
|
|
|
|
|
ans = grad(f)(zs)
|
|
|
|
expected = onp.array([ 0. +0.j,
|
|
|
|
-0.80430663+0.40215331j,
|
|
|
|
-0.70368982+0.35184491j,
|
|
|
|
0.1886467 -0.09432335j,
|
|
|
|
0.86873727-0.43436864j])
|
2019-11-16 13:51:42 -05:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False,
|
|
|
|
atol=jtu.default_gradient_tolerance,
|
|
|
|
rtol=jtu.default_gradient_tolerance)
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_complex_output_jacrev_raises_error(self):
|
|
|
|
self.assertRaises(TypeError, lambda: jacrev(lambda x: np.sin(x))(1 + 2j))
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-04-13 13:22:45 -07:00
|
|
|
def test_nonholomorphic_jacrev(self):
|
2019-04-12 12:01:19 -07:00
|
|
|
# code based on https://github.com/google/jax/issues/603
|
|
|
|
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
2019-04-13 13:22:45 -07:00
|
|
|
|
2019-04-12 12:01:19 -07:00
|
|
|
def f(z):
|
|
|
|
return np.cos(np.linalg.norm(2 * z))
|
2019-04-13 13:22:45 -07:00
|
|
|
|
|
|
|
ans = jacrev(f)(zs)
|
|
|
|
expected = grad(f)(zs)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
|
|
|
|
|
|
|
def test_complex_input_jacfwd_raises_error(self):
|
|
|
|
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
|
2019-04-12 12:01:19 -07:00
|
|
|
|
2019-05-03 08:14:03 -07:00
|
|
|
def test_legacy_devicearray_repr(self):
|
|
|
|
dx = device_put(3.)
|
|
|
|
str(dx.item()) # doesn't crash
|
|
|
|
|
2019-05-02 19:27:22 -07:00
|
|
|
def test_devicearray_repr(self):
|
|
|
|
x = device_put(np.zeros(3))
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
2019-05-02 19:27:22 -07:00
|
|
|
repr(x) # doesn't crash
|
|
|
|
|
|
|
|
x = device_put(np.ones(3) + 1j * np.ones(3))
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
2019-05-02 19:27:22 -07:00
|
|
|
repr(x) # doesn't crash
|
|
|
|
|
2019-05-30 09:48:38 -04:00
|
|
|
def test_devicearray_delete(self):
|
|
|
|
x = device_put(1.)
|
|
|
|
x.delete()
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(ValueError, "DeviceValue has been deleted.",
|
2019-11-14 16:00:55 -05:00
|
|
|
lambda: repr(x))
|
2019-05-30 09:48:38 -04:00
|
|
|
|
2019-06-03 12:05:28 -04:00
|
|
|
def test_devicearray_block_until_ready(self):
|
|
|
|
x = device_put(1.)
|
2019-09-05 10:16:20 -04:00
|
|
|
y = x.block_until_ready()
|
|
|
|
# Tests mostly that block_until_ready() does not produce an error.
|
|
|
|
self.assertTrue(y is x)
|
2019-06-03 12:05:28 -04:00
|
|
|
|
2019-05-20 10:15:20 -07:00
|
|
|
def test_namedtuple_transparency(self):
|
|
|
|
# See https://github.com/google/jax/issues/446
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
|
|
|
|
def f(pt):
|
|
|
|
return np.sqrt(pt.x ** 2 + pt.y ** 2)
|
|
|
|
|
|
|
|
pt = Point(1., 2.)
|
|
|
|
|
|
|
|
f(pt) # doesn't crash
|
|
|
|
g = api.grad(f)(pt)
|
|
|
|
self.assertIsInstance(g, Point)
|
|
|
|
|
|
|
|
f_jit = api.jit(f)
|
|
|
|
self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False)
|
|
|
|
|
2019-06-03 07:22:32 -07:00
|
|
|
def test_namedtuple_subclass_transparency(self):
|
|
|
|
# See https://github.com/google/jax/issues/806
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
|
|
|
|
class ZeroPoint(Point):
|
|
|
|
def is_zero(self):
|
|
|
|
return (self.x == 0) and (self.y == 0)
|
|
|
|
|
|
|
|
pt = ZeroPoint(0., 0.)
|
|
|
|
|
|
|
|
def f(pt):
|
|
|
|
return 0. if pt.is_zero() else np.sqrt(pt.x ** 2 + pt.y ** 2)
|
|
|
|
|
|
|
|
f(pt) # doesn't crash
|
|
|
|
g = api.grad(f)(pt)
|
|
|
|
self.assertIsInstance(pt, ZeroPoint)
|
|
|
|
|
2020-02-11 14:11:48 +00:00
|
|
|
@parameterized.parameters(1, 2, 3)
|
|
|
|
def test_shape_dtype_struct(self, i):
|
|
|
|
s = api.ShapeDtypeStruct(shape=(i, 2, 3), dtype=np.float32)
|
|
|
|
self.assertEqual(s.shape, (i, 2, 3))
|
|
|
|
self.assertEqual(s.dtype, np.float32)
|
|
|
|
self.assertEqual(s.ndim, 3)
|
|
|
|
self.assertEqual(s.size, i * 2 * 3)
|
|
|
|
self.assertLen(s, i)
|
|
|
|
for f in (str, repr):
|
|
|
|
self.assertEqual(
|
|
|
|
f(s), "ShapeDtypeStruct(shape=({}, 2, 3), dtype=float32)".format(i))
|
|
|
|
|
|
|
|
def test_shape_dtype_struct_scalar(self):
|
|
|
|
s = api.ShapeDtypeStruct(shape=(), dtype=np.float32)
|
|
|
|
self.assertEmpty(s.shape)
|
|
|
|
self.assertEqual(s.size, 1)
|
|
|
|
self.assertEqual(s.ndim, 0)
|
|
|
|
with self.assertRaisesRegex(TypeError, "len[(][)] of unsized object"):
|
|
|
|
_ = len(s)
|
|
|
|
|
2019-06-01 09:34:33 -07:00
|
|
|
def test_eval_shape(self):
|
|
|
|
def fun(x, y):
|
|
|
|
return np.tanh(np.dot(x, y) + 3.)
|
|
|
|
|
|
|
|
x = np.ones((2, 3))
|
|
|
|
y = np.ones((3, 4))
|
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_constants(self):
|
|
|
|
def fun():
|
|
|
|
x = np.ones((2, 3))
|
|
|
|
y = np.ones((3, 4))
|
|
|
|
return np.tanh(np.dot(x, y) + 3.)
|
|
|
|
|
|
|
|
out_shape = api.eval_shape(fun)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2, 4))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_tuple_unpacking(self):
|
|
|
|
def fun(x, y):
|
|
|
|
a, b = x
|
|
|
|
return a + b + y
|
|
|
|
|
|
|
|
x = (np.ones(2), np.ones(2))
|
|
|
|
y = 3.
|
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_tuple_itemgetting(self):
|
|
|
|
def fun(x, y):
|
|
|
|
return x[0] + x[1] + y
|
|
|
|
|
|
|
|
x = (np.ones(2), np.ones(2))
|
|
|
|
y = 3.
|
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (2,))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
def test_eval_shape_output_dict(self):
|
2019-06-01 09:48:28 -07:00
|
|
|
def fun(x, y):
|
2019-06-01 09:34:33 -07:00
|
|
|
return {'hi': x[0] + x[1] + y}
|
|
|
|
|
|
|
|
x = (np.ones(2), np.ones(2))
|
|
|
|
y = 3.
|
2019-06-01 09:48:28 -07:00
|
|
|
out_shape = api.eval_shape(fun, x, y)
|
2019-08-21 20:36:47 -07:00
|
|
|
out_shape = tree_util.tree_map(onp.shape, out_shape)
|
2019-06-01 09:34:33 -07:00
|
|
|
|
|
|
|
self.assertEqual(out_shape, {'hi': (2,)})
|
|
|
|
|
|
|
|
def test_eval_shape_shape_error(self):
|
|
|
|
def fun(x, y):
|
|
|
|
return np.tanh(np.dot(x, y) + 3.)
|
|
|
|
|
|
|
|
x = np.ones((3, 3))
|
|
|
|
y = np.ones((4, 4))
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, lambda: api.eval_shape(fun, x, y))
|
|
|
|
|
2019-06-01 09:48:28 -07:00
|
|
|
def test_eval_shape_duck_typing(self):
|
|
|
|
def fun(A, b, x):
|
|
|
|
return np.dot(A, x) + b
|
|
|
|
|
|
|
|
class MyArgArray(object):
|
|
|
|
def __init__(self, shape, dtype):
|
|
|
|
self.shape = shape
|
|
|
|
self.dtype = dtype
|
|
|
|
|
|
|
|
A = MyArgArray((3, 4), np.float32)
|
|
|
|
b = MyArgArray((5,), np.float32)
|
|
|
|
x = MyArgArray((4, 5), np.float32)
|
|
|
|
out_shape = api.eval_shape(fun, A, b, x)
|
|
|
|
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertEqual(out_shape.shape, (3, 5))
|
2019-06-01 09:48:28 -07:00
|
|
|
|
2019-06-18 09:18:44 -07:00
|
|
|
def test_issue_871(self):
|
|
|
|
T = np.array([[1., 2.], [3., 4.], [5., 6.]])
|
|
|
|
x = np.array([1, 2, 3])
|
|
|
|
|
|
|
|
y, f_jvp = api.linearize(np.sum, x)
|
|
|
|
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
|
|
|
("linearized function called on tangent values "
|
|
|
|
"inconsistent with the original primal values."))
|
|
|
|
|
|
|
|
y, f_jvp = api.linearize(api.jit(np.sum), x)
|
|
|
|
jtu.check_raises(lambda: f_jvp(T), ValueError,
|
|
|
|
("linearized function called on tangent values "
|
|
|
|
"inconsistent with the original primal values."))
|
|
|
|
|
2019-06-18 21:23:52 -07:00
|
|
|
def test_partial_eval_lower(self):
|
|
|
|
# this is a simplified model of a bug that arose when we first used @jit in
|
|
|
|
# a jvp rule. it's in this file because we want to use make_jaxpr.
|
2019-12-31 10:38:45 -08:00
|
|
|
|
|
|
|
# NOTE(mattjj): I no longer understand what this was meant to test. My guess
|
|
|
|
# is it was related to staging out the broadcast into a jaxpr to be
|
|
|
|
# transposed, but after #1749 that's no longer a problem. After changing
|
|
|
|
# make_jaxpr (and jit) to stage out sub-calls fully, this test started to
|
|
|
|
# fail; I left it in as skipped because deleting tests feels wrong.
|
|
|
|
raise unittest.SkipTest("obsolete test")
|
|
|
|
|
2019-06-18 21:23:52 -07:00
|
|
|
@api.jit
|
|
|
|
def f(a, b, c):
|
|
|
|
a = lax.broadcast(a, (2,))
|
|
|
|
return lax.select(a, b, c)
|
|
|
|
|
|
|
|
a = onp.ones((3, 3), dtype=onp.bool_)
|
|
|
|
b = onp.ones((2, 3, 3))
|
|
|
|
c = onp.ones((2, 3, 3))
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c)
|
2020-02-05 15:38:25 +01:00
|
|
|
subjaxpr = next(eqn.params["call_jaxpr"] for eqn in jaxpr.jaxpr.eqns
|
|
|
|
if "call_jaxpr" in eqn.params)
|
2019-06-18 21:23:52 -07:00
|
|
|
self.assertEqual(len(subjaxpr.eqns), 1)
|
|
|
|
|
2019-06-24 10:45:42 -04:00
|
|
|
def test_grad_of_int_errors(self):
|
|
|
|
dfn = grad(lambda x: x ** 2)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
TypeError,
|
2019-06-24 11:29:06 -04:00
|
|
|
"Primal inputs to reverse-mode differentiation must be of float or "
|
2019-11-14 16:00:55 -05:00
|
|
|
"complex type, got type int..", lambda: dfn(3))
|
2019-06-24 10:45:42 -04:00
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
def test_xla_computation(self):
|
|
|
|
# these tests basically check the examples in the xla_computation docstring
|
|
|
|
|
|
|
|
def h(x):
|
|
|
|
return np.sin(np.cos(x))
|
|
|
|
c = api.xla_computation(h)(2.)
|
|
|
|
self.assertIn('cosine', c.GetHloText())
|
|
|
|
self.assertIn('sine', c.GetHloText())
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return x - lax.psum(x, 'i')
|
|
|
|
axis_env = [('i', 4)]
|
|
|
|
c = api.xla_computation(f, axis_env=axis_env)(2)
|
|
|
|
self.assertIn('all-reduce', c.GetHloText())
|
|
|
|
self.assertIn('replica_groups={{0,1,2,3}}', c.GetHloText())
|
|
|
|
|
|
|
|
def g(x):
|
|
|
|
rowsum = lax.psum(x, 'i')
|
|
|
|
colsum = lax.psum(x, 'j')
|
|
|
|
allsum = lax.psum(x, ('i', 'j'))
|
|
|
|
return rowsum, colsum, allsum
|
|
|
|
axis_env = [('i', 4), ('j', 2)]
|
|
|
|
c = api.xla_computation(g, axis_env=axis_env)(5.)
|
|
|
|
self.assertIn('all-reduce', c.GetHloText())
|
|
|
|
self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.GetHloText())
|
|
|
|
self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.GetHloText())
|
|
|
|
self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.GetHloText())
|
|
|
|
|
2019-09-27 17:37:44 -07:00
|
|
|
def test_xla_computation_args(self):
|
|
|
|
def foo(x, y, z):
|
|
|
|
return x + y + z
|
|
|
|
|
|
|
|
c = api.xla_computation(foo)(1., 2., 3.)
|
|
|
|
self.assertEqual(len(c.GetProgramShape().parameter_shapes()), 3)
|
|
|
|
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
|
|
param_shapes = c.GetProgramShape().parameter_shapes()
|
|
|
|
self.assertEqual(len(param_shapes), 1)
|
|
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
|
2020-03-30 11:31:29 -07:00
|
|
|
def test_xla_computation_duck_typing(self):
|
|
|
|
def foo(x, y, z):
|
|
|
|
return x + y + z
|
|
|
|
|
|
|
|
x = jax.ShapeDtypeStruct((), onp.float32)
|
|
|
|
y = jax.ShapeDtypeStruct((), onp.float32)
|
|
|
|
z = jax.ShapeDtypeStruct((), onp.float32)
|
|
|
|
|
|
|
|
c = api.xla_computation(foo)(x, y, z)
|
|
|
|
self.assertEqual(len(c.GetProgramShape().parameter_shapes()), 3)
|
|
|
|
|
|
|
|
c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
|
|
|
|
param_shapes = c.GetProgramShape().parameter_shapes()
|
|
|
|
self.assertEqual(len(param_shapes), 1)
|
|
|
|
self.assertEqual(param_shapes[0].xla_element_type(),
|
|
|
|
xb.xla_client.PrimitiveType.TUPLE)
|
|
|
|
|
2019-07-09 15:12:02 -07:00
|
|
|
def test_staging_out_multi_replica(self):
|
|
|
|
def f(x):
|
|
|
|
return api.pmap(np.mean)(x)
|
|
|
|
xla_comp = api.xla_computation(f)
|
|
|
|
xla_comp(np.arange(8)).GetHloText() # doesn't crash
|
|
|
|
|
2019-12-04 09:50:29 -08:00
|
|
|
def test_xla_computation_instantiate_constant_outputs(self):
|
|
|
|
def f():
|
|
|
|
return np.zeros((3, 4))
|
|
|
|
|
|
|
|
xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
|
2020-04-23 18:30:47 -04:00
|
|
|
out_shape, = xla_comp.GetProgramShape().result_shape().tuple_shapes()
|
2019-12-04 09:50:29 -08:00
|
|
|
self.assertEqual(out_shape.dimensions(), (3, 4))
|
|
|
|
|
2020-04-23 18:07:51 -07:00
|
|
|
def test_xla_computation_static_argnums(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3)
|
|
|
|
self.assertIn('constant(3)', xla_comp.GetHloText())
|
|
|
|
|
2019-09-06 11:45:47 -07:00
|
|
|
def test_jit_device(self):
|
|
|
|
device = xb.devices()[-1]
|
|
|
|
x = api.jit(lambda x: x, device=device)(3.)
|
2019-12-10 14:10:57 -08:00
|
|
|
self.assertIsInstance(x, xla.DeviceArray)
|
2019-09-06 11:45:47 -07:00
|
|
|
self.assertEqual(x.device_buffer.device(), device)
|
2019-07-23 02:48:53 -07:00
|
|
|
|
2019-07-23 17:03:28 -04:00
|
|
|
def test_jit_of_noncallable(self):
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(TypeError, "Expected a callable value.*",
|
2020-04-06 21:29:24 -07:00
|
|
|
lambda: api.jit(3))
|
|
|
|
|
|
|
|
def test_jit_of_generator(self):
|
|
|
|
def gen(x):
|
|
|
|
yield x
|
|
|
|
self.assertRaisesRegex(TypeError, "Expected a function, got a generator function.*",
|
|
|
|
lambda: api.jit(gen))
|
2019-06-01 09:34:33 -07:00
|
|
|
|
2019-07-24 21:45:56 +03:00
|
|
|
def test_issue_1062(self):
|
|
|
|
# code from https://github.com/google/jax/issues/1062 @shoyer
|
|
|
|
# this tests, among other things, whether ShardedDeviceTuple constants work
|
|
|
|
device_count = xb.device_count()
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def multi_step(state, count):
|
|
|
|
return lax.fori_loop(0, count, lambda i, s: s, state)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def multi_step_pmap(state, count=2):
|
|
|
|
@partial(api.pmap, axis_name='x')
|
|
|
|
def pmapped_multi_step(state):
|
|
|
|
return multi_step(state, count)
|
|
|
|
|
|
|
|
return pmapped_multi_step(state)
|
|
|
|
|
|
|
|
u = np.ones((device_count, 100))
|
|
|
|
u_final = multi_step_pmap(u) # doesn't crash
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
def test_concurrent_device_get_and_put(self):
|
|
|
|
def f(x):
|
|
|
|
for _ in range(100):
|
|
|
|
y = jax.device_put(x)
|
|
|
|
x = jax.device_get(y)
|
|
|
|
return x
|
|
|
|
|
|
|
|
xs = [onp.random.randn(i) for i in range(10)]
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
|
|
ys = [f.result() for f in futures]
|
|
|
|
for x, y in zip(xs, ys):
|
|
|
|
self.assertAllClose(x, y, check_dtypes=True)
|
|
|
|
|
2019-07-23 09:53:27 -04:00
|
|
|
def test_concurrent_jit(self):
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
return x + x - 3.
|
|
|
|
|
|
|
|
xs = [onp.random.randn(i) for i in range(10)]
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
futures = [executor.submit(partial(f, x)) for x in xs]
|
|
|
|
ys = [f.result() for f in futures]
|
|
|
|
for x, y in zip(xs, ys):
|
|
|
|
self.assertAllClose(x * 2 - 3., y, check_dtypes=True)
|
|
|
|
|
2019-08-24 12:34:44 -07:00
|
|
|
def test_dtype_warning(self):
|
|
|
|
# cf. issue #1230
|
2019-08-22 09:22:57 -07:00
|
|
|
if FLAGS.jax_enable_x64:
|
|
|
|
return # test only applies when x64 is disabled
|
|
|
|
|
|
|
|
def check_warning(warn, nowarn):
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
|
|
|
nowarn() # get rid of extra startup warning
|
|
|
|
|
|
|
|
prev_len = len(w)
|
|
|
|
nowarn()
|
|
|
|
assert len(w) == prev_len
|
|
|
|
|
|
|
|
warn()
|
|
|
|
assert len(w) > 0
|
|
|
|
msg = str(w[-1].message)
|
|
|
|
expected_prefix = "Explicitly requested dtype "
|
|
|
|
self.assertEqual(expected_prefix, msg[:len(expected_prefix)])
|
|
|
|
|
|
|
|
prev_len = len(w)
|
|
|
|
nowarn()
|
|
|
|
assert len(w) == prev_len
|
|
|
|
|
|
|
|
check_warning(lambda: np.array([1, 2, 3], dtype="float64"),
|
|
|
|
lambda: np.array([1, 2, 3], dtype="float32"),)
|
|
|
|
check_warning(lambda: np.ones(3, dtype=onp.float64),
|
|
|
|
lambda: np.ones(3))
|
|
|
|
check_warning(lambda: np.ones_like(3, dtype=onp.int64),
|
|
|
|
lambda: np.ones_like(3, dtype=onp.int32))
|
|
|
|
check_warning(lambda: np.zeros(3, dtype="int64"),
|
|
|
|
lambda: np.zeros(3, dtype="int32"))
|
|
|
|
check_warning(lambda: np.zeros_like(3, dtype="float64"),
|
|
|
|
lambda: np.zeros_like(3, dtype="float32"))
|
|
|
|
check_warning(lambda: np.full((2, 3), 1, dtype="int64"),
|
|
|
|
lambda: np.full((2, 3), 1))
|
|
|
|
check_warning(lambda: np.ones(3).astype("float64"),
|
|
|
|
lambda: np.ones(3).astype("float32"))
|
|
|
|
check_warning(lambda: np.eye(3, dtype=onp.float64),
|
|
|
|
lambda: np.eye(3))
|
|
|
|
check_warning(lambda: np.arange(3, dtype=onp.float64),
|
|
|
|
lambda: np.arange(3, dtype=onp.float32))
|
|
|
|
check_warning(lambda: np.linspace(0, 3, dtype=onp.float64),
|
|
|
|
lambda: np.linspace(0, 3, dtype=onp.float32))
|
|
|
|
check_warning(lambda: np.tri(2, dtype="float64"),
|
|
|
|
lambda: np.tri(2, dtype="float32"))
|
|
|
|
|
2020-03-22 19:50:06 +01:00
|
|
|
def test_vmap_in_axes_list(self):
|
|
|
|
# https://github.com/google/jax/issues/2367
|
|
|
|
dictionary = {'a': 5., 'b': np.ones(2)}
|
|
|
|
x = np.zeros(3)
|
|
|
|
y = np.arange(3.)
|
|
|
|
|
|
|
|
|
|
|
|
def f(dct, x, y):
|
|
|
|
return dct['a'] + dct['b'] + x + y
|
|
|
|
|
|
|
|
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
|
|
|
|
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
|
|
|
|
self.assertAllClose(out1, out2, check_dtypes=True)
|
|
|
|
|
2019-10-28 14:03:52 -07:00
|
|
|
def test_vmap_in_axes_tree_prefix_error(self):
|
|
|
|
# https://github.com/google/jax/issues/795
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-10-31 13:04:12 -07:00
|
|
|
ValueError,
|
|
|
|
"axes specification must be a tree prefix of the corresponding "
|
|
|
|
r"value, got specification \(0, 0\) for value "
|
2019-11-14 16:00:55 -05:00
|
|
|
r"PyTreeDef\(tuple, \[\*\]\).",
|
|
|
|
lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3))
|
2019-10-31 13:04:12 -07:00
|
|
|
)
|
2019-10-28 14:03:52 -07:00
|
|
|
|
2019-10-31 11:57:37 -07:00
|
|
|
def test_vmap_unbatched_object_passthrough_issue_183(self):
|
2019-10-28 15:20:49 -07:00
|
|
|
# https://github.com/google/jax/issues/183
|
|
|
|
fun = lambda f, x: f(x)
|
|
|
|
vfun = api.vmap(fun, (None, 0))
|
|
|
|
ans = vfun(lambda x: x + 1, np.arange(3))
|
|
|
|
self.assertAllClose(ans, onp.arange(1, 4), check_dtypes=False)
|
|
|
|
|
2019-10-31 11:57:37 -07:00
|
|
|
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
|
2019-10-30 17:31:37 -07:00
|
|
|
# https://github.com/google/jax/issues/705
|
|
|
|
def h(a, b):
|
|
|
|
return np.sum(a) + np.sum(b)
|
|
|
|
|
|
|
|
X = onp.random.randn(10, 4)
|
|
|
|
U = onp.random.randn(10, 2)
|
2019-10-31 13:20:32 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-30 17:31:37 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
2019-10-31 12:01:37 -07:00
|
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
2019-10-30 17:31:37 -07:00
|
|
|
"so\n"
|
|
|
|
"arg 0 has an axis to be mapped of size 10\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
|
|
api.vmap(h, in_axes=(0, 1))(X, U)
|
2019-10-30 17:31:37 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-31 13:20:32 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
|
|
|
r"arg 0 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
r"arg 1 has shape \(10, 2\) and axis 1 is to be mapped" "\n"
|
|
|
|
r"arg 2 has shape \(10, 4\) and axis 0 is to be mapped" "\n"
|
|
|
|
"so\n"
|
|
|
|
"args 0, 2 have axes to be mapped of size 10\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
"arg 1 has an axis to be mapped of size 2"):
|
|
|
|
api.vmap(lambda x, y, z: None, in_axes=(0, 1, 0))(X, U, X)
|
2019-10-31 13:20:32 -07:00
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
2019-10-30 17:31:37 -07:00
|
|
|
ValueError,
|
|
|
|
"vmap got inconsistent sizes for array axes to be mapped:\n"
|
2019-10-31 11:57:37 -07:00
|
|
|
"the tree of axis sizes is:\n"
|
2020-03-28 16:50:31 +01:00
|
|
|
r"\(10, \[2, 2\]\)"):
|
|
|
|
api.vmap(h, in_axes=(0, 1))(X, [U, U])
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap got arg 0 of rank 0 but axis to be mapped 0"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# The mapped inputs cannot be scalars
|
|
|
|
api.vmap(lambda x: x)(1.)
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, re.escape("vmap got arg 0 of rank 1 but axis to be mapped [1. 2.]")):
|
2020-03-28 16:50:31 +01:00
|
|
|
api.vmap(lambda x: x, in_axes=(np.array([1., 2.]),))(np.array([1., 2.]))
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap must have at least one non-None in_axes"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# If the output is mapped, there must be a non-None in_axes
|
|
|
|
api.vmap(lambda x: x, in_axes=None)(np.array([1., 2.]))
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap got arg 0 of rank 1 but axis to be mapped 1"):
|
2020-03-28 16:50:31 +01:00
|
|
|
api.vmap(lambda x: x, in_axes=1)(np.array([1., 2.]))
|
|
|
|
|
|
|
|
# Error is: TypeError: only integer scalar arrays can be converted to a scalar index
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "axes specification must be a tree prefix of the corresponding value"):
|
2020-03-28 16:50:31 +01:00
|
|
|
api.vmap(lambda x: x, in_axes=0, out_axes=(2, 3))(np.array([1., 2.]))
|
|
|
|
|
2020-04-12 15:35:35 -04:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "vmap has mapped output but out_axes is None"):
|
2020-03-28 16:50:31 +01:00
|
|
|
# If the output is mapped, then there must be some out_axes specified
|
|
|
|
api.vmap(lambda x: x, out_axes=None)(np.array([1., 2.]))
|
|
|
|
|
2019-10-30 17:31:37 -07:00
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
def test_vmap_structured_in_axes(self):
|
|
|
|
|
|
|
|
A, B, C, D = 2, 3, 4, 5
|
|
|
|
K = 6 # batch size
|
|
|
|
x = onp.ones((K, A, B)) # batch axis in different locations
|
|
|
|
y = onp.ones((B, K, C))
|
|
|
|
z = onp.ones((C, D, K))
|
|
|
|
|
|
|
|
def foo(tree_arg):
|
|
|
|
x, (y, z) = tree_arg
|
|
|
|
return np.dot(x, np.dot(y, z))
|
|
|
|
|
|
|
|
tree = (x, (y, z))
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, (1, 2)),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
Point = collections.namedtuple("Point", ["x", "y"])
|
|
|
|
tree = (x, Point(y, z))
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, Point(1, 2)),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
def foo(tree_arg):
|
|
|
|
x, dct = tree_arg
|
|
|
|
y, z = dct['a'], dct['b']
|
|
|
|
return np.dot(x, np.dot(y, z))
|
|
|
|
|
|
|
|
tree = (x, {'a':y, 'b':z})
|
|
|
|
vfoo = api.vmap(foo, in_axes=((0, {'a':1, 'b':2}),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
|
|
|
tree = (x, collections.OrderedDict([('a', y), ('b', z)]))
|
|
|
|
vfoo = api.vmap(
|
|
|
|
foo, in_axes=((0, collections.OrderedDict([('a', 1), ('b', 2)])),))
|
|
|
|
self.assertEqual(vfoo(tree).shape, (6, 2, 5))
|
|
|
|
|
2019-10-30 14:57:00 -07:00
|
|
|
def test_jit_reference_dropping(self):
|
|
|
|
x = onp.ones(10)
|
|
|
|
f = (lambda x: lambda: x)(x) # reference to x in f's closure
|
|
|
|
g = jit(f)
|
|
|
|
x = weakref.ref(x) # no more strong ref to x in this scope
|
|
|
|
assert x() is not None # x is still around
|
|
|
|
f() # f runs
|
|
|
|
g() # g runs
|
|
|
|
g() # g runs a second time
|
|
|
|
del f # delete the raw callable
|
|
|
|
assert x() is not None # x is still around
|
|
|
|
g() # g still runs
|
|
|
|
del g # no more references to x
|
|
|
|
assert x() is None # x is gone
|
|
|
|
|
|
|
|
def test_jit_global_cache(self):
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
|
|
|
return x
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
api.jit(f)(2)
|
|
|
|
python_should_be_executing = False
|
|
|
|
api.jit(f)(3)
|
|
|
|
|
2019-12-11 02:48:51 +00:00
|
|
|
def test_jit_shallow_copy(self):
|
|
|
|
def f(x):
|
|
|
|
return copy.copy(x)
|
|
|
|
api.jit(f)(1)
|
|
|
|
|
|
|
|
def test_jit_deep_copy(self):
|
|
|
|
def f(x):
|
|
|
|
return copy.deepcopy(x)
|
|
|
|
api.jit(f)(1)
|
|
|
|
|
2019-10-30 14:57:00 -07:00
|
|
|
def test_pmap_global_cache(self):
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
|
|
|
return x
|
|
|
|
|
|
|
|
x = onp.ones(1)
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
api.pmap(f)(x)
|
|
|
|
python_should_be_executing = False
|
|
|
|
api.pmap(f)(x)
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
api.pmap(f, 'i')(x)
|
|
|
|
python_should_be_executing = False
|
|
|
|
api.pmap(f, 'i')(x)
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
def test_device_array_repr(self):
|
2019-11-12 06:18:43 -08:00
|
|
|
rep = repr(np.ones(()) + 1.)
|
|
|
|
self.assertStartsWith(rep, 'DeviceArray')
|
2019-06-01 09:34:33 -07:00
|
|
|
|
2019-11-14 21:18:23 -08:00
|
|
|
def test_grad_without_enough_args_error_message(self):
|
|
|
|
# https://github.com/google/jax/issues/1696
|
|
|
|
def f(x, y): return x + y
|
|
|
|
df = api.grad(f, argnums=0)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 21:18:23 -08:00
|
|
|
TypeError,
|
|
|
|
"differentiating with respect to argnums=0 requires at least 1 "
|
|
|
|
"positional arguments to be passed by the caller, but got only 0 "
|
|
|
|
"positional arguments.",
|
|
|
|
lambda: partial(df, x=0.)(y=1.))
|
|
|
|
|
2019-11-26 07:56:48 -08:00
|
|
|
def test_grad_of_jit_compilation_caching(self):
|
|
|
|
if not hasattr(self, "assertLogs"):
|
|
|
|
raise unittest.SkipTest("test requires assertLogs (python 3)")
|
|
|
|
|
|
|
|
lax.add(1, 2) # make sure some initial warnings are already printed
|
|
|
|
|
|
|
|
sin = api.jit(np.sin)
|
|
|
|
|
2019-11-26 17:06:57 -08:00
|
|
|
prev_level = logging.get_verbosity()
|
|
|
|
try:
|
|
|
|
logging.set_verbosity('DEBUG')
|
|
|
|
with self.assertLogs(level=logging.DEBUG) as l:
|
|
|
|
ans1 = api.grad(sin)(2.)
|
|
|
|
ans2 = api.grad(sin)(3.)
|
|
|
|
finally:
|
|
|
|
logging.set_verbosity(prev_level)
|
2019-11-26 07:56:48 -08:00
|
|
|
self.assertLen(l.output, 2)
|
|
|
|
|
|
|
|
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
|
|
|
|
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
|
|
|
|
|
2019-11-22 10:53:11 -08:00
|
|
|
def test_remat_basic(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
2019-11-27 14:28:13 -08:00
|
|
|
return lax.sin(lax.sin(x)), 3.
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x, _ = g(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
ans = f(2.)
|
2019-11-27 14:28:13 -08:00
|
|
|
expected = onp.sin(onp.sin(2.))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans, f_lin = api.linearize(f, 2.)
|
2019-11-27 14:28:13 -08:00
|
|
|
expected = onp.sin(onp.sin(2.))
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = f_lin(3.)
|
2019-11-27 14:28:13 -08:00
|
|
|
expected = onp.cos(onp.sin(2.)) * onp.cos(2.) * 3.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
sin_calls = []
|
|
|
|
cos_calls = []
|
|
|
|
sin_impl = lax.sin_p.impl
|
|
|
|
cos_impl = lax.cos_p.impl
|
|
|
|
try:
|
|
|
|
lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x))
|
|
|
|
lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x))
|
|
|
|
f_lin(3.)
|
|
|
|
finally:
|
|
|
|
lax.sin_p.def_impl(sin_impl)
|
|
|
|
lax.cos_p.def_impl(cos_impl)
|
|
|
|
self.assertEqual(len(sin_calls), 1)
|
|
|
|
self.assertEqual(len(cos_calls), 2)
|
|
|
|
|
|
|
|
def test_remat_freevars(self):
|
|
|
|
def f1(x):
|
|
|
|
y = 2 * np.sin(x)
|
|
|
|
z = np.cos(x) * np.sin(y)
|
|
|
|
return z
|
|
|
|
|
|
|
|
def f2(x):
|
|
|
|
y = 2 * np.sin(x)
|
|
|
|
z = api.remat(lambda x: np.cos(x) * np.sin(y))(x)
|
|
|
|
return z
|
|
|
|
|
|
|
|
ans, f_lin = api.linearize(f2, 2.)
|
|
|
|
expected, f_lin_expected = api.linearize(f1, 2.)
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-11-27 14:28:13 -08:00
|
|
|
ans = f_lin(3.)
|
|
|
|
expected = f_lin_expected(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
def test_remat_grad_python_control_flow(self):
|
|
|
|
@partial(api.remat, concrete=True)
|
|
|
|
def g(x):
|
|
|
|
if x > 0:
|
|
|
|
return lax.sin(x), 3.
|
|
|
|
else:
|
|
|
|
return lax.cos(x), 4.
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
x, _ = g(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
ans = f(2.)
|
|
|
|
expected = onp.sin(2.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f)(2.)
|
|
|
|
expected = onp.cos(2.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_jit(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
|
|
|
return lax.sin(lax.sin(x))
|
|
|
|
|
|
|
|
def f_(x):
|
|
|
|
return g(x)
|
|
|
|
f = api.jit(f_)
|
|
|
|
|
|
|
|
ans = f(2.)
|
|
|
|
expected = onp.sin(onp.sin(2.))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f)(2.)
|
|
|
|
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jit(api.grad(f_))(2.)
|
|
|
|
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_vmap(self):
|
|
|
|
@api.remat
|
|
|
|
def g(x):
|
|
|
|
return lax.sin(lax.sin(x))
|
|
|
|
|
|
|
|
x = onp.arange(3.)
|
|
|
|
|
|
|
|
ans = api.vmap(g)(x)
|
|
|
|
expected = onp.sin(onp.sin(x))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jacfwd(g)(x)
|
|
|
|
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jacrev(g)(x)
|
|
|
|
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_higher_order_autodiff(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.cos(lax.sin(x))
|
|
|
|
g = api.remat(f)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(g))(3.)
|
|
|
|
expected = api.grad(api.grad(f))(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_scan(self):
|
|
|
|
to_scan = lambda c, x: (np.sin(c), None)
|
|
|
|
|
|
|
|
def f_noremat(x):
|
|
|
|
y, _ = lax.scan(to_scan, x, onp.arange(3.))
|
|
|
|
return y
|
|
|
|
|
|
|
|
def f_yesremat(x):
|
|
|
|
y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
|
|
|
|
return y
|
|
|
|
|
|
|
|
ans = f_yesremat(4.)
|
|
|
|
expected = f_noremat(4.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(f_yesremat)(4.)
|
|
|
|
expected = api.grad(f_noremat)(4.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
2019-11-28 09:00:55 +01:00
|
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
2019-11-27 15:25:49 -08:00
|
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
2019-11-22 10:53:11 -08:00
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
2019-11-28 09:00:55 +01:00
|
|
|
scan_eqn, = jaxpr.jaxpr.eqns
|
2019-11-22 10:53:11 -08:00
|
|
|
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
|
|
|
|
|
|
|
def test_remat_no_redundant_flops(self):
|
|
|
|
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def g(x):
|
|
|
|
return f(2., x)
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def f(x, y):
|
|
|
|
return np.sin(x) * y
|
|
|
|
|
|
|
|
# We swap out sin_p's impl rule to count how many times it's invoked
|
|
|
|
called = []
|
|
|
|
sin_impl = lax.sin_p.impl
|
|
|
|
try:
|
|
|
|
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
|
|
|
api.grad(g)(3.)
|
|
|
|
finally:
|
|
|
|
lax.sin_p.def_impl(sin_impl)
|
|
|
|
num_calls = len(called)
|
|
|
|
self.assertEqual(num_calls, 1)
|
|
|
|
|
|
|
|
def test_remat_binomial_checkpointing(self):
|
|
|
|
def binom_checkpoint(funs):
|
|
|
|
if len(funs) == 1:
|
|
|
|
return funs[0]
|
|
|
|
else:
|
|
|
|
f1 = binom_checkpoint(funs[:len(funs)//2])
|
|
|
|
f2 = binom_checkpoint(funs[len(funs)//2:])
|
|
|
|
return api.remat(lambda x: f1(f2(x)))
|
|
|
|
|
|
|
|
f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
|
|
|
|
f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x))))
|
|
|
|
x = 4.
|
|
|
|
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
|
|
|
|
2019-12-23 11:49:01 -08:00
|
|
|
def test_remat_symbolic_zeros(self):
|
|
|
|
# code from https://github.com/google/jax/issues/1907
|
|
|
|
test_remat = True
|
|
|
|
test_scan = True
|
|
|
|
|
|
|
|
key = jax.random.PRNGKey(0)
|
|
|
|
key, split = jax.random.split(key)
|
|
|
|
n = 5
|
|
|
|
|
|
|
|
def func(D0):
|
|
|
|
def shift(R, dR, **unused_kwargs):
|
|
|
|
return R + dR
|
|
|
|
|
|
|
|
def apply_fn(R):
|
|
|
|
return D0 * R
|
|
|
|
|
|
|
|
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
|
|
|
|
dtype=np.float32)
|
|
|
|
|
|
|
|
def move(R,i):
|
|
|
|
F = apply_fn(R)
|
|
|
|
return shift(R, 0.001 * F), np.array([0.])
|
|
|
|
|
|
|
|
move = api.remat(move)
|
|
|
|
R, temp = lax.scan(move, Rinit, np.arange(2))
|
|
|
|
return R[0, 0]
|
|
|
|
|
|
|
|
api.grad(func)(5.0) # doesn't crash
|
|
|
|
|
2020-01-31 23:47:30 -08:00
|
|
|
def test_remat_jit2(self):
|
|
|
|
@api.jit
|
|
|
|
def f(x):
|
|
|
|
y = 2 * x
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def g():
|
|
|
|
return y
|
|
|
|
|
|
|
|
return g()
|
|
|
|
|
|
|
|
self.assertAllClose(f(3), 6, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_remat_nontrivial_env(self):
|
|
|
|
# simplified from https://github.com/google/jax/issues/2030
|
|
|
|
|
|
|
|
@api.remat
|
|
|
|
def foo(state, dt=0.5, c=1):
|
|
|
|
u, u_t = state
|
|
|
|
u_tt = c**2 * u
|
|
|
|
u_t = u_t + u_tt * dt
|
|
|
|
return (u, u_t)
|
|
|
|
|
|
|
|
@partial(api.jit, static_argnums=(1,))
|
|
|
|
def _multi_step(state, count, dt, c):
|
|
|
|
f = lambda s, _: (foo(s, dt, c), _)
|
|
|
|
return lax.scan(f, state, None, count)
|
|
|
|
|
|
|
|
def multi_step(state, count, dt=1/np.sqrt(2), c=1):
|
|
|
|
return _multi_step(state, count, dt, c)
|
|
|
|
|
|
|
|
def loss(u0, target, steps, dt=1/np.sqrt(2), c=1):
|
|
|
|
init = (u0, np.zeros_like(u0))
|
|
|
|
(uf, _), _ = multi_step(init, steps, dt, c)
|
|
|
|
return ((uf - target) ** 2).mean()
|
|
|
|
|
|
|
|
target = np.zeros((128, 128))
|
|
|
|
u0 = np.ones_like(target)
|
|
|
|
loss(u0, target, 10) # doesn't crash
|
|
|
|
|
2020-02-11 15:56:53 -08:00
|
|
|
def test_remat_jit3(self):
|
|
|
|
# https://github.com/google/jax/issues/2180
|
|
|
|
def f(w, x):
|
|
|
|
a = np.dot(x, w)
|
|
|
|
b = np.einsum("btd,bTd->btT", a, a)
|
|
|
|
c = np.einsum("btT,btd->btd", b, a)
|
|
|
|
return np.sum(c)
|
|
|
|
|
|
|
|
w = np.ones([1, 1])
|
|
|
|
x = np.ones([1, 1, 1])
|
|
|
|
f = api.remat(f)
|
|
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def mul(a, b):
|
|
|
|
return a * b
|
|
|
|
|
|
|
|
def f(w, x):
|
|
|
|
a = mul(w, x)
|
|
|
|
b = mul(a, a)
|
|
|
|
return b
|
|
|
|
|
|
|
|
w = 1.
|
|
|
|
x = 1.
|
|
|
|
f = api.remat(f)
|
|
|
|
api.grad(f)(w, x) # doesn't crash
|
|
|
|
|
|
|
|
def test_remat_scan2(self):
|
|
|
|
# https://github.com/google/jax/issues/1963
|
|
|
|
|
|
|
|
def scan_bug(x0):
|
|
|
|
f = lambda x, _: (x + 1, None)
|
|
|
|
def scanned_f(x, _):
|
|
|
|
return lax.scan(f, x, xs=None, length=1)[0], None
|
|
|
|
x, _ = jax.remat(scanned_f)(x0, None)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jax.grad(scan_bug)(1.0) # doesn't crash
|
|
|
|
|
2020-04-24 18:19:24 -07:00
|
|
|
def test_remat_jit_static_argnum(self):
|
|
|
|
# https://github.com/google/jax/issues/2833
|
|
|
|
def f(a_bool, y):
|
|
|
|
if a_bool:
|
|
|
|
return y + 1
|
|
|
|
else:
|
|
|
|
return y
|
|
|
|
|
|
|
|
api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash
|
|
|
|
|
2019-12-04 19:34:21 -08:00
|
|
|
def test_trivial_computations(self):
|
|
|
|
x = np.array([1, 2, 3])
|
|
|
|
y = api.jit(lambda x: x)(x)
|
|
|
|
self.assertIs(x, y)
|
|
|
|
|
|
|
|
z1, z2 = api.jit(lambda x: (x, x))(x)
|
|
|
|
self.assertIs(z1, z2)
|
|
|
|
|
|
|
|
x1, x2 = np.array([1, 2]), np.array([2, 3])
|
|
|
|
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
|
|
|
|
self.assertIs(z1, x2)
|
|
|
|
self.assertIs(z3, x1)
|
|
|
|
self.assertEqual(z2, 1)
|
|
|
|
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
def test_nested_jit_hoisting(self):
|
|
|
|
@api.jit
|
|
|
|
def f(x, y):
|
|
|
|
z = 2 * x
|
|
|
|
return y + z, 3
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def g(x):
|
|
|
|
return f(2, x)
|
|
|
|
|
|
|
|
jaxpr_subcomp = xla.jaxpr_subcomp
|
|
|
|
|
|
|
|
jaxprs = []
|
|
|
|
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
|
|
|
jaxprs.append(jaxpr)
|
|
|
|
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
|
|
|
|
|
|
|
try:
|
|
|
|
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
|
|
|
ans = g(3)
|
|
|
|
finally:
|
|
|
|
xla.jaxpr_subcomp = jaxpr_subcomp
|
|
|
|
|
|
|
|
self.assertEqual(ans, (7, 3))
|
|
|
|
self.assertLen(jaxprs, 2)
|
|
|
|
outer_jaxpr, inner_jaxpr = jaxprs
|
|
|
|
|
|
|
|
self.assertLen(outer_jaxpr.eqns, 1)
|
|
|
|
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
2020-02-05 15:38:25 +01:00
|
|
|
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
make nested jit stage out full inner jit bodies
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.
For example:
```python
@jit
def f(x, y):
z = 2 * x
return y + z
@jit
def g(x):
return f(2, x)
g(3)
```
would lead to these XLA computations being compiled and executed:
```
HloModule jit_f.7
ENTRY jit_f.7 {
parameter.2 = () parameter(1)
tuple.3 = () tuple()
parameter.1 = s32[] parameter(0)
constant.4 = s32[] constant(2)
multiply.5 = s32[] multiply(parameter.1, constant.4)
ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}
HloModule jit_g.14
jaxpr_subcomputation.4 {
parameter.6 = () parameter(1)
tuple.8 = () tuple()
parameter.7 = s32[] parameter(2)
parameter.5 = s32[] parameter(0)
add.9 = s32[] add(parameter.7, parameter.5)
ROOT tuple.10 = (s32[]) tuple(add.9)
}
ENTRY jit_g.14 {
constant.1 = s32[] constant(4)
tuple.3 = () tuple()
parameter.2 = s32[] parameter(0)
call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```
Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.
This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).
The solution was just to tag jaxpr traces differently in the two cases.
2019-12-11 18:39:16 -08:00
|
|
|
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
|
|
|
self.assertLen(inner_jaxpr.eqns, 2)
|
|
|
|
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
|
|
|
|
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
|
|
|
|
|
2019-12-17 17:49:06 -08:00
|
|
|
def test_primitive_compilation_cache(self):
|
2019-12-19 11:19:58 -08:00
|
|
|
with jtu.count_primitive_compiles() as count:
|
2019-12-17 17:49:06 -08:00
|
|
|
lax.add(1, 2)
|
|
|
|
lax.add(2, 3)
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
|
2020-01-18 22:12:07 -08:00
|
|
|
def test_arange_jit(self):
|
|
|
|
# see https://github.com/google/jax/issues/553
|
|
|
|
def fun(x):
|
|
|
|
r = np.arange(x.shape[0])[x]
|
|
|
|
return r
|
|
|
|
|
|
|
|
jit(fun)(np.array([0, 1, 2], dtype=np.int32)) # doesn't crash
|
|
|
|
|
2020-02-15 06:35:49 +01:00
|
|
|
def helper_save_tracer(self, x):
|
|
|
|
self._saved_tracer = x
|
|
|
|
return x
|
|
|
|
|
|
|
|
def test_escaped_tracers_diffent_top_level_traces(self):
|
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
2020-02-15 06:35:49 +01:00
|
|
|
re.compile(
|
|
|
|
"Encountered an unexpected tracer.*Different traces at same level",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(lambda x: self._saved_tracer)(0.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_cant_lift_sublevels(self):
|
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
2020-02-15 06:35:49 +01:00
|
|
|
re.compile(
|
|
|
|
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(lambda x: x)(self._saved_tracer)
|
|
|
|
|
|
|
|
def test_escaped_tracers_tracer_from_higher_level(self):
|
|
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
2020-02-15 06:35:49 +01:00
|
|
|
re.compile(
|
|
|
|
"Encountered an unexpected tracer.*Tracer from a higher level",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.grad(lambda x: x)(self._saved_tracer)
|
|
|
|
|
|
|
|
def test_escaped_tracers_incompatible_sublevel(self):
|
|
|
|
def func1(x):
|
|
|
|
api.jit(self.helper_save_tracer)(0.)
|
|
|
|
# Use the tracer
|
|
|
|
return x + self._saved_tracer
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
2020-02-15 06:35:49 +01:00
|
|
|
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(func1)(2.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_cant_lift(self):
|
|
|
|
def func1(x):
|
|
|
|
api.grad(self.helper_save_tracer)(0.)
|
|
|
|
return x + self._saved_tracer
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile("Encountered an unexpected tracer.*Can't lift",
|
|
|
|
re.DOTALL)):
|
2020-02-15 06:35:49 +01:00
|
|
|
api.grad(func1)(2.)
|
|
|
|
|
|
|
|
def test_escaped_tracers_not_among_input_tracers(self):
|
|
|
|
def func1(x):
|
|
|
|
api.grad(self.helper_save_tracer)(x)
|
|
|
|
# Use the tracer
|
|
|
|
return x + self._saved_tracer
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
2020-01-15 15:00:38 -08:00
|
|
|
core.UnexpectedTracerError,
|
|
|
|
re.compile(
|
2020-02-15 06:35:49 +01:00
|
|
|
"Encountered an unexpected tracer.*Tracer not among input tracers",
|
|
|
|
re.DOTALL)):
|
|
|
|
api.jit(func1)(2.)
|
|
|
|
|
2019-12-04 19:34:21 -08:00
|
|
|
|
|
|
|
class JaxprTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_scalar_literals(self):
|
|
|
|
jaxpr = api.make_jaxpr(lambda x: x + 2)(42)
|
|
|
|
self.assertLen(jaxpr.jaxpr.constvars, 0)
|
|
|
|
|
|
|
|
def test_const(self):
|
|
|
|
def fun(x):
|
|
|
|
return (x, 1., np.zeros(1))
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(fun)(0.)
|
2020-02-10 11:40:05 +01:00
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda b ; a.
|
|
|
|
let
|
|
|
|
in (a, 1.0, b) }
|
|
|
|
""", str(jaxpr))
|
2019-12-04 19:34:21 -08:00
|
|
|
|
|
|
|
def test_cond(self):
|
|
|
|
def f(x):
|
|
|
|
return lax.cond(x >= 0.,
|
|
|
|
x + 1.,
|
|
|
|
lambda xt: xt + x,
|
|
|
|
x + 2.,
|
|
|
|
lambda xf: xf - x)
|
|
|
|
jaxpr = api.make_jaxpr(f)(3.)
|
2020-02-10 11:40:05 +01:00
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
|
|
|
c = add a 1.0
|
|
|
|
d = add a 2.0
|
|
|
|
e = cond[ false_jaxpr={ lambda ; b a.
|
|
|
|
let c = sub a b
|
2020-03-19 11:28:35 -07:00
|
|
|
in (c,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
linear=(False, False, False, False)
|
|
|
|
true_jaxpr={ lambda ; b a.
|
|
|
|
let c = add a b
|
2020-03-19 11:28:35 -07:00
|
|
|
in (c,) } ] b a c a d
|
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def testExamplesJaxprDoc(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
"""Tests examples included in the Understanding jaxprs doc (docs/jaxpr.rst)."""
|
2020-02-10 11:40:05 +01:00
|
|
|
from jax import numpy as jnp
|
|
|
|
def func1(first, second):
|
|
|
|
temp = first + jnp.sin(second) * 3.
|
|
|
|
return jnp.sum(temp)
|
|
|
|
|
|
|
|
jaxpr = jax.make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda ; a b.
|
|
|
|
let c = sin b
|
|
|
|
d = mul c 3.0
|
|
|
|
e = add a d
|
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
|
|
|
f = reduce_sum[ axes=(0,) ] e
|
2020-03-19 11:28:35 -07:00
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func5(first, second):
|
|
|
|
temp = first + np.sin(second) * 3. - jnp.ones(8)
|
|
|
|
return temp
|
|
|
|
|
|
|
|
def func6(first):
|
|
|
|
return func5(first, jnp.ones(8))
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func6)(jnp.ones(8))
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda b d ; a.
|
|
|
|
let c = add a b
|
|
|
|
e = sub c d
|
2020-03-19 11:28:35 -07:00
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func7(arg):
|
|
|
|
return lax.cond(arg >= 0.,
|
|
|
|
arg,
|
|
|
|
lambda xtrue: xtrue + 3.,
|
|
|
|
arg,
|
|
|
|
lambda xfalse: xfalse - 3.)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func7)(5.)
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda ; a.
|
|
|
|
let b = ge a 0.0
|
|
|
|
c = cond[ false_jaxpr={ lambda ; a.
|
|
|
|
let b = sub a 3.0
|
2020-03-19 11:28:35 -07:00
|
|
|
in (b,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
linear=(False, False)
|
|
|
|
true_jaxpr={ lambda ; a.
|
|
|
|
let b = add a 3.0
|
2020-03-19 11:28:35 -07:00
|
|
|
in (b,) } ] b a a
|
|
|
|
in (c,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func8(arg1, arg2): # arg2 is a pair
|
|
|
|
return lax.cond(arg1 >= 0.,
|
|
|
|
arg2,
|
|
|
|
lambda xtrue: xtrue[0],
|
|
|
|
arg2,
|
|
|
|
lambda xfalse: jnp.ones(1) + xfalse[1])
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func8)(5., (jnp.zeros(1), 2.))
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda e ; a b c.
|
|
|
|
let d = ge a 0.0
|
|
|
|
f = cond[ false_jaxpr={ lambda ; c a b.
|
|
|
|
let d = add c b
|
2020-03-19 11:28:35 -07:00
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
linear=(False, False, False, False, False)
|
|
|
|
true_jaxpr={ lambda ; a b.
|
|
|
|
let
|
2020-03-19 11:28:35 -07:00
|
|
|
in (a,) } ] d b c e b c
|
|
|
|
in (f,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func10(arg, n):
|
|
|
|
ones = jnp.ones(arg.shape) # A constant
|
|
|
|
return lax.fori_loop(0, n,
|
|
|
|
lambda i, carry: carry + ones * 3. + arg,
|
|
|
|
arg + ones)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func10)(onp.ones(16), 5)
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda c d ; a b.
|
|
|
|
let e = add a d
|
|
|
|
f g h = while[ body_jaxpr={ lambda ; e g a b c.
|
|
|
|
let d = add a 1
|
|
|
|
f = add c e
|
|
|
|
h = add f g
|
|
|
|
in (d, b, h) }
|
|
|
|
body_nconsts=2
|
|
|
|
cond_jaxpr={ lambda ; a b c.
|
|
|
|
let d = lt a b
|
2020-03-19 11:28:35 -07:00
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
cond_nconsts=0 ] c a 0 b e
|
2020-03-19 11:28:35 -07:00
|
|
|
in (h,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func11(arr, extra):
|
|
|
|
ones = jnp.ones(arr.shape) # A constant
|
|
|
|
|
|
|
|
def body(carry, aelems):
|
|
|
|
# carry: running dot-product of the two arrays
|
|
|
|
# aelems: a pair with corresponding elements from the two arrays
|
|
|
|
ae1, ae2 = aelems
|
|
|
|
return (carry + ae1 * ae2 + extra, carry)
|
|
|
|
|
|
|
|
return lax.scan(body, 0., (arr, ones))
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func11)(onp.ones(16), 5.)
|
2020-04-07 18:21:04 -07:00
|
|
|
# TODO(#2640): update docs/jaxpr.rst to reflect new jaxpr
|
2020-02-10 11:40:05 +01:00
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda c ; a b.
|
|
|
|
let d e = scan[ forward=True
|
2020-04-07 18:21:04 -07:00
|
|
|
jaxpr={ lambda ; f a b c.
|
|
|
|
let d = mul b c
|
|
|
|
e = add a d
|
|
|
|
g = add e f
|
|
|
|
in (g, a) }
|
2020-02-10 11:40:05 +01:00
|
|
|
length=16
|
2020-04-07 18:21:04 -07:00
|
|
|
linear=(False, False, False, False)
|
2020-02-10 11:40:05 +01:00
|
|
|
num_carry=1
|
2020-04-07 18:21:04 -07:00
|
|
|
num_consts=1 ] b 0.0 a c
|
2020-02-10 11:40:05 +01:00
|
|
|
in (d, e) }
|
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func12(arg):
|
|
|
|
@api.jit
|
|
|
|
def inner(x):
|
|
|
|
return x + arg * jnp.ones(1) # Include a constant in the inner function
|
|
|
|
|
|
|
|
return arg + inner(arg - 2.)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func12)(1.)
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda b ; a.
|
|
|
|
let c = sub a 2.0
|
|
|
|
d = xla_call[ backend=None
|
|
|
|
call_jaxpr={ lambda ; c b a.
|
|
|
|
let d = mul b c
|
|
|
|
e = add a d
|
2020-03-19 11:28:35 -07:00
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
device=None
|
|
|
|
name=inner ] b a c
|
|
|
|
e = add a d
|
2020-03-19 11:28:35 -07:00
|
|
|
in (e,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
|
|
|
|
|
|
|
def func13(arr, extra):
|
|
|
|
def inner(x):
|
|
|
|
# use a free variable "extra" and a constant jnp.ones(1)
|
|
|
|
return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
|
|
|
|
|
|
|
|
return api.pmap(inner, axis_name='rows')(arr)
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(func13)(jnp.ones((1, 3)), 5.)
|
|
|
|
self.assertMultiLineStrippedEqual("""
|
|
|
|
{ lambda c ; a b.
|
|
|
|
let d = xla_pmap[ axis_name=rows
|
|
|
|
axis_size=1
|
|
|
|
backend=None
|
|
|
|
call_jaxpr={ lambda ; d b a.
|
2019-12-04 19:34:21 -08:00
|
|
|
let c = add a b
|
2020-02-10 11:40:05 +01:00
|
|
|
e = add c d
|
|
|
|
f = psum[ axis_name=rows ] a
|
|
|
|
g = div e f
|
2020-03-19 11:28:35 -07:00
|
|
|
in (g,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
devices=None
|
|
|
|
global_axis_size=None
|
|
|
|
mapped_invars=(True, False, True)
|
|
|
|
name=inner ] c b a
|
2020-03-19 11:28:35 -07:00
|
|
|
in (d,) }
|
2020-02-10 11:40:05 +01:00
|
|
|
""", str(jaxpr))
|
2019-11-26 07:56:48 -08:00
|
|
|
|
2020-04-23 18:07:51 -07:00
|
|
|
def test_make_jaxpr_static_argnums(self):
|
|
|
|
def f(x, y):
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
jaxpr = api.make_jaxpr(f, static_argnums=(1,))(2, 3)
|
|
|
|
self.assertIn('3', str(jaxpr))
|
|
|
|
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
class LazyTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def count_compiles(self):
|
|
|
|
|
|
|
|
make_computation_builder = xb.make_computation_builder
|
|
|
|
count = [0]
|
|
|
|
|
|
|
|
def make_computation_builder_and_count(*args, **kwargs):
|
|
|
|
count[0] += 1
|
|
|
|
return make_computation_builder(*args, **kwargs)
|
|
|
|
|
|
|
|
xb.make_computation_builder = make_computation_builder_and_count
|
|
|
|
try:
|
|
|
|
yield count
|
|
|
|
finally:
|
|
|
|
xb.make_computation_builder = make_computation_builder
|
|
|
|
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def test_lazy_jit_closed_over_values(self):
|
|
|
|
if not core.skip_checks:
|
2020-01-18 08:26:23 -05:00
|
|
|
raise unittest.SkipTest("oom test skipped when core.skip_checks is False")
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
|
|
|
y = np.arange(int(1e12)) # will likely oom if materialized
|
|
|
|
ans = jit(lambda x: (x + y)[1])(1)
|
|
|
|
self.assertEqual(ans, 2)
|
|
|
|
|
|
|
|
def test_jit_forces_arguments(self):
|
|
|
|
|
|
|
|
@api.jit
|
|
|
|
def f(x):
|
|
|
|
assert python_should_be_executing
|
|
|
|
return np.sum(x)
|
|
|
|
|
|
|
|
x = np.arange(10, dtype=np.int32)
|
|
|
|
assert xla.is_device_constant(x) # lazy iota
|
|
|
|
|
|
|
|
python_should_be_executing = True
|
|
|
|
_ = f(x)
|
|
|
|
|
|
|
|
python_should_be_executing = False # should not recompile
|
|
|
|
x = onp.arange(10, dtype=onp.int32)
|
|
|
|
_ = f(x)
|
|
|
|
|
|
|
|
@parameterized.parameters(jtu.cases_from_list(range(10000)))
|
|
|
|
def test_random_lazy_program(self, seed):
|
|
|
|
|
|
|
|
def random_array(rng):
|
|
|
|
kind = rng.choice(['arr', 'iota', 'eye', 'tri'])
|
|
|
|
if kind == 'arr':
|
|
|
|
dtype = [onp.float32, onp.int32][rng.choice(2)]
|
|
|
|
dim = rng.randint(4)
|
|
|
|
shape = rng.randint(4, size=dim)
|
|
|
|
onp_x = onp.asarray(rng.randn(*shape), dtype=dtype)
|
|
|
|
jax_x = np.array(onp_x, dtype=dtype)
|
|
|
|
elif kind == 'iota':
|
|
|
|
dtype = [onp.float32, onp.int32][rng.choice(2)]
|
|
|
|
size = rng.randint(5)
|
|
|
|
onp_x = onp.arange(size, dtype=dtype)
|
|
|
|
jax_x = lax.iota(dtype, size)
|
|
|
|
elif kind == 'eye':
|
|
|
|
dtype = [onp.float32, onp.int32][rng.choice(2)]
|
|
|
|
N = rng.randint(2, 5)
|
|
|
|
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
|
|
|
|
k = rng.choice([-1, 0, 1])
|
|
|
|
onp_x = onp.eye(N, M, k, dtype=dtype)
|
|
|
|
jax_x = np.eye(N, M, k, dtype=dtype)
|
|
|
|
elif kind == 'tri':
|
|
|
|
dtype = [onp.float32, onp.int32][rng.choice(2)]
|
|
|
|
N = rng.randint(2, 5)
|
|
|
|
M = None if rng.rand() < 0.5 else rng.randint(2, 5)
|
|
|
|
k = rng.choice([-1, 0, 1])
|
|
|
|
onp_x = onp.tri(N, M, k, dtype=dtype)
|
|
|
|
jax_x = np.tri(N, M, k, dtype=dtype)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
assert type(onp_x) is onp.ndarray and type(jax_x) is xla.DeviceArray
|
|
|
|
return onp_x, jax_x
|
|
|
|
|
|
|
|
def random_op(rng, shape):
|
|
|
|
kind = rng.choice(['transpose', 'broadcast', 'reshape'])
|
|
|
|
if kind == 'transpose':
|
|
|
|
perm = tuple(rng.permutation(len(shape)))
|
|
|
|
return Op(partial(onp.transpose, axes=perm),
|
|
|
|
partial(lax.transpose, permutation=perm))
|
|
|
|
elif kind == 'broadcast':
|
|
|
|
n = rng.randint(1, 3)
|
|
|
|
new_sizes = rng.randint(1, 4, size=n)
|
|
|
|
new_ndim = n + len(shape)
|
|
|
|
bcast_dims = tuple(sorted(rng.permutation(new_ndim)[:len(shape)]))
|
|
|
|
shape_iter = iter(shape)
|
|
|
|
new_sizes = iter(rng.randint(1, 4, size=n))
|
|
|
|
new_shape = [next(shape_iter) if i in bcast_dims else next(new_sizes)
|
|
|
|
for i in range(new_ndim)]
|
|
|
|
return Op(partial(lax_reference.broadcast_in_dim, shape=new_shape,
|
|
|
|
broadcast_dimensions=bcast_dims),
|
|
|
|
partial(lax.broadcast_in_dim, shape=new_shape,
|
|
|
|
broadcast_dimensions=bcast_dims))
|
|
|
|
elif kind == 'reshape':
|
|
|
|
new_shape = list(shape)
|
|
|
|
for _ in range(rng.randint(1, 3)):
|
|
|
|
loc = len(new_shape) and rng.randint(len(new_shape))
|
|
|
|
new_shape.insert(loc, 1)
|
|
|
|
new_shape = tuple(new_shape)
|
|
|
|
return Op(partial(onp.reshape, newshape=new_shape),
|
|
|
|
partial(lax.reshape, new_sizes=new_shape))
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
Op = collections.namedtuple('Op', ['onp_fn', 'jax_fn'])
|
|
|
|
|
|
|
|
rng = onp.random.RandomState(seed)
|
|
|
|
onp_x, jax_x = _, orig_x = random_array(rng)
|
|
|
|
ops = []
|
|
|
|
with jtu.count_primitive_compiles() as count:
|
|
|
|
for _ in range(rng.randint(5)):
|
|
|
|
op = random_op(rng, onp.shape(onp_x))
|
|
|
|
onp_x = op.onp_fn(onp_x)
|
|
|
|
jax_x = op.jax_fn(jax_x)
|
|
|
|
ops.append(op)
|
|
|
|
self.assertEqual(count[0], 0)
|
|
|
|
|
|
|
|
kind = rng.choice(['closure', 'npy_value', 'force', 'add'])
|
|
|
|
if kind == 'closure':
|
|
|
|
result = api.jit(lambda x: x + jax_x)(0)
|
|
|
|
self.assertAllClose(onp_x, result, check_dtypes=False)
|
|
|
|
elif kind == 'npy_value':
|
|
|
|
self.assertAllClose(onp_x, jax_x, check_dtypes=False)
|
|
|
|
elif kind == 'force':
|
|
|
|
result = xla._force(jax_x)
|
|
|
|
self.assertAllClose(onp_x, result, check_dtypes=False)
|
|
|
|
elif kind == 'add':
|
|
|
|
result = jax_x + onp.zeros(jax_x.shape, dtype=jax_x.dtype)
|
|
|
|
self.assertAllClose(onp_x, result, check_dtypes=False)
|
|
|
|
else:
|
|
|
|
assert False
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def apply_ops(x):
|
|
|
|
for op in ops:
|
|
|
|
x = op.jax_fn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jit_result = apply_ops(orig_x)
|
|
|
|
self.assertAllClose(jit_result, onp_x, check_dtypes=False)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def apply_ops_closure():
|
|
|
|
x = orig_x
|
|
|
|
for op in ops:
|
|
|
|
x = op.jax_fn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
jit_result = apply_ops_closure()
|
|
|
|
self.assertAllClose(jit_result, onp_x, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_constant_forcing_computations_cached(self):
|
|
|
|
# from https://github.com/google/jax/issues/1909
|
|
|
|
xla._lazy_force_computation.cache_clear() # clear force compile cache
|
|
|
|
big_lazy_x = np.ones((api.device_count(), 100))
|
|
|
|
f = api.pmap(lambda x: 2 * x)
|
|
|
|
_ = f(big_lazy_x)
|
|
|
|
|
|
|
|
with self.count_compiles() as count:
|
|
|
|
_ = f(big_lazy_x)
|
|
|
|
self.assertEqual(count[0], 0)
|
|
|
|
|
|
|
|
def test_zeros_ones_compilation(self):
|
|
|
|
w = np.ones(3) + np.ones(3) # ensure + has a cache entry
|
|
|
|
w.block_until_ready()
|
|
|
|
|
|
|
|
xla._lazy_force_computation.cache_clear() # clear force compile cache
|
|
|
|
|
|
|
|
with self.count_compiles() as count:
|
|
|
|
x = np.ones(3) + np.zeros(3)
|
|
|
|
y = np.ones(3) + np.ones(3)
|
|
|
|
|
|
|
|
self.assertEqual(count[0], 1)
|
|
|
|
self.assertAllClose(x, onp.ones(3), check_dtypes=False)
|
|
|
|
self.assertAllClose(y, onp.ones(3) + onp.ones(3), check_dtypes=False)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
class CustomJVPTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_basic(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * np.cos(x) * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
|
|
|
(np.sin(x), 2 * np.cos(x)),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2 * np.cos(x), check_dtypes=True)
|
|
|
|
|
|
|
|
def test_invariance(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return np.cos(2 * x) / 2.
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return (f(x), 3 * g)
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
def f2(x):
|
|
|
|
y, _ = api.jvp(f, (x,), (x,))
|
|
|
|
return y
|
|
|
|
def f3(x):
|
|
|
|
y, _ = api.jvp(f2, (x,), (x,))
|
|
|
|
return y
|
|
|
|
x = 1.
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
|
|
api.jvp(f2, (x,), (x,)),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
|
|
api.jvp(f3, (x,), (x,)),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_python_control_flow(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
if x > 0:
|
|
|
|
return np.sin(x)
|
|
|
|
else:
|
|
|
|
return np.cos(x)
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
if x > 0:
|
|
|
|
return f(x), 2 * g
|
|
|
|
else:
|
|
|
|
return f(x), 3 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
x = 2.
|
|
|
|
self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(f(-x), np.cos(-x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (1.,)),
|
|
|
|
(np.sin(x), 2.),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.jvp(f, (-x,), (1.,)),
|
|
|
|
(np.cos(-x), 3.),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
assert np.ndim(x) == 0
|
|
|
|
return np.sin(x)
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
assert np.ndim(x) == np.ndim(g) == 0
|
|
|
|
return f(x), 2 * np.cos(x) * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
x = np.arange(3.)
|
|
|
|
xx = np.arange(6.).reshape(2, 3)
|
|
|
|
|
|
|
|
# vmap of f
|
|
|
|
self.assertAllClose(api.vmap(f)(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), np.sin(xx), check_dtypes=True)
|
|
|
|
|
|
|
|
# vmap of jvp of f
|
|
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x),
|
|
|
|
(np.sin(x), 2 * np.cos(x) * x),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx),
|
|
|
|
(np.sin(xx), 2 * np.cos(xx) * xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
# jvp of vmap of f
|
|
|
|
self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)),
|
|
|
|
(np.sin(x), 2 * np.cos(x) * x),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)),
|
|
|
|
(np.sin(xx), 2 * np.cos(xx) * xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
# vmap of jvp of vmap of f
|
|
|
|
self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx),
|
|
|
|
(np.sin(xx), 2 * np.cos(xx) * xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
def test_jit(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * np.cos(x) * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
|
|
|
|
# jit
|
|
|
|
self.assertAllClose(api.jit(f)(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jit(api.jit(f))(x), np.sin(x), check_dtypes=True)
|
|
|
|
|
|
|
|
# jit of jvp
|
|
|
|
self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x),
|
|
|
|
(np.sin(x), 2 * np.cos(x) * x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# jvp of jit
|
|
|
|
self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)),
|
|
|
|
(np.sin(x), 2 * np.cos(x) * x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_pytrees(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return {'b': np.sin(x['a'])}
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), {'b': 2 * np.cos(x['a']) * g['a']}
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
x = {'a': 3.}
|
|
|
|
self.assertAllClose(f(x)['b'], np.sin(x['a']), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jvp(f, (x,), (x,)),
|
|
|
|
({'b': np.sin(x['a'])},
|
|
|
|
{'b': 2 * np.cos(x['a']) * x['a']}),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_kwargs(self):
|
|
|
|
# from https://github.com/google/jax/issues/1938
|
|
|
|
@api.custom_jvp
|
|
|
|
def my_fun(x, y, c=1.):
|
|
|
|
return c * (x + y)
|
|
|
|
def my_jvp(primals, tangents):
|
|
|
|
x, y, c = primals
|
|
|
|
t_x, t_y, t_c = tangents
|
|
|
|
return my_fun(x, y, c), t_c
|
|
|
|
my_fun.defjvp(my_jvp)
|
|
|
|
f = lambda x, y: np.square(my_fun(x, y, c=2.)).sum()
|
|
|
|
f(10., 5.) # doesn't crash
|
|
|
|
api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
|
|
|
|
|
|
|
|
def test_initial_style(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return 3 * x
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.grad(foo)(3.)
|
|
|
|
expected = 2.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
|
|
expected = 0.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_initial_style_vmap(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
assert np.ndim(x) == 0
|
|
|
|
return 3 * x
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
g, = tangents
|
|
|
|
return f(x), 2 * g
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.vmap(foo)(np.ones(3))
|
|
|
|
expected = 3. * np.ones(3)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(np.ones(3))
|
|
|
|
expected = 2. * np.ones(3)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_closed_over_tracers_error_message(self):
|
2020-03-28 14:15:46 -07:00
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def f(x):
|
|
|
|
@api.custom_jvp
|
|
|
|
def g(y):
|
|
|
|
return x + y
|
|
|
|
def g_jvp(primals, tangents):
|
|
|
|
(y,), (t,) = primals, tangents
|
|
|
|
return g(x), 2 * y
|
|
|
|
g.defjvp(g_jvp)
|
|
|
|
return g(1.)
|
|
|
|
|
|
|
|
self.assertRaises(
|
|
|
|
core.UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,)))
|
|
|
|
self.assertRaises(
|
|
|
|
core.UnexpectedTracerError, lambda: api.grad(f)(3.))
|
|
|
|
|
|
|
|
def test_nondiff_arg(self):
|
|
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
|
|
def app(f, x):
|
|
|
|
return f(x)
|
|
|
|
def app_jvp(f, primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
return app(f, x), 3 * t
|
|
|
|
app.defjvp(app_jvp)
|
|
|
|
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
|
|
expected = 2
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,))
|
|
|
|
expected = (2., 3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_nondiff_arg_tracer(self):
|
|
|
|
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
|
|
|
def f(x, y):
|
|
|
|
return x * y
|
|
|
|
def f_jvp(x, primals, tangents):
|
|
|
|
(y,), (t_y,) = primals, tangents
|
|
|
|
return f(x, y), 5 * t_y
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def g(x, y):
|
|
|
|
return f(x, y)
|
|
|
|
|
|
|
|
ans = api.jvp(lambda y: g(2., y), (3.,), (1.,))
|
|
|
|
expected = (6., 5.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap_axes(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_pmap(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_missing_jvp_rule_error_message(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
@api.custom_jvp
|
|
|
|
def foo(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: foo(2))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: api.jvp(foo, (2.,), (1.,)))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No JVP defined for custom_jvp function foo using defjvp.",
|
|
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_jvp_rule_inconsistent_pytree_structures_error_message(self):
|
2020-01-15 15:00:38 -08:00
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return (x**2,)
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return f(x), [2 * x * t, x]
|
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
|
|
"with equal container (pytree) structures, but got "
|
2020-03-24 20:43:33 -07:00
|
|
|
"{} and {} respectively.".format(
|
2020-01-15 15:00:38 -08:00
|
|
|
tree_util.tree_structure((1,)),
|
|
|
|
tree_util.tree_structure([1, 2]))
|
|
|
|
),
|
|
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
|
2020-03-24 20:43:33 -07:00
|
|
|
def test_primal_tangent_aval_disagreement_error_message(self):
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return f(x), np.reshape(t, (1,))
|
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce primal and tangent outputs "
|
|
|
|
"with equal shapes and dtypes, but got float32[] and float32[1] "
|
|
|
|
"respectively."),
|
|
|
|
lambda: api.jvp(f, (np.float32(2.),), (np.float32(1.),)))
|
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_jvp_rule_doesnt_return_pair_error_message(self):
|
|
|
|
# https://github.com/google/jax/issues/2516
|
|
|
|
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def foo_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return t
|
|
|
|
|
|
|
|
f(2.) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom JVP rule must produce a pair (list or tuple of length two) "
|
|
|
|
"representing primal and tangent outputs, got 1.0"),
|
|
|
|
lambda: api.jvp(f, (2.,), (1.,)))
|
|
|
|
|
2020-03-28 13:52:40 -07:00
|
|
|
def test_multiple_rule_invocations(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def expit(x):
|
|
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
|
|
|
|
@expit.defjvp
|
|
|
|
def _expit_jvp(primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
ans = expit(x)
|
|
|
|
t_out = t * ans * (1 - ans)
|
|
|
|
return ans, t_out
|
|
|
|
|
|
|
|
def scanned_fun(c, _):
|
|
|
|
return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
|
|
return c[-1]
|
|
|
|
|
|
|
|
# just make sure these don't crash
|
|
|
|
foo(3.)
|
|
|
|
grad(foo)(3.)
|
|
|
|
grad(lambda x: jax.vmap(foo)(x).sum())(np.arange(3.))
|
|
|
|
|
|
|
|
def test_hard_stuff(self):
|
|
|
|
arr = np.ones((5, 2, 2))
|
|
|
|
api.jit(jax.vmap(np.linalg.det))(arr) # doesn't crash
|
|
|
|
|
|
|
|
def test_hard_stuff2(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def f(x):
|
|
|
|
return lax.tie_in(x, onp.zeros(x.shape, x.dtype))
|
|
|
|
|
|
|
|
@f.defjvp
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return f(x), t
|
|
|
|
|
|
|
|
# don't crash
|
|
|
|
jax.jit(jax.vmap(f))(np.arange(3.))
|
|
|
|
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
|
|
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
|
|
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
|
|
|
|
jax.jvp(jax.vmap(f), (np.arange(3.),), (np.ones(3),))
|
|
|
|
|
|
|
|
def test_hard_stuff3(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def relu(x):
|
|
|
|
return np.maximum(x, 0)
|
|
|
|
|
|
|
|
@relu.defjvp
|
|
|
|
def _relu_jvp(primals, tangents):
|
|
|
|
x, = primals
|
|
|
|
t, = tangents
|
|
|
|
return relu(x), lax.select(x > 0, t, lax.full_like(t, 0))
|
|
|
|
|
|
|
|
def scanned_fun(c, _):
|
|
|
|
return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
c, _ = lax.scan(scanned_fun, [x, 0., 0., 0., 0.], None, length=10)
|
|
|
|
return c[-1]
|
|
|
|
|
|
|
|
# don't crash
|
|
|
|
jax.jit(jax.vmap(f))(np.arange(3.))
|
|
|
|
jax.jit(jax.vmap(jax.grad(f)))(np.arange(3.))
|
|
|
|
jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(np.arange(3.))
|
|
|
|
jax.grad(lambda x: jax.vmap(f)(x).sum())(np.arange(3.))
|
|
|
|
jax.jvp(jax.jit(jax.vmap(f)), (np.arange(3.),), (np.ones(3),))
|
2020-01-15 15:00:38 -08:00
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_eval_shape(self):
|
|
|
|
@jax.custom_jvp
|
|
|
|
def expit(x):
|
|
|
|
return 1 / (1 + lax.exp(-x))
|
|
|
|
|
|
|
|
@expit.defjvp
|
|
|
|
def _expit_jvp(primals, tangents):
|
|
|
|
(x,), (t,) = primals, tangents
|
|
|
|
ans = expit(x)
|
|
|
|
t_out = t * ans * (1 - ans)
|
|
|
|
return ans, t_out
|
|
|
|
|
|
|
|
# don't crash
|
|
|
|
api.eval_shape(expit, np.ones((2, 3)))
|
|
|
|
api.eval_shape(api.grad(lambda x: expit(x).sum()), np.ones((2, 3)))
|
|
|
|
|
2020-04-10 11:45:33 -07:00
|
|
|
def test_jaxpr_zeros(self):
|
|
|
|
# from https://github.com/google/jax/issues/2657
|
|
|
|
@api.custom_jvp
|
|
|
|
def f(A, b):
|
|
|
|
return A @ b
|
|
|
|
|
|
|
|
def f_jvp(primals, tangents):
|
|
|
|
A, b = primals
|
|
|
|
dA, db = tangents
|
|
|
|
z = f(A, b)
|
|
|
|
dz = A @ db + dA @ b
|
|
|
|
return z, dz
|
|
|
|
|
|
|
|
f.defjvp(f_jvp)
|
|
|
|
|
|
|
|
def experiment(theta):
|
|
|
|
def step(q, _):
|
|
|
|
z = f(np.eye(3), np.ones(3) * theta)
|
|
|
|
q += z[0]
|
|
|
|
return q, q
|
|
|
|
|
|
|
|
q = 0.
|
|
|
|
q, _ = lax.scan(step, q, None, 4)
|
|
|
|
return q
|
|
|
|
|
|
|
|
grad(experiment)(1.) # doesn't crash
|
|
|
|
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
class CustomVJPTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_basic(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.grad(f)(x), 2 * np.cos(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.value_and_grad(f)(x),
|
|
|
|
(np.sin(x), 2 * np.cos(x)),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
def test_invariance(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return np.cos(2 * x) / 2.
|
|
|
|
def f_fwd(x):
|
|
|
|
return (f(x), x)
|
|
|
|
def f_rev(x, g):
|
|
|
|
return (g * 3,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
def f2(x):
|
|
|
|
y, _ = api.value_and_grad(f)(x)
|
|
|
|
return y
|
|
|
|
def f3(x):
|
|
|
|
y, _ = api.value_and_grad(f2)(x)
|
|
|
|
return y
|
|
|
|
x = 1.
|
|
|
|
self.assertAllClose(f(x), f2(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(f(x), f3(x), check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f2)(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(x), api.grad(f3)(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_python_control_flow(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
if x > 0:
|
|
|
|
return np.sin(x)
|
|
|
|
else:
|
|
|
|
return np.cos(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
if x > 0:
|
|
|
|
return f(x), x
|
|
|
|
else:
|
|
|
|
return f(x), x
|
|
|
|
def f_rev(x, g):
|
|
|
|
if x > 0:
|
|
|
|
return (2 * g,)
|
|
|
|
else:
|
|
|
|
return (3 * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
x = 2.
|
|
|
|
self.assertAllClose(f(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(f(-x), np.cos(-x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.value_and_grad(f)(x), (np.sin(x), 2.),
|
|
|
|
check_dtypes=False)
|
|
|
|
self.assertAllClose(api.value_and_grad(f)(-x), (np.cos(-x), 3.),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_vmap(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
assert np.ndim(x) == 0
|
|
|
|
return np.sin(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
assert np.ndim(x) == 0
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
x = np.arange(3.)
|
|
|
|
xx = np.arange(6.).reshape(2, 3)
|
|
|
|
|
|
|
|
# vmap of f
|
|
|
|
self.assertAllClose(api.vmap(f)(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(f))(xx), np.sin(xx), check_dtypes=True)
|
|
|
|
|
|
|
|
# vmap of grad of f
|
|
|
|
self.assertAllClose(api.vmap(api.grad(f))(x), 2 * np.cos(x),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.value_and_grad(f))(x),
|
|
|
|
(np.sin(x), 2 * np.cos(x)),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * np.cos(xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx),
|
|
|
|
(np.sin(xx), 2 * np.cos(xx)),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
# grad of vmap of f
|
|
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x),
|
|
|
|
2 * np.cos(x),
|
|
|
|
check_dtypes=True)
|
|
|
|
self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx),
|
|
|
|
2 * np.cos(xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
# vmap of grad of vmap of f
|
|
|
|
self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx),
|
|
|
|
2 * np.cos(xx),
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
def test_jit(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
x = 3.
|
|
|
|
|
|
|
|
# jit
|
|
|
|
self.assertAllClose(api.jit(f)(x), np.sin(x), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.jit(api.jit(f))(x), np.sin(x), check_dtypes=True)
|
|
|
|
|
|
|
|
# jit of grad
|
|
|
|
self.assertAllClose(api.jit(api.grad(f))(x), 2 * np.cos(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# grad of jit
|
|
|
|
self.assertAllClose(api.grad(api.jit(f))(x), 2 * np.cos(x),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
def test_pytrees(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return {'b': np.sin(x['a'])}
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), {'r': np.cos(x['a'])}
|
|
|
|
def f_bwd(res, g):
|
|
|
|
cos_x = res['r']
|
|
|
|
return ({'a': 2 * cos_x * g['b']},)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
x = {'a': 3.}
|
|
|
|
self.assertAllClose(f(x)['b'], np.sin(x['a']), check_dtypes=True)
|
|
|
|
self.assertAllClose(api.grad(lambda x: f(x)['b'])(x),
|
|
|
|
{'a': 2 * np.cos(x['a'])},
|
|
|
|
check_dtypes=True)
|
|
|
|
|
|
|
|
def test_jvp_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
|
|
lambda: api.jvp(f, (3.,), (1.,)))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.",
|
|
|
|
lambda: api.jvp(api.vmap(f), (np.arange(3.),), (np.ones(3),)))
|
|
|
|
|
|
|
|
def test_kwargs(self):
|
|
|
|
# from https://github.com/google/jax/issues/1938
|
|
|
|
@api.custom_vjp
|
|
|
|
def my_fun(x, y, c=1.):
|
|
|
|
return c * (x + y)
|
|
|
|
my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
|
|
|
|
lambda _, g: (g, g, g))
|
|
|
|
f = lambda x, y: np.square(my_fun(x, y, c=2.)).sum()
|
|
|
|
f(10., 5.) # doesn't crash
|
|
|
|
api.grad(f)(10., 5.) # doesn't crash
|
|
|
|
|
|
|
|
def test_initial_style(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return np.sin(x)
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.grad(foo)(3.)
|
|
|
|
expected = 2. * np.cos(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
|
|
expected = -2. * np.sin(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=True)
|
|
|
|
|
|
|
|
def test_initial_style_vmap(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
assert np.ndim(x) == 0
|
|
|
|
return 3 * x
|
|
|
|
def f_fwd(x):
|
|
|
|
return f(x), np.cos(x)
|
|
|
|
def f_rev(cos_x, g):
|
|
|
|
return (2 * cos_x * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
def foo(x):
|
|
|
|
out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1)
|
|
|
|
return out
|
|
|
|
|
|
|
|
ans = api.vmap(foo)(np.arange(3.))
|
|
|
|
expected = 3. * np.arange(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(lambda x: api.vmap(foo)(x).sum())(np.arange(3.))
|
|
|
|
expected = 2. * np.cos(np.arange(3.))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_nondiff_arg(self):
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
|
|
def app(f, x):
|
|
|
|
return f(x)
|
|
|
|
def app_fwd(f, x):
|
|
|
|
return app(f, x), np.cos(x)
|
|
|
|
def app_rev(f, cos_x, g):
|
|
|
|
return (cos_x * g,)
|
|
|
|
app.defvjp(app_fwd, app_rev)
|
|
|
|
|
|
|
|
ans = app(lambda x: 2 * x, 1)
|
|
|
|
expected = 2
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.)
|
|
|
|
expected = (2., np.cos(1.))
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def test_nondiff_arg_tracer(self):
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0,))
|
|
|
|
def f(x, y):
|
|
|
|
return x * y
|
|
|
|
def f_fwd(x, y):
|
|
|
|
return f(x, y), np.cos(y)
|
|
|
|
def f_rev(x, cos_y, g):
|
|
|
|
return (cos_y * g,)
|
|
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def g(x, y):
|
|
|
|
return f(x, y)
|
|
|
|
|
|
|
|
ans = g(2, 3.)
|
|
|
|
expected = 6.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(g, 1)(2., 3.)
|
|
|
|
expected = np.cos(3.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def test_vmap_axes(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_pmap(self):
|
|
|
|
raise unittest.SkipTest("TODO") # TODO(mattjj): write test
|
|
|
|
|
|
|
|
def test_missing_vjp_rule_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def foo(x):
|
|
|
|
return x ** 2
|
|
|
|
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
|
|
lambda: foo(2))
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
AttributeError,
|
|
|
|
r"No VJP defined for custom_vjp function foo using defvjp.",
|
|
|
|
lambda: api.grad(foo)(2.))
|
|
|
|
|
|
|
|
def test_vjp_rule_inconsistent_pytree_structures_error(self):
|
|
|
|
@api.custom_vjp
|
|
|
|
def f(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def foo_fwd(x):
|
|
|
|
return x, None
|
|
|
|
|
|
|
|
def foo_bwd(_, g):
|
|
|
|
return g
|
|
|
|
|
|
|
|
f.defvjp(foo_fwd, foo_bwd)
|
|
|
|
|
|
|
|
f(2) # doesn't crash
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
re.escape(
|
|
|
|
"Custom VJP rule must produce an output with the same container "
|
|
|
|
"(pytree) structure as the args tuple of the primal function, "
|
|
|
|
"and in particular must produce a tuple of length equal to the "
|
|
|
|
"number of arguments to the primal function, but got VJP output "
|
|
|
|
"structure {} for primal input structure {}.".format(
|
|
|
|
tree_util.tree_structure(1),
|
|
|
|
tree_util.tree_structure((1,)))
|
|
|
|
),
|
|
|
|
lambda: api.grad(f)(2.))
|
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def test_issue2511(self):
|
|
|
|
arr = np.ones((5, 2, 2))
|
|
|
|
foo = lambda x: api.vmap(np.linalg.det, (0,))(x)
|
|
|
|
api.jit(foo)(arr) # doesn't crash
|
|
|
|
|
2020-04-02 22:52:07 -07:00
|
|
|
def test_lowering_out_of_traces(self):
|
|
|
|
# https://github.com/google/jax/issues/2578
|
|
|
|
|
|
|
|
class F(collections.namedtuple("F", ["a"])):
|
|
|
|
def __call__(self, x):
|
|
|
|
return jax.nn.relu(self.a) * x
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
def g(f, x):
|
|
|
|
return f(x)
|
|
|
|
|
|
|
|
jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash
|
|
|
|
|
2020-04-22 23:29:32 -07:00
|
|
|
def test_nondiff_argnums_stop_gradient(self):
|
|
|
|
# https://github.com/google/jax/issues/2784
|
|
|
|
@partial(api.custom_vjp, nondiff_argnums=(0, 1))
|
|
|
|
def _clip_gradient(lo, hi, x):
|
|
|
|
return x # identity function
|
|
|
|
|
|
|
|
def clip_gradient_fwd(lo, hi, x):
|
|
|
|
# return x, None
|
|
|
|
return x, (hi, )
|
|
|
|
|
|
|
|
def clip_gradient_bwd(lo, hi, _, g):
|
|
|
|
return (np.clip(g, lo, hi),)
|
|
|
|
|
|
|
|
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
|
|
|
|
|
|
|
|
def clip_gradient(x):
|
|
|
|
lo = -1
|
|
|
|
hi = x + 1 # causes things to break
|
|
|
|
return _clip_gradient(lo, hi, x)
|
|
|
|
|
|
|
|
jax.grad(clip_gradient)(1.) # doesn't crash
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2020-03-23 14:29:22 -07:00
|
|
|
class DeprecatedCustomTransformsTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def test_defvjp_all(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return 2. * foo_p.bind(x)
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
|
|
|
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_all_const(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return foo_p.bind(x)
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
|
|
|
self.assertAllClose(val_ans, 9., check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 12., check_dtypes=True)
|
|
|
|
|
|
|
|
def test_defvjp_all_higher_order_revmode(self):
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x): return 2. * foo_p.bind(x)
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
|
|
|
|
ans = api.grad(api.grad(foo))(3.)
|
|
|
|
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_all_multiple_arguments(self):
|
|
|
|
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
|
|
|
|
# the first argument in one case
|
|
|
|
|
|
|
|
foo_p = Primitive('foo')
|
|
|
|
def foo(x, y): return foo_p.bind(x, y)
|
|
|
|
|
|
|
|
def vjpfun(x, y):
|
|
|
|
out = x**2 + y**3
|
|
|
|
vjp = lambda g: (g + x + y, g * x * 9.)
|
|
|
|
return out, vjp
|
|
|
|
|
|
|
|
ad.defvjp_all(foo_p, vjpfun)
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
|
|
|
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
|
|
|
|
|
|
|
|
ans = api.grad(foo, (0, 1))(3., 4.)
|
|
|
|
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_all_custom_transforms(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x):
|
|
|
|
return np.sin(x)
|
|
|
|
|
|
|
|
api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x,)))
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
|
|
|
self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 3., check_dtypes=False)
|
|
|
|
|
|
|
|
# TODO(mattjj): add defvjp_all test with pytree arguments
|
|
|
|
|
|
|
|
def test_defvjp(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x, y):
|
|
|
|
return np.sin(x * y)
|
|
|
|
|
|
|
|
api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
|
|
|
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 0., check_dtypes=False)
|
|
|
|
|
|
|
|
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
|
|
|
|
self.assertAllClose(ans_0, 0., check_dtypes=False)
|
|
|
|
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_higher_order(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x):
|
|
|
|
return np.sin(2. * x)
|
|
|
|
|
|
|
|
api.defvjp(foo, lambda g, _, x: g * np.cos(x))
|
|
|
|
ans = api.grad(api.grad(foo))(2.)
|
|
|
|
expected = api.grad(api.grad(np.sin))(2.)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_defvjp_use_ans(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def foo(x, y):
|
|
|
|
return np.sin(x * y)
|
|
|
|
|
|
|
|
api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
|
|
|
|
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
|
|
|
|
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
|
|
|
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
# def test_defjvp_closure_error(self):
|
|
|
|
# def foo(x):
|
|
|
|
# @api.custom_transforms
|
|
|
|
# def bar(y):
|
|
|
|
# return x * y
|
|
|
|
|
|
|
|
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
|
|
|
|
# return bar(x)
|
|
|
|
# jtu.check_raises(
|
|
|
|
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
|
|
|
|
# "Detected differentiation with respect to closed-over values with "
|
|
|
|
# "custom JVP rule, which isn't supported.")
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
# def test_defvjp_closure_error(self):
|
|
|
|
# def foo(x):
|
|
|
|
# @api.custom_transforms
|
|
|
|
# def bar(y):
|
|
|
|
# return x * y
|
|
|
|
|
|
|
|
# api.defvjp(bar, lambda g, ans, y: x * y)
|
|
|
|
# return bar(x)
|
|
|
|
# jtu.check_raises(
|
|
|
|
# lambda: grad(foo)(1.,), ValueError,
|
|
|
|
# "Detected differentiation w.r.t. variables from outside "
|
|
|
|
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
|
|
|
|
# "defvjp_all only support differentiation w.r.t. positional arguments.")
|
|
|
|
|
|
|
|
def test_custom_transforms_eval_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans = f((1, 2))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
|
|
|
|
def test_custom_transforms_jit_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans = jit(f)((1, 2))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
|
|
|
|
def test_custom_transforms_jit_with_pytrees_consts(self):
|
|
|
|
# The purpose of this test is to exercise the custom_transforms default
|
|
|
|
# translation rule in how it deals with constants that are too large to be
|
|
|
|
# treated as literals (at the time of writing).
|
|
|
|
z = onp.arange(10.)
|
|
|
|
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': z * b}
|
|
|
|
|
|
|
|
ans = jit(f)((1, 2))
|
|
|
|
self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_transforms_jvp_with_pytrees(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
|
|
|
|
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
|
|
|
self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
|
|
|
|
|
|
|
|
def test_custom_transforms_vmap_with_pytrees(self):
|
|
|
|
raise unittest.SkipTest("Test deprecated custom_transforms")
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x):
|
|
|
|
a, b = x[0], x[1]
|
|
|
|
return {'hi': 2 * a, 'bye': 2 * b}
|
|
|
|
|
|
|
|
ans = api.vmap(f)((onp.arange(3), onp.ones((3, 2))))
|
|
|
|
expected = {'hi': 2 * onp.arange(3), 'bye': 2 * onp.ones((3, 2))}
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_transforms_jvp_with_closure(self):
|
|
|
|
def f(x):
|
|
|
|
@api.custom_transforms
|
|
|
|
def g(y):
|
|
|
|
return x * y
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
ans = api.grad(f)(1.)
|
|
|
|
expected = 2.
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_gradient(self):
|
|
|
|
@api.custom_gradient
|
|
|
|
def f(x):
|
|
|
|
return x ** 2, lambda g: (g * x,)
|
|
|
|
|
|
|
|
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
|
|
|
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
|
|
|
|
|
|
|
def test_custom_vjp_zeros(self):
|
|
|
|
@api.custom_transforms
|
|
|
|
def f(x, y):
|
|
|
|
return 2 * x, 3 * y
|
|
|
|
|
|
|
|
def f_vjp(x, y):
|
|
|
|
return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
|
|
|
|
|
|
|
|
api.defvjp_all(f, f_vjp, )
|
|
|
|
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
|
|
|
|
|
|
|
|
def test_custom_transforms_vjp_nones(self):
|
|
|
|
# issue rasied by jsnoek@ and jumper@
|
|
|
|
@jax.custom_transforms
|
|
|
|
def solve(a, b):
|
|
|
|
return np.dot(np.linalg.inv(a), b)
|
|
|
|
# print(solve(a, b))
|
|
|
|
|
|
|
|
def solve_vjp(a, b):
|
|
|
|
x = solve(a, b)
|
|
|
|
def vjp(x_tangent):
|
|
|
|
dx = np.dot(solve(a, x_tangent), x.T)
|
|
|
|
out = (dx, b * 0.)
|
|
|
|
return out
|
|
|
|
return x, vjp
|
|
|
|
jax.defvjp_all(solve, solve_vjp)
|
|
|
|
gf = grad(lambda a,b: np.sum(solve(a, b)))
|
|
|
|
|
|
|
|
n = 3
|
|
|
|
a_in = np.linspace(0, 1, n)[:, None]
|
|
|
|
a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1
|
|
|
|
real_x = onp.random.RandomState(0).randn(n)
|
|
|
|
b = np.dot(a + np.eye(a.shape[0]), real_x)
|
|
|
|
print(gf(a, b)) # doesn't crash
|
|
|
|
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
absltest.main()
|