mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Change the way that batching.spec_types is updated.
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions. Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove). When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types. PiperOrigin-RevId: 737280270
This commit is contained in:
parent
f360e19194
commit
466ef6a132
@ -322,12 +322,15 @@ vmappables: dict[type, tuple[type, type]] = {}
|
||||
spec_types: set[type] = {JumbleAxis}
|
||||
|
||||
def unregister_vmappable(data_type: type) -> None:
|
||||
spec_type, axis_size_type = vmappables.pop(data_type)
|
||||
spec_types.remove(spec_type)
|
||||
_, axis_size_type = vmappables.pop(data_type)
|
||||
del to_elt_handlers[data_type]
|
||||
del from_elt_handlers[data_type]
|
||||
if axis_size_type in make_iota_handlers:
|
||||
del make_iota_handlers[axis_size_type]
|
||||
global spec_types
|
||||
spec_types = (
|
||||
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
|
||||
)
|
||||
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) is Jumble or type(x) in vmappables
|
||||
|
@ -1356,6 +1356,32 @@ class VmappableTest(jtu.JaxTestCase):
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
|
||||
def test_types_with_same_spec(self):
|
||||
# We register NamedArray.
|
||||
batching.register_vmappable(NamedArray, NamedMapSpec, int,
|
||||
named_to_elt, named_from_elt, None)
|
||||
|
||||
# We then register another type that uses NamedMapSpec as the spec_type too,
|
||||
# and immediately unregister it.
|
||||
class Foo:
|
||||
pass
|
||||
batching.register_vmappable(Foo, NamedMapSpec, int,
|
||||
named_to_elt, named_from_elt, None)
|
||||
batching.unregister_vmappable(Foo)
|
||||
|
||||
# We should still be able to use vmap on NamedArray.
|
||||
def f(x):
|
||||
return named_mul(x, x)
|
||||
|
||||
x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4))
|
||||
ans = jax.jit(f)(x)
|
||||
expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2)
|
||||
|
||||
self.assertEqual(ans.names, expected.names)
|
||||
self.assertAllClose(ans.data, expected.data)
|
||||
|
||||
# And unregister NamedArray without exceptions.
|
||||
batching.unregister_vmappable(NamedArray)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user