add error checks so that #603 isn't silent fail

This commit is contained in:
Matthew Johnson 2019-04-12 12:01:19 -07:00
parent 1e6b033cd3
commit 18671fa027
2 changed files with 77 additions and 8 deletions

View File

@ -271,7 +271,7 @@ def value_and_grad(fun, argnums=0, has_aux=False):
ans, vjp_py = vjp(f_partial, *dyn_args)
else:
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
_check_scalar(ans)
_check_scalar_real(ans)
g = vjp_py(onp.ones((), onp.result_type(ans)))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
@ -309,6 +309,7 @@ def jacfwd(fun, argnums=0):
f_partial, dyn_args = _argnums_partial(f, argnums, args)
pushfwd = partial(jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
_check_real(y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
@ -338,6 +339,7 @@ def jacrev(fun, argnums=0):
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = _argnums_partial(f, argnums, args)
_check_real(dyn_args)
y, pullback = vjp(f_partial, *dyn_args)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
@ -372,7 +374,13 @@ def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(onp.size, leaves))
# TODO(mattjj): use a symbolic identity matrix here
return _unravel_array_into_pytree(pytree, 1, onp.eye(ndim))
dtype = onp.result_type(*leaves)
if not onp.issubdtype(dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
flat_basis = onp.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, flat_basis)
def _unravel_array_into_pytree(pytree, axis, arr):
leaves, treedef = tree_flatten(pytree)
@ -779,14 +787,28 @@ def _check_args(args):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))
def _check_scalar(x):
def _check_scalar_real(x):
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
try:
aval = core.get_aval(x)
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
raise TypeError(msg(x))
except TypeError:
raise TypeError(msg(x))
else:
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
raise TypeError(msg(x))
if not onp.issubdtype(aval.dtype, onp.floating):
msg2 = ("Gradient only defined for functions with output dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real scalars), but "
"got {}. For holomorphic differentiation, apply `np.real` at the "
"end of the function.")
raise TypeError(msg2.format(aval.dtype.name))
def _check_real(x):
aval = core.get_aval(x)
if not onp.issubdtype(aval.dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(aval.dtype.name))
def custom_transforms(fun):
@ -814,9 +836,14 @@ def _elementwise_std_basis(pytree):
arity = len(leaves)
dims = map(onp.size, leaves)
# TODO(mattjj): use symbolic constants
basis_array = onp.stack(
[onp.concatenate([onp.ones(dims[j]) if i == j else onp.zeros(dims[j])
for j in range(arity)]) for i in range(arity)])
dtype = onp.result_type(*leaves)
if not onp.issubdtype(dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
basis_array = onp.stack([onp.concatenate(
[onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
for j in range(arity)]) for i in range(arity)])
return _unravel_array_into_pytree(pytree, 1, basis_array)
def jarrett(fun):

View File

@ -427,6 +427,48 @@ class APITest(jtu.JaxTestCase):
jaxpr2 = api.make_jaxpr(f2_vjp)(y)
assert len(jaxpr2.constvars) == 2
def test_complex_grad_raises_error(self):
self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
grad(lambda x: np.real(np.sin(x)))(1 + 2j) # doesn't crash
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
# def test_complex_jacfwd(self):
# # code based on https://github.com/google/jax/issues/603
# zs = 0.5j * onp.arange(5) + onp.arange(5)
# def f(z):
# return np.cos(np.linalg.norm(2 * z))
# ans = jacfwd(f)(zs)
# expected = grad(f)(zs)
# self.assertAllClose(ans, expected, check_dtypes=True)
def test_complex_jacfwd_raises_error(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * onp.arange(5) + onp.arange(5)
def f(z):
return np.cos(np.linalg.norm(2 * z))
self.assertRaises(TypeError, lambda: jacfwd(f)(zs))
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
# def test_complex_jacrev(self):
# # code based on https://github.com/google/jax/issues/603
# zs = 0.5j * onp.arange(5) + onp.arange(5)
# def f(z):
# return np.cos(np.linalg.norm(2 * z))
# ans = jacrev(f)(zs)
# expected = grad(f)(zs)
# self.assertAllClose(ans, expected, check_dtypes=True)
def test_complex_jacrev_raises_error(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * onp.arange(5) + onp.arange(5)
def f(z):
return np.cos(np.linalg.norm(2 * z))
self.assertRaises(TypeError, lambda: jacrev(f)(zs))
if __name__ == '__main__':
absltest.main()