Merge pull request #19509 from jakevdp:array-register

PiperOrigin-RevId: 601264991
This commit is contained in:
jax authors 2024-01-24 16:14:49 -08:00
commit c7425ef967

View File

@ -29,7 +29,6 @@ from jax import tree_util
from jax._src import api_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config
from jax._src import core
from jax._src import dispatch
@ -397,7 +396,6 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
*(f"__{op}__" for op in _array_operators),
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')