mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19509 from jakevdp:array-register
PiperOrigin-RevId: 601264991
This commit is contained in:
commit
c7425ef967
@ -29,7 +29,6 @@ from jax import tree_util
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import api
|
||||
from jax._src import basearray
|
||||
from jax._src import config as config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -397,7 +396,6 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
*(f"__{op}__" for op in _array_operators),
|
||||
'at', 'flatten', 'ravel', 'reshape',
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user