mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
8fe26190de
commit
850f1afd95
104
jax/api.py
104
jax/api.py
@ -150,7 +150,7 @@ def jit(fun: Callable, static_argnums: Union[int, Iterable[int]] = (),
|
||||
else:
|
||||
dyn_args = args
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
_check_args(args_flat)
|
||||
for arg in args_flat: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
|
||||
name=flat_fun.__name__)
|
||||
@ -370,7 +370,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
||||
holomorphic. Default False.
|
||||
holomorphic. If True, inputs and outputs must be complex. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as ``fun``, that evaluates the gradient
|
||||
@ -424,7 +424,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
first element is considered the output of the mathematical function to be
|
||||
differentiated and the second element is auxiliary data. Default False.
|
||||
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
||||
holomorphic. Default False.
|
||||
holomorphic. If True, inputs and outputs must be complex. Default False.
|
||||
|
||||
Returns:
|
||||
A function with the same arguments as ``fun`` that evaluates both ``fun``
|
||||
@ -454,17 +454,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
tree_map(partial(_check_input_dtype_grad, holomorphic), dyn_args)
|
||||
if not has_aux:
|
||||
ans, vjp_py = _vjp(f_partial, *dyn_args)
|
||||
else:
|
||||
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
||||
_check_scalar(ans)
|
||||
dtype = dtypes.result_type(ans)
|
||||
if not (holomorphic or dtypes.issubdtype(dtype, onp.floating)):
|
||||
msg = ("Gradient only defined for real-output functions (with dtype that "
|
||||
"is a subdtype of np.floating), but got dtype {}. For holomorphic "
|
||||
"differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(dtype))
|
||||
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
|
||||
g = vjp_py(onp.ones((), dtype=dtype))
|
||||
g = g[0] if isinstance(argnums, int) else g
|
||||
if not has_aux:
|
||||
@ -487,6 +484,38 @@ def _check_scalar(x):
|
||||
else:
|
||||
raise TypeError(msg("had abstract value {}".format(aval)))
|
||||
|
||||
def _check_input_dtype_revderiv(name, holomorphic, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
msg = (f"{name} with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not (dtypes.issubdtype(aval.dtype, onp.floating) or
|
||||
dtypes.issubdtype(aval.dtype, onp.complexfloating)):
|
||||
msg = (f"{name} requires real- or complex-valued inputs (input dtype that "
|
||||
"is a sub-dtype of np.floating or np.complexfloating), "
|
||||
f"but got {aval.dtype.name}. ")
|
||||
raise TypeError(msg)
|
||||
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
||||
|
||||
def _check_output_dtype_revderiv(name, holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
msg = (f"{name} with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
msg = (f"{name} requires real-valued outputs (output dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
"For differentiation of non-holomorphic functions involving complex "
|
||||
"outputs, use jax.vjp directly.")
|
||||
raise TypeError(msg)
|
||||
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
|
||||
|
||||
|
||||
def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
holomorphic: bool = False) -> Callable:
|
||||
@ -521,21 +550,39 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
holomorphic or tree_map(_check_real_input_jacfwd, dyn_args)
|
||||
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
|
||||
pushfwd = partial(_jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, batching.last))(_std_basis(dyn_args))
|
||||
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
return tree_map(partial(_unravel_array_into_pytree, example_args, -1), jac)
|
||||
|
||||
return jacfun
|
||||
|
||||
def _check_real_input_jacfwd(x):
|
||||
def _check_input_dtype_jacfwd(holomorphic, x):
|
||||
_check_arg(x)
|
||||
aval = core.get_aval(x)
|
||||
if not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("jacfwd only defined for functions with input dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real values), but "
|
||||
"got {}. For holomorphic differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
if holomorphic:
|
||||
if not (dtypes.issubdtype(aval.dtype, onp.complexfloating) and
|
||||
not dtypes.issubdtype(aval.dtype, onp.floating)):
|
||||
msg = ("jacfwd with holomorphic=True requires inputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
elif not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("jacfwd requires real-valued inputs (input dtype that is "
|
||||
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
||||
"For holomorphic differentiation, pass holomorphic=True. "
|
||||
"For differentiation of non-holomorphic functions involving complex "
|
||||
"inputs, use jax.jvp directly.")
|
||||
raise TypeError(msg)
|
||||
|
||||
def _check_output_dtype_jacfwd(holomorphic, x):
|
||||
aval = core.get_aval(x)
|
||||
if holomorphic:
|
||||
if not dtypes.issubdtype(aval.dtype, onp.complexfloating):
|
||||
msg = ("jacfwd with holomorphic=True requires outputs with complex dtype, "
|
||||
f"but got {aval.dtype.name}.")
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
@ -571,8 +618,9 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
tree_map(partial(_check_input_dtype_jacrev, holomorphic), dyn_args)
|
||||
y, pullback = _vjp(f_partial, *dyn_args)
|
||||
holomorphic or tree_map(_check_real_output_jacrev, y)
|
||||
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
|
||||
jac = vmap(pullback)(_std_basis(y))
|
||||
jac = jac[0] if isinstance(argnums, int) else jac
|
||||
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
||||
@ -582,13 +630,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
return jacfun
|
||||
jacobian = jacrev
|
||||
|
||||
def _check_real_output_jacrev(x):
|
||||
aval = core.get_aval(x)
|
||||
if not dtypes.issubdtype(aval.dtype, onp.floating):
|
||||
msg = ("jacrev only defined for functions with output dtypes that are "
|
||||
"sub-dtypes of `np.floating` (i.e. that model real values), but "
|
||||
"got {}. For holomorphic differentiation, pass holomorphic=True.")
|
||||
raise TypeError(msg.format(aval.dtype.name))
|
||||
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
|
||||
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
|
||||
|
||||
|
||||
def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
@ -1070,7 +1113,7 @@ def pmap(fun: Callable, axis_name: Optional[AxisName] = None, *, in_axes=0,
|
||||
assert all(axis in (0, None) for axis in in_axes_flat), \
|
||||
"pmap currently only supports mapping over the leading axis"
|
||||
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap")
|
||||
_check_args(args)
|
||||
for arg in args: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
out = pxla.xla_pmap(
|
||||
flat_fun,
|
||||
@ -1114,7 +1157,7 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, *,
|
||||
"soft_pmap currently only supports mapping over the leading axis"
|
||||
mapped_invars = tuple(axis is not None for axis in in_axes_flat)
|
||||
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
|
||||
_check_args(args_flat)
|
||||
for arg in args_flat: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
|
||||
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count(backend))
|
||||
@ -1489,7 +1532,7 @@ def _vjp(fun: lu.WrappedFun, *primals, **kwargs):
|
||||
has_aux = kwargs.pop('has_aux', False)
|
||||
assert not kwargs
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
_check_args(primals_flat)
|
||||
for arg in primals_flat: _check_arg(arg)
|
||||
tree_map(_check_inexact_input_vjp, primals)
|
||||
if not has_aux:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
@ -1618,11 +1661,10 @@ def device_get(x):
|
||||
return tree_map(_device_get, x)
|
||||
|
||||
|
||||
def _check_args(args):
|
||||
for arg in args:
|
||||
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(arg, type(arg)))
|
||||
def _check_arg(arg):
|
||||
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
|
||||
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
|
||||
.format(arg, type(arg)))
|
||||
|
||||
def _valid_jaxtype(arg):
|
||||
try:
|
||||
|
@ -193,14 +193,14 @@ def id_tap(func: Callable, arg, *,
|
||||
if func not in (_end_consumer, _unknown_testing_consumer):
|
||||
api._check_callable(func)
|
||||
flat_args, arg_treedef = pytree.flatten(arg)
|
||||
api._check_args(flat_args)
|
||||
for arg in flat_args: api._check_arg(arg)
|
||||
params = dict(kwargs) # we pass a copy of params to the primitive
|
||||
# See definition of id_tap_p for what parameters it takes
|
||||
params["func"] = func
|
||||
params["arg_treedef"] = arg_treedef
|
||||
if result is not None:
|
||||
flat_results, result_treedef = pytree.flatten(result)
|
||||
api._check_args(flat_results)
|
||||
for result in flat_results: api._check_arg(result)
|
||||
all_args = flat_args + flat_results
|
||||
params["nr_untapped"] = len(flat_results)
|
||||
else:
|
||||
|
@ -801,8 +801,49 @@ class APITest(jtu.JaxTestCase):
|
||||
dfn = grad(lambda x: x ** 2)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Primal inputs to reverse-mode differentiation must be of float or "
|
||||
"complex type, got type int..", lambda: dfn(3))
|
||||
(r"grad requires real- or complex-valued inputs \(input dtype that is a "
|
||||
r"sub-dtype of np.floating or np.complexfloating\), but got int.*."),
|
||||
lambda: dfn(3))
|
||||
|
||||
def test_grad_complex_result_errors(self):
|
||||
dfn = grad(lambda x: x ** 2 + 1j)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"grad requires real-valued outputs \(output dtype that is a "
|
||||
r"sub-dtype of np.floating\), but got complex.*"),
|
||||
lambda: dfn(3.))
|
||||
|
||||
def test_holomorphic_grad_of_float_errors(self):
|
||||
dfn = grad(lambda x: x ** 2, holomorphic=True)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"grad with holomorphic=True requires inputs with complex dtype, "
|
||||
r"but got float.*"),
|
||||
lambda: dfn(3.))
|
||||
|
||||
def test_holomorphic_jacrev_of_float_errors(self):
|
||||
dfn = jacrev(lambda x: x ** 2, holomorphic=True)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"jacrev with holomorphic=True requires inputs with complex dtype, "
|
||||
r"but got float.*"),
|
||||
lambda: dfn(3.))
|
||||
|
||||
def test_holomorphic_jacfwd_of_float_errors(self):
|
||||
dfn = jacfwd(lambda x: x ** 2, holomorphic=True)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"jacfwd with holomorphic=True requires inputs with complex dtype, "
|
||||
r"but got float.*"),
|
||||
lambda: dfn(3.))
|
||||
|
||||
def test_jacfwd_of_complex_errors(self):
|
||||
dfn = jacfwd(lambda x: x ** 2)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
(r"jacfwd requires real-valued inputs \(input dtype that is a "
|
||||
r"sub-dtype of np.floating\), but got complex.*"),
|
||||
lambda: dfn(3. + 1j))
|
||||
|
||||
def test_xla_computation(self):
|
||||
# these tests basically check the examples in the xla_computation docstring
|
||||
|
Loading…
x
Reference in New Issue
Block a user