rocm_jax/tests/api_test.py
2019-01-22 15:34:09 -08:00

327 lines
10 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import numpy as onp
from absl.testing import absltest
from jax import test_util as jtu
import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
from jax import api
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
from jax.interpreters.ad import defjvp
from jax.interpreters.xla import DeviceArray
from jax.abstract_arrays import concretization_err_msg
from jax.config import config
config.parse_flags_with_absl()
class APITest(jtu.JaxTestCase):
def test_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
assert grad(f)(1.0, 1.0, 1.0, flag=True) == 1.0
assert grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == 2.0
assert grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (3.0, 1.0)
def test_value_and_grad_argnums(self):
def f(x, y, z, flag=False):
assert flag
return 1.0 * x + 2.0 * y + 3.0 * z
y = f(1.0, 1.0, 1.0, flag=True)
assert api.value_and_grad(f)(1.0, 1.0, 1.0, flag=True) == (y, 1.0)
assert api.value_and_grad(f, argnums=1)(1.0, 1.0, 1.0, flag=True) == (y, 2.0)
assert api.value_and_grad(f, argnums=(2, 0))(1.0, 1.0, 1.0, flag=True) == (y, (3.0, 1.0))
def test_jit_static_args(self):
side = []
def f(x, y, z, flag=False, flag2=False):
assert flag
side.append(None)
return 100*x + 10*y + z
f1 = jit(f)
assert f1(1, 2, 3, flag=True) == 123
assert len(side) == 1
assert f1(2, 1, 3, flag=True) == 213
assert len(side) == 1
assert f1(2, 1, 3, flag=True, flag2=True) == 213
assert len(side) == 2
side[:] = []
f2 = jit(f, static_argnums=[0,2])
assert f2(1, 2, 3, flag=True) == 123
assert len(side) == 1
assert f2(1, 3, 3, flag=True) == 133
assert len(side) == 1
assert f2(2, 2, 3, flag=True) == 223
assert len(side) == 2
assert f2(2, 4, 3, flag=True) == 243
assert len(side) == 2
assert f2(2, 4, 3, flag=True, flag2=True) == 243
assert len(side) == 3
assert f2(2, 5, 3, flag=True, flag2=True) == 253
assert len(side) == 3
def test_grad_of_jit(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
assert grad(f)(1.0) == 2.0
assert len(side) == 1
assert grad(f)(2.0) == 4.0
assert len(side) == 1
def test_jit_of_grad(self):
side = []
@jit
def f(x):
side.append(None)
return x * x
g = jit(grad(f))
assert g(1.0) == 2.0
assert len(side) == 1
assert g(2.0) == 4.0
assert len(side) == 1
def test_bad_input(self):
def f(x):
return x
jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError,
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError,
"Argument 'foo' of type <.*'str'> is not a valid JAX type")
# TODO(dougalm): enable when we remove 'None' from pytree nodes
# def test_bad_output(self):
# def f(x):
# pass
# grad(f)(onp.zeros(3))
# jit(f)(onp.zeros(3))
# assert False
def test_grad_tuple_output(self):
jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_unit_output(self):
jtu.check_raises(lambda: grad(lambda x: ())(onp.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_grad_nonscalar_output(self):
jtu.check_raises(lambda: grad(lambda x: x)(onp.zeros(3)), TypeError,
"Gradient only defined for scalar-output functions. ")
def test_unwrapped_numpy(self):
def f(x):
return onp.exp(x)
jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception,
"Tracer can't be used with raw numpy functions. "
"You might have\n import numpy as np\ninstead of\n"
" import jax.numpy as np")
def test_binop_mismatch(self):
def f(x, y):
return x + y
jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
ValueError,
"Incompatible shapes for broadcasting: ((3,), (4,))")
def test_dot_mismatch(self):
def f(x, y):
return np.dot(x, y)
jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
TypeError,
"Incompatible shapes for dot: got (3,) and (4,).")
def test_switch_value_jit(self):
def f(x):
y = x > 0
if y:
return x
else:
return -x
assert grad(f)(1.0) == 1.0
assert grad(f)(-1.0) == -1.0
jtu.check_raises(lambda: jit(f)(1), TypeError, concretization_err_msg(bool))
def test_range_err(self):
def f(x, n):
for i in range(n):
x = x + i
return x
assert jit(f, static_argnums=(1,))(0, 5) == 10
jtu.check_raises_regexp(
lambda: jit(f)(0, 5), TypeError,
"('JaxprTracer' object cannot be interpreted as an integer"
"|Abstract value passed to .*)")
def test_casts(self):
for castfun in [float, complex, hex, oct] + list(six.integer_types):
f = lambda x: castfun(x)
jtu.check_raises_regexp(
lambda: jit(f)(0), TypeError,
"('JaxprTracer' object cannot be interpreted as an integer"
"|Abstract value passed to .*)")
def test_unimplemented_interpreter_rules(self):
foo_p = Primitive('foo')
def foo(x):
return foo_p.bind(x)
jtu.check_raises(lambda: foo(1.0), NotImplementedError,
"Evaluation rule for 'foo' not implemented")
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"Abstract evaluation for 'foo' not implemented")
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Forward-mode differentiation rule for 'foo' not implemented")
def_abstract_eval(foo_p, lambda x: x)
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"XLA translation rule for 'foo' not implemented")
foo_p.def_impl(lambda x: x)
defjvp(foo_p, lambda g, x: foo(g))
jtu.check_raises(lambda: grad(foo)(1.0), NotImplementedError,
"Reverse-mode differentiation rule for 'foo' not implemented")
def test_device_put_and_get(self):
x = onp.arange(12.).reshape((3, 4)).astype("float32")
dx = device_put(x)
assert isinstance(dx, DeviceArray)
x2 = device_get(dx)
assert isinstance(x2, onp.ndarray)
assert onp.all(x == x2)
y = [x, (2 * x, 3 * x)]
dy = device_put(y)
y2 = device_get(dy)
assert isinstance(y2, list)
assert isinstance(y2[0], onp.ndarray)
assert onp.all(y2[0] == x)
assert isinstance(y2[1], tuple)
assert isinstance(y2[1][0], onp.ndarray)
assert onp.all(y2[1][0] == 2 * x)
assert isinstance(y2[1][1], onp.ndarray)
assert onp.all(y2[1][1] == 3 * x)
@jtu.skip_on_devices("tpu")
def test_jacobian(self):
R = onp.random.RandomState(0).randn
A = R(4, 3)
x = R(3)
f = lambda x: np.dot(A, x)
assert onp.allclose(jacfwd(f)(x), A)
assert onp.allclose(jacrev(f)(x), A)
f = lambda x: np.tanh(np.dot(A, x))
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))
@jtu.skip_on_devices("tpu")
def test_hessian(self):
R = onp.random.RandomState(0).randn
A = R(4, 4)
x = R(4)
f = lambda x: np.dot(x, np.dot(A, x))
assert onp.allclose(hessian(f)(x), A + A.T)
def test_std_basis(self):
basis = api._std_basis(np.zeros(3))
assert getattr(basis, "shape", None) == (3, 3)
assert onp.allclose(basis, onp.eye(3))
basis = api._std_basis(np.zeros((3, 3)))
assert getattr(basis, "shape", None) == (9, 3, 3)
assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3))
basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))])
assert isinstance(basis, list) and len(basis) == 2
assert getattr(basis[0], "shape", None) == (16,)
assert isinstance(basis[1], tuple) and len(basis[1]) == 2
assert getattr(basis[1][0], "shape", None) == (16, 3)
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
@jtu.skip_on_devices("tpu")
def test_jacobian_on_pytrees(self):
for jacfun in [jacfwd, jacrev]:
ans = jacfun(lambda x, y: (x, y))(0., 1.)
expected = (1., 0.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
expected = (0., 1.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
expected = ((1., 0.),
(0., 1.),)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = jacfun(lambda x: x[:2])((1., 2., 3.))
expected = ((1., 0., 0.),
(0., 1., 0.))
self.assertAllClose(ans, expected, check_dtypes=False)
R = onp.random.RandomState(0).randn
x = R(2)
y = R(3)
ans = jacfun(lambda x, y: {'x': x, 'xy': np.outer(x, y)})(x, y)
expected = {'x': onp.eye(2),
'xy': onp.kron(onp.eye(2), y[:, None]).reshape(2, 3, 2)}
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def test_hessian_on_pytrees(self):
ans = hessian(lambda x: np.array(x)**2)((1., 2.))
expected = ((onp.array([2., 0.]), onp.array([0., 0.])),
(onp.array([0., 0.]), onp.array([0., 2.])))
self.assertAllClose(ans, expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main()