Custom PRNG: make KeyArray compatible with custom_jvp

This commit is contained in:
Jake VanderPlas 2023-05-16 17:05:21 -07:00
parent 613559d4e5
commit 6ef4e5f01a
2 changed files with 28 additions and 0 deletions

View File

@ -25,6 +25,7 @@ from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_util
from jax._src import api
from jax._src import basearray
from jax._src import config as config_lib
@ -317,6 +318,8 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
basearray.Array.register(PRNGKeyArrayImpl)
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:

View File

@ -1910,6 +1910,31 @@ class KeyArrayTest(jtu.JaxTestCase):
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
self.assertArraysEqual(result, keys)
def test_key_array_custom_jvp(self):
def f_raw(x, key):
return x * jax.random.normal(key, ())
f = jax.custom_jvp(f_raw)
@f.defjvp
def f_jvp(primals, tangents):
nonlocal key_dot
x, key = primals
x_dot, key_dot = tangents
rand = jax.random.normal(key, ())
tangent_out = x_dot * rand
primal_out = x * rand
return primal_out, tangent_out
key_dot = None
key = self.make_keys()
default_result = jax.grad(f_raw)(0.0, key)
custom_result = jax.grad(f)(0.0, key)
self.assertAllClose(default_result, custom_result)
self.assertIsInstance(key_dot, jax.random.PRNGKeyArray)
self.assertArraysEqual(jax.random.key_data(key_dot), np.uint32(0))
# TODO(frostig,mattjj): more polymorphic primitives tests