Improve error message for missing vmap rule in custom_vmap.

This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
This commit is contained in:
Dan Foreman-Mackey 2024-08-08 14:08:51 +01:00
parent 551f72979c
commit 595ca0affa
2 changed files with 14 additions and 1 deletions

View File

@ -65,6 +65,11 @@ class custom_vmap:
@traceback_util.api_boundary
def __call__(self, *args, **kwargs):
assert not kwargs
fun_name = getattr(self.fun, "__name__", str(self.fun))
if not self.vmap_rule:
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]

View File

@ -10780,7 +10780,6 @@ class CustomVmapTest(jtu.JaxTestCase):
)
self.assertAllClose(outputs['b'], expected)
def test_batch_divides_axis(self):
def f(t):
x, a = t
@ -10798,6 +10797,15 @@ class CustomVmapTest(jtu.JaxTestCase):
self.assertAllClose(y, (x + a)**2)
def test_undefined_rule(self):
@jax.custom_batching.custom_vmap
def f(x): return jnp.sin(x)
with self.assertRaisesRegex(
AttributeError, "No batching rule defined for custom_vmap function f"):
f(0.5)
class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""