mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
make jacrev work w/ complex inputs, update errors (#610)
* make jacrev work w/ complex inputs, update errors * fix up complex handling in jacfwd and jacrev
This commit is contained in:
parent
d7f623ca9d
commit
d7096a42c5
103
jax/api.py
103
jax/api.py
@ -187,7 +187,7 @@ def xla_computation(fun, static_argnums=()):
|
||||
|
||||
return computation_maker
|
||||
|
||||
def grad(fun, argnums=0, has_aux=False):
|
||||
def grad(fun, argnums=0, has_aux=False, holomorphic=False):
|
||||
"""Creates a function which evaluates the gradient of `fun`.
|
||||
|
||||
Args:
|
||||
@ -198,8 +198,10 @@ def grad(fun, argnums=0, has_aux=False):
|
||||
argnums: Optional, integer or tuple of integers. Specifies which positional
|
||||
argument(s) to differentiate with respect to (default 0).
|
||||
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
||||
holomorphic. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as `fun`, that evaluates the gradient of
|
||||
@ -216,7 +218,8 @@ def grad(fun, argnums=0, has_aux=False):
|
||||
array(0.961043, dtype=float32)
|
||||
|
||||
"""
|
||||
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux)
|
||||
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
|
||||
holomorphic=holomorphic)
|
||||
|
||||
docstr = ("Gradient of {fun} with respect to positional argument(s) "
|
||||
"{argnums}. Takes the same arguments as {fun} but returns the "
|
||||
@ -234,7 +237,7 @@ def grad(fun, argnums=0, has_aux=False):
|
||||
|
||||
return grad_f
|
||||
|
||||
def value_and_grad(fun, argnums=0, has_aux=False):
|
||||
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
|
||||
"""Creates a function which evaluates both `fun` and the gradient of `fun`.
|
||||
|
||||
Args:
|
||||
@ -247,6 +250,8 @@ def value_and_grad(fun, argnums=0, has_aux=False):
|
||||
has_aux: Optional, bool. Indicates whether `fun` returns a pair where the
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
||||
holomorphic. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as `fun` that evaluates both `fun` and
|
||||
@ -271,8 +276,14 @@ def value_and_grad(fun, argnums=0, has_aux=False):
|
||||
ans, vjp_py = vjp(f_partial, *dyn_args)
|
||||
else:
|
||||
ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
|
||||
_check_scalar_real(ans)
|
||||
g = vjp_py(onp.ones((), onp.result_type(ans)))
|
||||
_check_scalar(ans)
|
||||
dtype = onp.result_type(ans)
|
||||
if not (holomorphic or onp.issubdtype(dtype, onp.floating)):
|
||||
msg = ("Gradient only defined for real-output functions (with dtype that "
|
||||
"is a subdtype of np.floating), but got dtype {}. For holomorphic "
|
||||
"differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(dtype))
|
||||
g = vjp_py(onp.ones((), dtype=dtype))
|
||||
g = g[0] if isinstance(argnums, int) else g
|
||||
if not has_aux:
|
||||
return ans, g
|
||||
@ -281,14 +292,26 @@ def value_and_grad(fun, argnums=0, has_aux=False):
|
||||
|
||||
return value_and_grad_f
|
||||
|
||||
def _check_scalar(x):
|
||||
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
|
||||
try:
|
||||
aval = core.get_aval(x)
|
||||
except TypeError:
|
||||
raise TypeError(msg(x))
|
||||
else:
|
||||
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
|
||||
raise TypeError(msg(x))
|
||||
|
||||
def jacfwd(fun, argnums=0):
|
||||
|
||||
def jacfwd(fun, argnums=0, holomorphic=False):
|
||||
"""Jacobian of `fun` evaluated column-by-column using forward-mode AD.
|
||||
|
||||
Args:
|
||||
fun: Function whose Jacobian is to be computed.
|
||||
argnums: Optional, integer or tuple of integers. Specifies which positional
|
||||
argument(s) to differentiate with respect to (default `0`).
|
||||
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
||||
holomorphic. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as `fun`, that evaluates the Jacobian of
|
||||
@ -307,21 +330,32 @@ def jacfwd(fun, argnums=0):
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
||||
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
|
||||
pushfwd = partial(jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
tree_map(_check_real, y)
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
|
||||
return jacfun
|
||||
|
||||
def jacrev(fun, argnums=0):
|
||||
def _check_real_input_jacfwd(x):
|
||||
aval = core.get_aval(x)
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("jacfwd only defined for functions with input dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
|
||||
"{}. For holomorphic differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
|
||||
def jacrev(fun, argnums=0, holomorphic=False):
|
||||
"""Jacobian of `fun` evaluated row-by-row using reverse-mode AD.
|
||||
|
||||
Args:
|
||||
fun: Function whose Jacobian is to be computed.
|
||||
argnums: Optional, integer or tuple of integers. Specifies which positional
|
||||
argument(s) to differentiate with respect to (default `0`).
|
||||
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
||||
holomorphic. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as `fun`, that evaluates the Jacobian of
|
||||
@ -339,8 +373,8 @@ def jacrev(fun, argnums=0):
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = _argnums_partial(f, argnums, args)
|
||||
tree_map(_check_real, dyn_args)
|
||||
y, pullback = vjp(f_partial, *dyn_args)
|
||||
holomorphic or tree_map(_check_real_output_jacrev, y)
|
||||
jac = vmap(pullback)(_std_basis(y))
|
||||
jac = jac[0] if isinstance(argnums, int) else jac
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
@ -350,13 +384,24 @@ def jacrev(fun, argnums=0):
|
||||
return jacfun
|
||||
jacobian = jacrev
|
||||
|
||||
def hessian(fun, argnums=0):
|
||||
def _check_real_output_jacrev(x):
|
||||
aval = core.get_aval(x)
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("jacrev only defined for functions with output dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real values), but got "
|
||||
"{}. For holomorphic differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
|
||||
def hessian(fun, argnums=0, holomorphic=False):
|
||||
"""Hessian of `fun`.
|
||||
|
||||
Args:
|
||||
fun: Function whose Hessian is to be computed.
|
||||
argnums: Optional, integer or tuple of integers. Specifies which positional
|
||||
argument(s) to differentiate with respect to (default `0`).
|
||||
holomorphic: Optional, bool. Indicates whether `fun` is promised to be
|
||||
holomorphic. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as `fun`, that evaluates the Hessian of
|
||||
@ -367,29 +412,22 @@ def hessian(fun, argnums=0):
|
||||
array([[ 6., -2.],
|
||||
[ -2., -480.]], dtype=float32)
|
||||
"""
|
||||
|
||||
return jacfwd(jacrev(fun, argnums=argnums), argnums=argnums)
|
||||
return jacfwd(jacrev(fun, argnums, holomorphic), argnums, holomorphic)
|
||||
|
||||
def _std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
ndim = sum(map(onp.size, leaves))
|
||||
# TODO(mattjj): use a symbolic identity matrix here
|
||||
dtype = onp.result_type(*leaves)
|
||||
if not onp.issubdtype(dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
|
||||
flat_basis = onp.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
axis = axis % arr.ndim
|
||||
dtypes = map(_dtype, leaves)
|
||||
shapes = [arr.shape[:axis] + onp.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, onp.cumsum(map(onp.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [onp.reshape(part.astype(dtype), shape)
|
||||
for part, dtype, shape in zip(parts, dtypes, shapes)]
|
||||
reshaped_parts = [onp.reshape(x, shape) for x, shape in zip(parts, shapes)]
|
||||
return tree_unflatten(treedef, reshaped_parts)
|
||||
|
||||
def _split(x, indices, axis):
|
||||
@ -787,29 +825,6 @@ def _check_args(args):
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(arg, type(arg)))
|
||||
|
||||
def _check_scalar_real(x):
|
||||
msg = "Gradient only defined for scalar-output functions. Output was: {}".format
|
||||
try:
|
||||
aval = core.get_aval(x)
|
||||
except TypeError:
|
||||
raise TypeError(msg(x))
|
||||
else:
|
||||
if not (isinstance(aval, ShapedArray) and aval.shape == ()):
|
||||
raise TypeError(msg(x))
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg2 = ("Gradient only defined for functions with output dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real scalars), but "
|
||||
"got {}. For holomorphic differentiation, apply `np.real` at the "
|
||||
"end of the function.")
|
||||
raise TypeError(msg2.format(aval.dtype.name))
|
||||
|
||||
def _check_real(x):
|
||||
aval = core.get_aval(x)
|
||||
if not onp.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
|
||||
|
||||
def custom_transforms(fun):
|
||||
name = getattr(fun, '__name__', '<unnamed user primitive>')
|
||||
|
@ -429,45 +429,42 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_complex_grad_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: grad(lambda x: np.sin(x))(1 + 2j))
|
||||
grad(lambda x: np.real(np.sin(x)))(1 + 2j) # doesn't crash
|
||||
|
||||
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
|
||||
# def test_complex_jacfwd(self):
|
||||
# # code based on https://github.com/google/jax/issues/603
|
||||
# zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
def test_holomorphic_grad(self):
|
||||
out = grad(lambda x: np.sin(x), holomorphic=True)(1 + 2j)
|
||||
expected = 2.0327230070196656 - 3.0518977991518j
|
||||
self.assertAllClose(out, expected, check_dtypes=False)
|
||||
|
||||
# def f(z):
|
||||
# return np.cos(np.linalg.norm(2 * z))
|
||||
def test_nonholomorphic_grad(self):
|
||||
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
|
||||
# ans = jacfwd(f)(zs)
|
||||
# expected = grad(f)(zs)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
def f(z):
|
||||
return np.sum(np.cos(np.abs(z)))
|
||||
|
||||
def test_complex_jacfwd_raises_error(self):
|
||||
ans = grad(f)(zs)
|
||||
expected = onp.array([ 0. +0.j,
|
||||
-0.80430663+0.40215331j,
|
||||
-0.70368982+0.35184491j,
|
||||
0.1886467 -0.09432335j,
|
||||
0.86873727-0.43436864j])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_complex_output_jacrev_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacrev(lambda x: np.sin(x))(1 + 2j))
|
||||
|
||||
def test_nonholomorphic_jacrev(self):
|
||||
# code based on https://github.com/google/jax/issues/603
|
||||
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
|
||||
def f(z):
|
||||
return np.cos(np.linalg.norm(2 * z))
|
||||
self.assertRaises(TypeError, lambda: jacfwd(f)(zs))
|
||||
|
||||
# TODO(mattjj, dougalm): make this work if we can, and delete subsequent test
|
||||
# def test_complex_jacrev(self):
|
||||
# # code based on https://github.com/google/jax/issues/603
|
||||
# zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
ans = jacrev(f)(zs)
|
||||
expected = grad(f)(zs)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
# def f(z):
|
||||
# return np.cos(np.linalg.norm(2 * z))
|
||||
|
||||
# ans = jacrev(f)(zs)
|
||||
# expected = grad(f)(zs)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_complex_jacrev_raises_error(self):
|
||||
# code based on https://github.com/google/jax/issues/603
|
||||
zs = 0.5j * onp.arange(5) + onp.arange(5)
|
||||
def f(z):
|
||||
return np.cos(np.linalg.norm(2 * z))
|
||||
self.assertRaises(TypeError, lambda: jacrev(f)(zs))
|
||||
def test_complex_input_jacfwd_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user