make jacrev work w/ complex inputs, update errors (#610)

* make jacrev work w/ complex inputs, update errors

* fix up complex handling in jacfwd and jacrev
This commit is contained in:
Matthew Johnson 2019-04-13 13:22:45 -07:00 committed by GitHub
parent d7f623ca9d
commit d7096a42c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 73 deletions

View File

@ -187,7 +187,7 @@ def xla_computation(fun, static_argnums=()):
return computation_maker
def grad(fun, argnums=0, has_aux=False):
def grad(fun, argnums=0, has_aux=False, holomorphic=False):
"""Creates a function which evaluates the gradient of `fun`.
Args:
@ -198,8 +198,10 @@ def grad(fun, argnums=0, has_aux=False):
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as `fun`, that evaluates the gradient of
@ -216,7 +218,8 @@ def grad(fun, argnums=0, has_aux=False):
array(0.961043, dtype=float32)
"""
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux)
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
@ -234,7 +237,7 @@ def grad(fun, argnums=0, has_aux=False):
return grad_f
def value_and_grad(fun, argnums=0, has_aux=False):
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
"""Creates a function which evaluates both `fun` and the gradient of `fun`.
Args:
@ -247,6 +250,8 @@ def value_and_grad(fun, argnums=0, has_aux=False):
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as `fun` that evaluates both `fun` and
@ -271,8 +276,14 @@ def value_and_grad(fun, argnums=0, has_aux=False):
ans, vjp_py = vjp(f_partial, *dyn_args)
else:
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
_check_scalar_real(ans)
g = vjp_py(onp.ones((), onp.result_type(ans)))
_check_scalar(ans)
dtype = onp.result_type(ans)
if not (holomorphic or onp.issubdtype(dtype, onp.floating)):
msg = ("Gradient only defined for real-output functions (with dtype that "
"is a subdtype of np.floating), but got dtype {}. For holomorphic "
"differentiation, pass holomorphic=True.")
raise TypeError(msg.format(dtype))
g = vjp_py(onp.ones((), dtype=dtype))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
@ -281,14 +292,26 @@ def value_and_grad(fun, argnums=0, has_aux=False):
return value_and_grad_f
def _check_scalar(x):
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
try:
aval = core.get_aval(x)
except TypeError:
raise TypeError(msg(x))
else:
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
raise TypeError(msg(x))
def jacfwd(fun, argnums=0):
def jacfwd(fun, argnums=0, holomorphic=False):
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default `0`).
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as `fun`, that evaluates the Jacobian of
@ -307,21 +330,32 @@ def jacfwd(fun, argnums=0):
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = _argnums_partial(f, argnums, args)
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
pushfwd = partial(jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
tree_map(_check_real, y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
return jacfun
def jacrev(fun, argnums=0):
def _check_real_input_jacfwd(x):
aval = core.get_aval(x)
if not onp.issubdtype(aval.dtype, onp.floating):
msg = ("jacfwd only defined for functions with input dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
"{}. For holomorphic differentiation, pass holomorphic=True.")
raise TypeError(msg.format(aval.dtype.name))
def jacrev(fun, argnums=0, holomorphic=False):
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default `0`).
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as `fun`, that evaluates the Jacobian of
@ -339,8 +373,8 @@ def jacrev(fun, argnums=0):
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = _argnums_partial(f, argnums, args)
tree_map(_check_real, dyn_args)
y, pullback = vjp(f_partial, *dyn_args)
holomorphic or tree_map(_check_real_output_jacrev, y)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
@ -350,13 +384,24 @@ def jacrev(fun, argnums=0):
return jacfun
jacobian = jacrev
def hessian(fun, argnums=0):
def _check_real_output_jacrev(x):
aval = core.get_aval(x)
if not onp.issubdtype(aval.dtype, onp.floating):
msg = ("jacrev only defined for functions with output dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
"{}. For holomorphic differentiation, pass holomorphic=True.")
raise TypeError(msg.format(aval.dtype.name))
def hessian(fun, argnums=0, holomorphic=False):
"""Hessian of `fun`.
Args:
fun: Function whose Hessian is to be computed.
argnums: Optional, integer or tuple of integers. Specifies which positional
argument(s) to differentiate with respect to (default `0`).
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as `fun`, that evaluates the Hessian of
@ -367,29 +412,22 @@ def hessian(fun, argnums=0):
array([[ 6., -2.],
[ -2., -480.]], dtype=float32)
"""
return jacfwd(jacrev(fun, argnums=argnums), argnums=argnums)
return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)
def _std_basis(pytree):
leaves, _ = tree_flatten(pytree)
ndim = sum(map(onp.size, leaves))
# TODO(mattjj): use a symbolic identity matrix here
dtype = onp.result_type(*leaves)
if not onp.issubdtype(dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
flat_basis = onp.eye(ndim, dtype=dtype)
return _unravel_array_into_pytree(pytree, 1, flat_basis)
def _unravel_array_into_pytree(pytree, axis, arr):
leaves, treedef = tree_flatten(pytree)
axis = axis % arr.ndim
dtypes = map(_dtype, leaves)
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
reshaped_parts = [onp.reshape(part.astype(dtype), shape)
for part, dtype, shape in zip(parts, dtypes, shapes)]
reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
@ -787,29 +825,6 @@ def _check_args(args):
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
.format(arg, type(arg)))
def _check_scalar_real(x):
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
try:
aval = core.get_aval(x)
except TypeError:
raise TypeError(msg(x))
else:
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
raise TypeError(msg(x))
if not onp.issubdtype(aval.dtype, onp.floating):
msg2 = ("Gradient only defined for functions with output dtypes that are "
"sub-dtypes of `np.floating` (i.e. that model real scalars), but "
"got {}. For holomorphic differentiation, apply `np.real` at the "
"end of the function.")
raise TypeError(msg2.format(aval.dtype.name))
def _check_real(x):
aval = core.get_aval(x)
if not onp.issubdtype(aval.dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(aval.dtype.name))
def custom_transforms(fun):
name = getattr(fun, '__name__', '<unnamed user primitive>')

View File

@ -429,45 +429,42 @@ class APITest(jtu.JaxTestCase):
def test_complex_grad_raises_error(self):
self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
grad(lambda x: np.real(np.sin(x)))(1 + 2j) # doesn't crash
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
# def test_complex_jacfwd(self):
# # code based on https://github.com/google/jax/issues/603
# zs = 0.5j * onp.arange(5) + onp.arange(5)
def test_holomorphic_grad(self):
out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j)
expected = 2.0327230070196656 - 3.0518977991518j
self.assertAllClose(out, expected, check_dtypes=False)
# def f(z):
# return np.cos(np.linalg.norm(2 * z))
def test_nonholomorphic_grad(self):
zs = 0.5j * onp.arange(5) + onp.arange(5)
# ans = jacfwd(f)(zs)
# expected = grad(f)(zs)
# self.assertAllClose(ans, expected, check_dtypes=True)
def f(z):
return np.sum(np.cos(np.abs(z)))
def test_complex_jacfwd_raises_error(self):
ans = grad(f)(zs)
expected = onp.array([ 0. +0.j,
-0.80430663+0.40215331j,
-0.70368982+0.35184491j,
0.1886467 -0.09432335j,
0.86873727-0.43436864j])
self.assertAllClose(ans, expected, check_dtypes=False)
def test_complex_output_jacrev_raises_error(self):
self.assertRaises(TypeError, lambda: jacrev(lambda x: np.sin(x))(1 + 2j))
def test_nonholomorphic_jacrev(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * onp.arange(5) + onp.arange(5)
def f(z):
return np.cos(np.linalg.norm(2 * z))
self.assertRaises(TypeError, lambda: jacfwd(f)(zs))
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
# def test_complex_jacrev(self):
# # code based on https://github.com/google/jax/issues/603
# zs = 0.5j * onp.arange(5) + onp.arange(5)
ans = jacrev(f)(zs)
expected = grad(f)(zs)
self.assertAllClose(ans, expected, check_dtypes=True)
# def f(z):
# return np.cos(np.linalg.norm(2 * z))
# ans = jacrev(f)(zs)
# expected = grad(f)(zs)
# self.assertAllClose(ans, expected, check_dtypes=True)
def test_complex_jacrev_raises_error(self):
# code based on https://github.com/google/jax/issues/603
zs = 0.5j * onp.arange(5) + onp.arange(5)
def f(z):
return np.cos(np.linalg.norm(2 * z))
self.assertRaises(TypeError, lambda: jacrev(f)(zs))
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
if __name__ == '__main__':