mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Custom PRNG: make KeyArray compatible with custom_jvp
This commit is contained in:
parent
613559d4e5
commit
6ef4e5f01a
@ -25,6 +25,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import basearray
|
||||
from jax._src import config as config_lib
|
||||
@ -317,6 +318,8 @@ _set_array_base_attributes(PRNGKeyArrayImpl, include=[
|
||||
'squeeze', 'swapaxes', 'take', 'transpose', 'T'])
|
||||
basearray.Array.register(PRNGKeyArrayImpl)
|
||||
|
||||
ad_util.jaxval_zeros_likers[PRNGKeyArrayImpl] = jnp.zeros_like # type: ignore[has-type]
|
||||
|
||||
|
||||
# TODO(frostig): remove, rerouting callers directly to random_seed
|
||||
def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
|
||||
|
@ -1910,6 +1910,31 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
|
||||
self.assertArraysEqual(result, keys)
|
||||
|
||||
def test_key_array_custom_jvp(self):
|
||||
def f_raw(x, key):
|
||||
return x * jax.random.normal(key, ())
|
||||
|
||||
f = jax.custom_jvp(f_raw)
|
||||
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
nonlocal key_dot
|
||||
x, key = primals
|
||||
x_dot, key_dot = tangents
|
||||
rand = jax.random.normal(key, ())
|
||||
tangent_out = x_dot * rand
|
||||
primal_out = x * rand
|
||||
return primal_out, tangent_out
|
||||
|
||||
key_dot = None
|
||||
key = self.make_keys()
|
||||
default_result = jax.grad(f_raw)(0.0, key)
|
||||
custom_result = jax.grad(f)(0.0, key)
|
||||
|
||||
self.assertAllClose(default_result, custom_result)
|
||||
self.assertIsInstance(key_dot, jax.random.PRNGKeyArray)
|
||||
self.assertArraysEqual(jax.random.key_data(key_dot), np.uint32(0))
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user