diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 40dbe0018..03c9a9510 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {} spec_types: set[type] = {JumbleAxis} def unregister_vmappable(data_type: type) -> None: - spec_type, axis_size_type = vmappables.pop(data_type) - spec_types.remove(spec_type) + _, axis_size_type = vmappables.pop(data_type) del to_elt_handlers[data_type] del from_elt_handlers[data_type] if axis_size_type in make_iota_handlers: del make_iota_handlers[axis_size_type] + global spec_types + spec_types = ( + {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()} + ) def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables diff --git a/tests/batching_test.py b/tests/batching_test.py index bab18ce53..f2a4e8c34 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase): self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) + def test_types_with_same_spec(self): + # We register NamedArray. + batching.register_vmappable(NamedArray, NamedMapSpec, int, + named_to_elt, named_from_elt, None) + + # We then register another type that uses NamedMapSpec as the spec_type too, + # and immediately unregister it. + class Foo: + pass + batching.register_vmappable(Foo, NamedMapSpec, int, + named_to_elt, named_from_elt, None) + batching.unregister_vmappable(Foo) + + # We should still be able to use vmap on NamedArray. + def f(x): + return named_mul(x, x) + + x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) + ans = jax.jit(f)(x) + expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) + + self.assertEqual(ans.names, expected.names) + self.assertAllClose(ans.data, expected.data) + + # And unregister NamedArray without exceptions. + batching.unregister_vmappable(NamedArray) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())