[random] add shaped_abstractify handler for custom PRNG key

This commit is contained in:
Jake VanderPlas 2023-10-10 16:15:19 -07:00
parent 899cc30419
commit e5c2a2c0a3

View File

@ -29,6 +29,7 @@ from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config
@ -413,6 +414,7 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
api_util._shaped_abstractify_handlers[PRNGKeyArrayImpl] = op.attrgetter('aval')
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
def prngkeyarrayimpl_flatten(x):