Improve error message when vjp is called with cotangent of wrong shape.

Previously the error was an internal assertion error.
This commit is contained in:
George Necula 2021-07-10 19:08:15 +03:00
parent c590c9ea4a
commit 5520fcb59f
2 changed files with 20 additions and 3 deletions

View File

@ -1888,19 +1888,24 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
return apply_flat_fun(fun, io_tree, *py_args)
def _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args):
def _vjp_pullback_wrapper(cotangent_dtypes, cotangent_shapes,
io_tree, fun, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of "
f"primal output {in_tree_expected}.")
for arg, ct_dtype in safe_zip(args, cotangent_dtypes):
for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes):
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(_dtype(arg))
if expected_tangent_dtype != ct_dtype:
raise TypeError(
f"Type of cotangent input to vjp pullback function ({ct_dtype}) is not "
f"the expected tangent type ({expected_tangent_dtype}) of corresponding primal output "
f"with dtype {_dtype(arg)}.")
if np.shape(arg) != cotangent_shapes:
raise ValueError(
f"Shape of cotangent input to vjp pullback function ({np.shape(arg)}) is not "
f"the expected shape ({ct_shape}) of corresponding primal output.")
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@ -2003,10 +2008,11 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
out_tree, aux_tree = out_aux_trees()
out_primal_py = tree_unflatten(out_tree, out_primal)
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
ct_shapes = [np.shape(x) for x in out_primal]
# Ensure that vjp_py is a PyTree so that we can pass it from the forward to the
# backward pass in a custom VJP.
vjp_py = Partial(partial(_vjp_pullback_wrapper,
ct_dtypes,
ct_dtypes, ct_shapes,
(out_tree, in_tree)),
out_vjp)
if not has_aux:

View File

@ -1148,6 +1148,17 @@ class APITest(jtu.JaxTestCase):
"Type of cotangent input to vjp pullback.*is not the expected tangent type",
lambda: pullback((np.float16(42))))
def test_vjp_bad_cotangent_shape(self):
x = np.ones((2, 5), dtype=np.float32)
y = np.ones((5, 3), dtype=np.float32)
def f_jax(x, y):
return jnp.matmul(x, y)
res, pullback = jax.vjp(f_jax, x, y)
with self.assertRaisesRegex(
ValueError,
"Shape of cotangent input to vjp pullback function .* is not the expected shape"):
pullback(np.ones((2, 4), dtype=np.float32))
def test_jvp_jit_cached(self):
"""Bug in caching in presence of JVP and JIT."""