From 466ef6a132f06d8b6a19b2c21b7d05cc7b39172f Mon Sep 17 00:00:00 2001
From: Joan Puigcerver <jpuigcerver@google.com>
Date: Sat, 15 Mar 2025 22:57:56 -0700
Subject: [PATCH] Change the way that batching.spec_types is updated.

There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
---
 jax/_src/interpreters/batching.py |  7 +++++--
 tests/batching_test.py            | 26 ++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py
index 40dbe0018..03c9a9510 100644
--- a/jax/_src/interpreters/batching.py
+++ b/jax/_src/interpreters/batching.py
@@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
 spec_types: set[type] = {JumbleAxis}
 
 def unregister_vmappable(data_type: type) -> None:
-  spec_type, axis_size_type = vmappables.pop(data_type)
-  spec_types.remove(spec_type)
+  _, axis_size_type = vmappables.pop(data_type)
   del to_elt_handlers[data_type]
   del from_elt_handlers[data_type]
   if axis_size_type in make_iota_handlers:
     del make_iota_handlers[axis_size_type]
+  global spec_types
+  spec_types = (
+      {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
+  )
 
 def is_vmappable(x: Any) -> bool:
   return type(x) is Jumble or type(x) in vmappables
diff --git a/tests/batching_test.py b/tests/batching_test.py
index bab18ce53..f2a4e8c34 100644
--- a/tests/batching_test.py
+++ b/tests/batching_test.py
@@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
       self.assertEqual(ans.names, expected.names)
       self.assertAllClose(ans.data, expected.data)
 
+  def test_types_with_same_spec(self):
+    # We register NamedArray.
+    batching.register_vmappable(NamedArray, NamedMapSpec, int,
+                              named_to_elt, named_from_elt, None)
+
+    # We then register another type that uses NamedMapSpec as the spec_type too,
+    # and immediately unregister it.
+    class Foo:
+      pass
+    batching.register_vmappable(Foo, NamedMapSpec, int,
+                                named_to_elt, named_from_elt, None)
+    batching.unregister_vmappable(Foo)
+
+    # We should still be able to use vmap on NamedArray.
+    def f(x):
+      return named_mul(x, x)
+
+    x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
+    ans = jax.jit(f)(x)
+    expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
+
+    self.assertEqual(ans.names, expected.names)
+    self.assertAllClose(ans.data, expected.data)
+
+    # And unregister NamedArray without exceptions.
+    batching.unregister_vmappable(NamedArray)
 
 if __name__ == '__main__':
   absltest.main(testLoader=jtu.JaxTestLoader())