Remove unnecessary Array.register

This commit is contained in:
Jake VanderPlas 2024-01-24 14:59:25 -08:00
parent cfb6250158
commit 78f27dfa9d

View File

@ -29,7 +29,6 @@ from jax import tree_util
from jax._src import api_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config
from jax._src import core
from jax._src import dispatch
@ -397,7 +396,6 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
*(f"__{op}__" for op in _array_operators),
'at', 'flatten', 'ravel', 'reshape',
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')