improve errors for complex derivs, fixes #3121 (#3149)

This commit is contained in:
Matthew Johnson 2020-05-19 15:17:03 -07:00 committed by GitHub
parent 8fe26190de
commit 850f1afd95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 35 deletions

View File

@ -150,7 +150,7 @@ def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
_check_args(args_flat)
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
name=flat_fun.__name__)
@ -370,7 +370,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
holomorphic. If True, inputs and outputs must be complex. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
@ -424,7 +424,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
holomorphic. If True, inputs and outputs must be complex. Default False.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
@ -454,17 +454,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
tree_map(partial(_check_input_dtype_grad, holomorphic), dyn_args)
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args)
else:
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
_check_scalar(ans)
dtype = dtypes.result_type(ans)
if not (holomorphic or dtypes.issubdtype(dtype, onp.floating)):
msg = ("Gradient only defined for real-output functions (with dtype that "
"is a subdtype of np.floating), but got dtype {}. For holomorphic "
"differentiation, pass holomorphic=True.")
raise TypeError(msg.format(dtype))
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(onp.ones((), dtype=dtype))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
@ -487,6 +484,38 @@ def _check_scalar(x):
else:
raise TypeError(msg("had abstract value {}".format(aval)))
def _check_input_dtype_revderiv(name, holomorphic, x):
_check_arg(x)
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
msg = (f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not (dtypes.issubdtype(aval.dtype, onp.floating) or
dtypes.issubdtype(aval.dtype, onp.complexfloating)):
msg = (f"{name} requires real- or complex-valued inputs (input dtype that "
"is a sub-dtype of np.floating or np.complexfloating), "
f"but got {aval.dtype.name}. ")
raise TypeError(msg)
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
msg = (f"{name} with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not dtypes.issubdtype(aval.dtype, onp.floating):
msg = (f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving complex "
"outputs, use jax.vjp directly.")
raise TypeError(msg)
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
holomorphic: bool = False) -> Callable:
@ -521,21 +550,39 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
pushfwd = partial(_jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, batching.last))(_std_basis(dyn_args))
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
return jacfun
def _check_real_input_jacfwd(x):
def _check_input_dtype_jacfwd(holomorphic, x):
_check_arg(x)
aval = core.get_aval(x)
if not dtypes.issubdtype(aval.dtype, onp.floating):
msg = ("jacfwd only defined for functions with input dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real values), but "
"got {}. For holomorphic differentiation, pass holomorphic=True.")
raise TypeError(msg.format(aval.dtype.name))
if holomorphic:
if not (dtypes.issubdtype(aval.dtype, onp.complexfloating) and
not dtypes.issubdtype(aval.dtype, onp.floating)):
msg = ("jacfwd with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
elif not dtypes.issubdtype(aval.dtype, onp.floating):
msg = ("jacfwd requires real-valued inputs (input dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving complex "
"inputs, use jax.jvp directly.")
raise TypeError(msg)
def _check_output_dtype_jacfwd(holomorphic, x):
aval = core.get_aval(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
msg = ("jacfwd with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
raise TypeError(msg)
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
@ -571,8 +618,9 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
tree_map(partial(_check_input_dtype_jacrev, holomorphic), dyn_args)
y, pullback = _vjp(f_partial, *dyn_args)
holomorphic or tree_map(_check_real_output_jacrev, y)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
@ -582,13 +630,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
return jacfun
jacobian = jacrev
def _check_real_output_jacrev(x):
aval = core.get_aval(x)
if not dtypes.issubdtype(aval.dtype, onp.floating):
msg = ("jacrev only defined for functions with output dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real values), but "
"got {}. For holomorphic differentiation, pass holomorphic=True.")
raise TypeError(msg.format(aval.dtype.name))
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
@ -1070,7 +1113,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
assert all(axis in (0, None) for axis in in_axes_flat), \
"pmap currently only supports mapping over the leading axis"
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
_check_args(args)
for arg in args: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
out = pxla.xla_pmap(
flat_fun,
@ -1114,7 +1157,7 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
"soft_pmap currently only supports mapping over the leading axis"
mapped_invars = tuple(axis is not None for axis in in_axes_flat)
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
_check_args(args_flat)
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
@ -1489,7 +1532,7 @@ def _vjp(fun: lu.WrappedFun, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
assert not kwargs
primals_flat, in_tree = tree_flatten(primals)
_check_args(primals_flat)
for arg in primals_flat: _check_arg(arg)
tree_map(_check_inexact_input_vjp, primals)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
@ -1618,11 +1661,10 @@ def device_get(x):
return tree_map(_device_get, x)
def _check_args(args):
for arg in args:
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))
def _check_arg(arg):
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))
def _valid_jaxtype(arg):
try:

View File

@ -193,14 +193,14 @@ def id_tap(func: Callable, arg, *,
if func not in (_end_consumer, _unknown_testing_consumer):
api._check_callable(func)
flat_args, arg_treedef = pytree.flatten(arg)
api._check_args(flat_args)
for arg in flat_args: api._check_arg(arg)
params = dict(kwargs) # we pass a copy of params to the primitive
# See definition of id_tap_p for what parameters it takes
params["func"] = func
params["arg_treedef"] = arg_treedef
if result is not None:
flat_results, result_treedef = pytree.flatten(result)
api._check_args(flat_results)
for result in flat_results: api._check_arg(result)
all_args = flat_args + flat_results
params["nr_untapped"] = len(flat_results)
else:

View File

@ -801,8 +801,49 @@ class APITest(jtu.JaxTestCase):
dfn = grad(lambda x: x ** 2)
self.assertRaisesRegex(
TypeError,
"Primal inputs to reverse-mode differentiation must be of float or "
"complex type, got type int..", lambda: dfn(3))
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
lambda: dfn(3))
def test_grad_complex_result_errors(self):
dfn = grad(lambda x: x ** 2 + 1j)
self.assertRaisesRegex(
TypeError,
(r"grad requires real-valued outputs \(output dtype that is a "
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3.))
def test_holomorphic_grad_of_float_errors(self):
dfn = grad(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"grad with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_holomorphic_jacrev_of_float_errors(self):
dfn = jacrev(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"jacrev with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_holomorphic_jacfwd_of_float_errors(self):
dfn = jacfwd(lambda x: x ** 2, holomorphic=True)
self.assertRaisesRegex(
TypeError,
(r"jacfwd with holomorphic=True requires inputs with complex dtype, "
r"but got float.*"),
lambda: dfn(3.))
def test_jacfwd_of_complex_errors(self):
dfn = jacfwd(lambda x: x ** 2)
self.assertRaisesRegex(
TypeError,
(r"jacfwd requires real-valued inputs \(input dtype that is a "
r"sub-dtype of np.floating\), but got complex.*"),
lambda: dfn(3. + 1j))
def test_xla_computation(self):
# these tests basically check the examples in the xla_computation docstring