mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error message when vjp is called with cotangent of wrong shape.
Previously the error was an internal assertion error.
This commit is contained in:
parent
c590c9ea4a
commit
5520fcb59f
@ -1888,19 +1888,24 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
|
||||
|
||||
return apply_flat_fun(fun, io_tree, *py_args)
|
||||
|
||||
def _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args):
|
||||
def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes,
|
||||
io_tree, fun, py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
args, in_tree = tree_flatten(py_args)
|
||||
if in_tree != in_tree_expected:
|
||||
raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of "
|
||||
f"primal output {in_tree_expected}.")
|
||||
for arg, ct_dtype in safe_zip(args, cotangent_dtypes):
|
||||
for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes):
|
||||
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(_dtype(arg))
|
||||
if expected_tangent_dtype != ct_dtype:
|
||||
raise TypeError(
|
||||
f"Type of cotangent input to vjp pullback function ({ct_dtype}) is not "
|
||||
f"the expected tangent type ({expected_tangent_dtype}) of corresponding primal output "
|
||||
f"with dtype {_dtype(arg)}.")
|
||||
if np.shape(arg) != cotangent_shapes:
|
||||
raise ValueError(
|
||||
f"Shape of cotangent input to vjp pullback function ({np.shape(arg)}) is not "
|
||||
f"the expected shape ({ct_shape}) of corresponding primal output.")
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
@ -2003,10 +2008,11 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
|
||||
out_tree, aux_tree = out_aux_trees()
|
||||
out_primal_py = tree_unflatten(out_tree, out_primal)
|
||||
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
|
||||
ct_shapes = [np.shape(x) for x in out_primal]
|
||||
# Ensure that vjp_py is a PyTree so that we can pass it from the forward to the
|
||||
# backward pass in a custom VJP.
|
||||
vjp_py = Partial(partial(_vjp_pullback_wrapper,
|
||||
ct_dtypes,
|
||||
ct_dtypes, ct_shapes,
|
||||
(out_tree, in_tree)),
|
||||
out_vjp)
|
||||
if not has_aux:
|
||||
|
@ -1148,6 +1148,17 @@ class APITest(jtu.JaxTestCase):
|
||||
"Type of cotangent input to vjp pullback.*is not the expected tangent type",
|
||||
lambda: pullback((np.float16(42))))
|
||||
|
||||
def test_vjp_bad_cotangent_shape(self):
|
||||
x = np.ones((2, 5), dtype=np.float32)
|
||||
y = np.ones((5, 3), dtype=np.float32)
|
||||
def f_jax(x, y):
|
||||
return jnp.matmul(x, y)
|
||||
res, pullback = jax.vjp(f_jax, x, y)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Shape of cotangent input to vjp pullback function .* is not the expected shape"):
|
||||
pullback(np.ones((2, 4), dtype=np.float32))
|
||||
|
||||
def test_jvp_jit_cached(self):
|
||||
"""Bug in caching in presence of JVP and JIT."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user