Change the way that batching.spec_types is updated.

There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
This commit is contained in:
Joan Puigcerver 2025-03-15 22:57:56 -07:00 committed by jax authors
parent f360e19194
commit 466ef6a132
2 changed files with 31 additions and 2 deletions

View File

@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {JumbleAxis}
def unregister_vmappable(data_type: type) -> None:
spec_type, axis_size_type = vmappables.pop(data_type)
spec_types.remove(spec_type)
_, axis_size_type = vmappables.pop(data_type)
del to_elt_handlers[data_type]
del from_elt_handlers[data_type]
if axis_size_type in make_iota_handlers:
del make_iota_handlers[axis_size_type]
global spec_types
spec_types = (
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
)
def is_vmappable(x: Any) -> bool:
return type(x) is Jumble or type(x) in vmappables

View File

@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
def test_types_with_same_spec(self):
# We register NamedArray.
batching.register_vmappable(NamedArray, NamedMapSpec, int,
named_to_elt, named_from_elt, None)
# We then register another type that uses NamedMapSpec as the spec_type too,
# and immediately unregister it.
class Foo:
pass
batching.register_vmappable(Foo, NamedMapSpec, int,
named_to_elt, named_from_elt, None)
batching.unregister_vmappable(Foo)
# We should still be able to use vmap on NamedArray.
def f(x):
return named_mul(x, x)
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
ans = jax.jit(f)(x)
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
self.assertEqual(ans.names, expected.names)
self.assertAllClose(ans.data, expected.data)
# And unregister NamedArray without exceptions.
batching.unregister_vmappable(NamedArray)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())