mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix jacfwd and jacrev for heterogeneous pytrees
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input pytree elements have heterogeneous dtypes, e.g., real and complex elements: * Changed the dtypes of the pytree elements of the Jacobian produced by jacfwd to be those of the input tangent basis. * Changed the dtypes of the pytree elements of the Jacobian produced by jacrev to be those of the output tangent basis. * Changed the dtypes of the pytree elements of the primals and tangents produced by jacfwd and jacrev to be the same as the corresponding elements in the input. Changed the behavior of the flags to `jacfwd` and `jacrev`: * Changed the allow_int flag to only allows integer and Boolean dtypes. Previously, this flag allowed all other types.
This commit is contained in:
parent
075d83527f
commit
832cf214e3
@ -1024,12 +1024,16 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
elif not allow_int and not (dtypes.issubdtype(aval.dtype, np.floating) or
|
||||
dtypes.issubdtype(aval.dtype, np.complexfloating)):
|
||||
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype that "
|
||||
"is a sub-dtype of np.floating or np.complexfloating), "
|
||||
f"but got {aval.dtype.name}. If you want to use integer-valued "
|
||||
"inputs, use vjp or set allow_int to True.")
|
||||
if (dtypes.issubdtype(aval.dtype, np.integer) or
|
||||
dtypes.issubdtype(aval.dtype, np.bool_)):
|
||||
if not allow_int:
|
||||
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
|
||||
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
|
||||
"If you want to use Boolean- or integer-valued inputs, use vjp "
|
||||
"or set allow_int to True.")
|
||||
elif not dtypes.issubdtype(aval.dtype, np.inexact):
|
||||
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
|
||||
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
|
||||
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
||||
|
||||
def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
@ -1038,12 +1042,17 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
||||
elif dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
"For differentiation of non-holomorphic functions involving complex "
|
||||
"outputs, or function with integer outputs, use jax.vjp directly.")
|
||||
"outputs, use jax.vjp directly.")
|
||||
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
||||
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For differentiation of functions with integer outputs, use "
|
||||
"jax.vjp directly.")
|
||||
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
|
||||
|
||||
|
||||
@ -1087,24 +1096,23 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
return tree_map(partial(_jacfwd_unravel, example_args), y, jac)
|
||||
|
||||
return jacfun
|
||||
|
||||
def _check_input_dtype_jacfwd(holomorphic, x):
|
||||
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not (dtypes.issubdtype(aval.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(aval.dtype, np.floating)):
|
||||
raise TypeError("jacfwd with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
||||
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
|
||||
f"dtype, but got {aval.dtype.name}.")
|
||||
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
||||
raise TypeError("jacfwd requires real-valued inputs (input dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
"For differentiation of non-holomorphic functions involving complex "
|
||||
"inputs or integer inputs, use jax.jvp directly.")
|
||||
"For differentiation of non-holomorphic functions involving "
|
||||
"complex inputs or integer inputs, use jax.jvp directly.")
|
||||
|
||||
def _check_output_dtype_jacfwd(holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
@ -1113,7 +1121,6 @@ def _check_output_dtype_jacfwd(holomorphic, x):
|
||||
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
|
||||
|
||||
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
holomorphic: bool = False, allow_int: bool = False) -> Callable:
|
||||
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
|
||||
@ -1157,8 +1164,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
jac = vmap(pullback)(_std_basis(y))
|
||||
jac = jac[0] if isinstance(argnums, int) else jac
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
jac = tree_map(partial(_unravel_array_into_pytree, y, 0), jac)
|
||||
return tree_transpose(tree_structure(example_args), tree_structure(y), jac)
|
||||
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
|
||||
return tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree)
|
||||
|
||||
return jacfun
|
||||
jacobian = jacrev
|
||||
@ -1235,14 +1242,40 @@ def _std_basis(pytree):
|
||||
ndim = sum(map(np.size, leaves))
|
||||
dtype = dtypes.result_type(*leaves)
|
||||
flat_basis = jax.numpy.eye(ndim, dtype=dtype)
|
||||
return _unravel_array_into_pytree(pytree, 1, flat_basis)
|
||||
return _unravel_array_into_pytree(pytree, 1, None, flat_basis)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, arr):
|
||||
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
|
||||
return _unravel_array_into_pytree(
|
||||
input_pytree, -1, _dtype(output_pytree_leaf), arr)
|
||||
|
||||
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
|
||||
return _unravel_array_into_pytree(
|
||||
output_pytree, 0, _dtype(input_pytree_leaf), arr)
|
||||
|
||||
def _possible_downcast(x, dtype):
|
||||
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(dtype, np.complexfloating)):
|
||||
x = x.real
|
||||
return x.astype(dtype)
|
||||
|
||||
def _unravel_array_into_pytree(pytree, axis, cast_to_type, arr):
|
||||
"""Unravel an array into a PyTree with a given structure.
|
||||
Args:
|
||||
pytree: The pytree that provides the structure.
|
||||
axis: The parameter axis is either -1, 0, or 1. It controls the
|
||||
resulting shapes.
|
||||
cast_to_type: Cast the components to the given dtype, or else use the
|
||||
pytree leaf type if cast_to_type is None.
|
||||
arr: The array to be unraveled.
|
||||
"""
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
axis = axis % arr.ndim
|
||||
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
|
||||
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
|
||||
reshaped_parts = [np.reshape(x, shape) for x, shape in zip(parts, shapes)]
|
||||
reshaped_parts = [
|
||||
_possible_downcast(np.reshape(x, shape),
|
||||
_dtype(leaf) if cast_to_type is None else cast_to_type)
|
||||
for x, shape, leaf in zip(parts, shapes, leaves)]
|
||||
return tree_unflatten(treedef, reshaped_parts)
|
||||
|
||||
def _split(x, indices, axis):
|
||||
|
@ -188,6 +188,9 @@ def build_tree(treedef, xs):
|
||||
return treedef.from_iterable_tree(xs)
|
||||
|
||||
def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
|
||||
"""Transform a tree having tree structure (outer, inner) into one having structure
|
||||
(inner, outer).
|
||||
"""
|
||||
flat, treedef = tree_flatten(pytree_to_transpose)
|
||||
inner_size = inner_treedef.num_leaves
|
||||
outer_size = outer_treedef.num_leaves
|
||||
|
@ -1340,6 +1340,60 @@ class APITest(jtu.JaxTestCase):
|
||||
expected = grad(f)(zs)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def test_heterogeneous_jacfwd(self):
|
||||
# See https://github.com/google/jax/issues/7157
|
||||
# See https://github.com/google/jax/issues/7780
|
||||
x = np.array([2.0], dtype=np.float16)
|
||||
y = np.array([3.0], dtype=np.float32)
|
||||
a = (x, y)
|
||||
|
||||
def f(tup):
|
||||
jtu._check_dtypes_match(tup, a)
|
||||
x, y = tup
|
||||
return x, y, x + y
|
||||
|
||||
actual = jacfwd(f)(a)
|
||||
desired = ((np.array(1., dtype=np.float16), np.array(0., dtype=np.float16)),
|
||||
(np.array(0., dtype=np.float32), np.array(1., dtype=np.float32)),
|
||||
(np.array(1., dtype=np.float32), np.array(1., dtype=np.float32)))
|
||||
jtu._check_dtypes_match(actual, desired)
|
||||
jtu.check_eq(actual, desired)
|
||||
|
||||
def test_heterogeneous_jacrev(self):
|
||||
# See https://github.com/google/jax/issues/7157
|
||||
# See https://github.com/google/jax/issues/7780
|
||||
x = np.array([2.0], dtype=np.float16)
|
||||
y = np.array([3.0], dtype=np.float32)
|
||||
a = (x, y)
|
||||
|
||||
def f(tup):
|
||||
jtu._check_dtypes_match(tup, a)
|
||||
x, y = tup
|
||||
return x, y, x + y
|
||||
|
||||
actual = jacrev(f)(a)
|
||||
desired = ((np.array(1., dtype=np.float16), np.array(0., dtype=np.float32)),
|
||||
(np.array(0., dtype=np.float16), np.array(1., dtype=np.float32)),
|
||||
(np.array(1., dtype=np.float16), np.array(1., dtype=np.float32)))
|
||||
jtu._check_dtypes_match(actual, desired)
|
||||
jtu.check_eq(actual, desired)
|
||||
|
||||
def test_heterogeneous_grad(self):
|
||||
# See https://github.com/google/jax/issues/7157
|
||||
x = np.array(1.0+1j)
|
||||
y = np.array(2.0)
|
||||
a = (x, y)
|
||||
|
||||
def f(tup):
|
||||
jtu._check_dtypes_match(tup, a)
|
||||
x, y = tup
|
||||
return jnp.square(jnp.abs(x)) + y
|
||||
|
||||
actual = grad(f)(a)
|
||||
desired = (np.array(2 - 2j), np.array(1.))
|
||||
jtu._check_dtypes_match(actual, desired)
|
||||
jtu.check_eq(actual, desired)
|
||||
|
||||
def test_complex_input_jacfwd_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacfwd(lambda x: jnp.sin(x))(1 + 2j))
|
||||
|
||||
@ -1561,7 +1615,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
|
||||
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
|
||||
r"sub-dtype of np.inexact\), but got int.*."),
|
||||
lambda: dfn(3))
|
||||
|
||||
@unittest.skipIf(numpy_version == (1, 21, 0),
|
||||
|
Loading…
x
Reference in New Issue
Block a user