mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[random] support key dtype in custom_jvp
To do this, we introduce a dtype for key tangents which cannot be used to generate random values
This commit is contained in:
parent
6cc6d09364
commit
c0f3fa00f8
@ -586,19 +586,19 @@ deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
|
||||
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
|
||||
|
||||
def instantiate_zeros(tangent):
|
||||
if type(tangent) is Zero:
|
||||
return zeros_like_aval(tangent.aval)
|
||||
else:
|
||||
if type(tangent) is not Zero:
|
||||
return tangent
|
||||
return instantiate_zeros_aval(tangent.aval, tangent)
|
||||
|
||||
# This function seems similar to instantiate_zeros, but it is sometimes used
|
||||
# to instantiate zero abstract units with a different aval
|
||||
def instantiate_zeros_aval(aval, tangent):
|
||||
if type(tangent) is Zero:
|
||||
assert tangent.aval == aval
|
||||
return zeros_like_aval(aval)
|
||||
else:
|
||||
if type(tangent) is not Zero:
|
||||
return tangent
|
||||
assert tangent.aval == aval
|
||||
if jax.dtypes.issubdtype(aval.dtype, jax.dtypes.extended):
|
||||
return aval.dtype._rules.make_tangent(aval.shape, aval.dtype)
|
||||
return zeros_like_aval(aval)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def traceable(in_tree, *primals_and_tangents):
|
||||
|
@ -472,6 +472,23 @@ class KeyTyRules:
|
||||
# the outset.
|
||||
return random_wrap(key_data, impl=dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
def make_tangent(shape, dtype):
|
||||
physical_shape = (*shape, *dtype._impl.key_shape)
|
||||
def not_implemented(name):
|
||||
def func(*args):
|
||||
raise NotImplementedError(f"Cannot call {name} on tangent of PRNG key.")
|
||||
return func
|
||||
impl = PRNGImpl(
|
||||
key_shape=dtype._impl.key_shape,
|
||||
seed=not_implemented('seed'),
|
||||
split=not_implemented('split'),
|
||||
random_bits=not_implemented('random_bits'),
|
||||
fold_in=not_implemented('fold_in'),
|
||||
name=f"{dtype._impl.name}_tangent",
|
||||
tag=f"{dtype._impl.tag}_t")
|
||||
return random_wrap(jnp.zeros(physical_shape, dtype='uint32'), impl=impl)
|
||||
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32'))
|
||||
@ -594,6 +611,20 @@ class KeyTyRules:
|
||||
return random_wrap(physical_result, impl=aval.dtype._impl)
|
||||
|
||||
|
||||
class KeyTangentTy(dtypes.ExtendedDType):
|
||||
"""A dtype to use for the tangent of a PRNGKey"""
|
||||
_impl: PRNGImpl
|
||||
type = dtypes.prng_key
|
||||
|
||||
@property
|
||||
def _rules(self):
|
||||
raise ValueError("Cannot perform operations on the tangent of a PRNGKey.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'key_tangent<{self._impl.tag}>'
|
||||
|
||||
|
||||
class KeyTy(dtypes.ExtendedDType):
|
||||
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
|
||||
_rules = KeyTyRules
|
||||
|
@ -1107,6 +1107,36 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'):
|
||||
jax.random.key(42, impl=A())
|
||||
|
||||
def test_keyarray_custom_vjp(self):
|
||||
# Regression test for https://github.com/google/jax/issues/18442
|
||||
@jax.custom_vjp
|
||||
def f(_, state):
|
||||
return state
|
||||
def _f_fwd(_, state):
|
||||
return state, None
|
||||
def _f_bwd(_, state_bar):
|
||||
assert state_bar[1].dtype.name == "key<fry_t>" # key tangent type
|
||||
return state_bar
|
||||
f.defvjp(_f_fwd, _f_bwd)
|
||||
state = (8.0, jax.random.key(123))
|
||||
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
def test_keyarray_custom_vjp_symbolic_zeros(self):
|
||||
@jax.custom_vjp
|
||||
def f(_, state):
|
||||
return state
|
||||
def _f_fwd(_, state):
|
||||
return tree_util.tree_map(lambda x: x.value, state), None
|
||||
def _f_bwd(_, state_bar):
|
||||
self.assertTrue(dtypes.issubdtype(state_bar[1].dtype, dtypes.prng_key))
|
||||
self.assertIsInstance(state_bar[1], jax.custom_derivatives.SymbolicZero)
|
||||
return state_bar
|
||||
f.defvjp(_f_fwd, _f_bwd, symbolic_zeros=True)
|
||||
state = (8.0, jax.random.key(123))
|
||||
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user