mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[random] add shaped_abstractify handler for custom PRNG key
This commit is contained in:
parent
899cc30419
commit
e5c2a2c0a3
@ -29,6 +29,7 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import api
|
||||
from jax._src import basearray
|
||||
from jax._src import config as config
|
||||
@ -413,6 +414,7 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
|
||||
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
|
||||
|
||||
def prngkeyarrayimpl_flatten(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user