mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve error message for missing vmap rule in custom_vmap.
This is a partial re-land of https://github.com/google/jax/pull/22869 after it was rolled back to fix internal users. This part of the change didn't cause the issues, and I'll follow up with the rest of the changes in a second PR.
This commit is contained in:
parent
551f72979c
commit
595ca0affa
@ -65,6 +65,11 @@ class custom_vmap:
|
||||
@traceback_util.api_boundary
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert not kwargs
|
||||
fun_name = getattr(self.fun, "__name__", str(self.fun))
|
||||
if not self.vmap_rule:
|
||||
raise AttributeError(
|
||||
f"No batching rule defined for custom_vmap function {fun_name} "
|
||||
"using def_vmap.")
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
|
@ -10780,7 +10780,6 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
)
|
||||
self.assertAllClose(outputs['b'], expected)
|
||||
|
||||
|
||||
def test_batch_divides_axis(self):
|
||||
def f(t):
|
||||
x, a = t
|
||||
@ -10798,6 +10797,15 @@ class CustomVmapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(y, (x + a)**2)
|
||||
|
||||
def test_undefined_rule(self):
|
||||
@jax.custom_batching.custom_vmap
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AttributeError, "No batching rule defined for custom_vmap function f"):
|
||||
f(0.5)
|
||||
|
||||
|
||||
class CustomApiTest(jtu.JaxTestCase):
|
||||
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user