mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add error checks so that #603 isn't silent fail
This commit is contained in:
parent
1e6b033cd3
commit
18671fa027
43
jax/api.py
43
jax/api.py
@ -271,7 +271,7 @@ def value_and_grad(fun, argnums=0, has_aux=False):
|
||||
ans, vjp_py = vjp(f_partial, *dyn_args)
|
||||
else:
|
||||
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
|
||||
_check_scalar(ans)
|
||||
_check_scalar_real(ans)
|
||||
g = vjp_py(onp.ones((), onp.result_type(ans)))
|
||||
g = g[0] if isinstance(argnums, int) else g
|
||||
if not has_aux:
|
||||
@ -309,6 +309,7 @@ def jacfwd(fun, argnums=0):
|
||||
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
||||
pushfwd = partial(jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
_check_real(y)
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
|
||||
@ -338,6 +339,7 @@ def jacrev(fun, argnums=0):
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
||||
_check_real(dyn_args)
|
||||
y, pullback = vjp(f_partial, *dyn_args)
|
||||
jac = vmap(pullback)(_std_basis(y))
|
||||
jac = jac[0] if isinstance(argnums, int) else jac
|
||||
@ -372,7 +374,13 @@ def _std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(onp.size, leaves))
|
||||
# TODO(mattjj): use a symbolic identity matrix here
|
||||
return _unravel_array_into_pytree(pytree, 1, onp.eye(ndim))
|
||||
dtype = onp.result_type(*leaves)
|
||||
if not onp.issubdtype(dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
|
||||
flat_basis = onp.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
@ -779,14 +787,28 @@ def _check_args(args):
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(arg, type(arg)))
|
||||
|
||||
def _check_scalar(x):
|
||||
def _check_scalar_real(x):
|
||||
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
|
||||
try:
|
||||
aval = core.get_aval(x)
|
||||
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
|
||||
raise TypeError(msg(x))
|
||||
except TypeError:
|
||||
raise TypeError(msg(x))
|
||||
else:
|
||||
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
|
||||
raise TypeError(msg(x))
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg2 = ("Gradient only defined for functions with output dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real scalars), but "
|
||||
"got {}. For holomorphic differentiation, apply `np.real` at the "
|
||||
"end of the function.")
|
||||
raise TypeError(msg2.format(aval.dtype.name))
|
||||
|
||||
def _check_real(x):
|
||||
aval = core.get_aval(x)
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
|
||||
def custom_transforms(fun):
|
||||
@ -814,9 +836,14 @@ def _elementwise_std_basis(pytree):
|
||||
arity = len(leaves)
|
||||
dims = map(onp.size, leaves)
|
||||
# TODO(mattjj): use symbolic constants
|
||||
basis_array = onp.stack(
|
||||
[onp.concatenate([onp.ones(dims[j]) if i == j else onp.zeros(dims[j])
|
||||
for j in range(arity)]) for i in range(arity)])
|
||||
dtype = onp.result_type(*leaves)
|
||||
if not onp.issubdtype(dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
|
||||
basis_array = onp.stack([onp.concatenate(
|
||||
[onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
|
||||
for j in range(arity)]) for i in range(arity)])
|
||||
return _unravel_array_into_pytree(pytree, 1, basis_array)
|
||||
|
||||
def jarrett(fun):
|
||||
|
@ -427,6 +427,48 @@ class APITest(jtu.JaxTestCase):
|
||||
jaxpr2 = api.make_jaxpr(f2_vjp)(y)
|
||||
assert len(jaxpr2.constvars) == 2
|
||||
|
||||
def test_complex_grad_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
|
||||
grad(lambda x: np.real(np.sin(x)))(1 + 2j) # doesn't crash
|
||||
|
||||
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
|
||||
# def test_complex_jacfwd(self):
|
||||
# # code based on https://github.com/google/jax/issues/603
|
||||
# zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
|
||||
# def f(z):
|
||||
# return np.cos(np.linalg.norm(2 * z))
|
||||
|
||||
# ans = jacfwd(f)(zs)
|
||||
# expected = grad(f)(zs)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_complex_jacfwd_raises_error(self):
|
||||
# code based on https://github.com/google/jax/issues/603
|
||||
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
def f(z):
|
||||
return np.cos(np.linalg.norm(2 * z))
|
||||
self.assertRaises(TypeError, lambda: jacfwd(f)(zs))
|
||||
|
||||
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
|
||||
# def test_complex_jacrev(self):
|
||||
# # code based on https://github.com/google/jax/issues/603
|
||||
# zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
|
||||
# def f(z):
|
||||
# return np.cos(np.linalg.norm(2 * z))
|
||||
|
||||
# ans = jacrev(f)(zs)
|
||||
# expected = grad(f)(zs)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_complex_jacrev_raises_error(self):
|
||||
# code based on https://github.com/google/jax/issues/603
|
||||
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
def f(z):
|
||||
return np.cos(np.linalg.norm(2 * z))
|
||||
self.assertRaises(TypeError, lambda: jacrev(f)(zs))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user