[random] support key dtype in custom_jvp

To do this, we introduce a dtype for key tangents which cannot be used
to generate random values
This commit is contained in:
Jake VanderPlas 2023-11-10 11:16:23 -08:00
parent 6cc6d09364
commit c0f3fa00f8
3 changed files with 68 additions and 7 deletions

View File

@ -586,19 +586,19 @@ deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
def instantiate_zeros(tangent):
if type(tangent) is Zero:
return zeros_like_aval(tangent.aval)
else:
if type(tangent) is not Zero:
return tangent
return instantiate_zeros_aval(tangent.aval, tangent)
# This function seems similar to instantiate_zeros, but it is sometimes used
# to instantiate zero abstract units with a different aval
def instantiate_zeros_aval(aval, tangent):
if type(tangent) is Zero:
assert tangent.aval == aval
return zeros_like_aval(aval)
else:
if type(tangent) is not Zero:
return tangent
assert tangent.aval == aval
if jax.dtypes.issubdtype(aval.dtype, jax.dtypes.extended):
return aval.dtype._rules.make_tangent(aval.shape, aval.dtype)
return zeros_like_aval(aval)
@lu.transformation_with_aux
def traceable(in_tree, *primals_and_tangents):

View File

@ -472,6 +472,23 @@ class KeyTyRules:
# the outset.
return random_wrap(key_data, impl=dtype._impl)
@staticmethod
def make_tangent(shape, dtype):
physical_shape = (*shape, *dtype._impl.key_shape)
def not_implemented(name):
def func(*args):
raise NotImplementedError(f"Cannot call {name} on tangent of PRNG key.")
return func
impl = PRNGImpl(
key_shape=dtype._impl.key_shape,
seed=not_implemented('seed'),
split=not_implemented('split'),
random_bits=not_implemented('random_bits'),
fold_in=not_implemented('fold_in'),
name=f"{dtype._impl.name}_tangent",
tag=f"{dtype._impl.tag}_t")
return random_wrap(jnp.zeros(physical_shape, dtype='uint32'), impl=impl)
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray(dtype._impl.key_shape, jnp.dtype('uint32'))
@ -594,6 +611,20 @@ class KeyTyRules:
return random_wrap(physical_result, impl=aval.dtype._impl)
class KeyTangentTy(dtypes.ExtendedDType):
"""A dtype to use for the tangent of a PRNGKey"""
_impl: PRNGImpl
type = dtypes.prng_key
@property
def _rules(self):
raise ValueError("Cannot perform operations on the tangent of a PRNGKey.")
@property
def name(self) -> str:
return f'key_tangent<{self._impl.tag}>'
class KeyTy(dtypes.ExtendedDType):
_impl: PRNGImpl # TODO(mattjj,frostig): protocol really
_rules = KeyTyRules

View File

@ -1107,6 +1107,36 @@ class KeyArrayTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'):
jax.random.key(42, impl=A())
def test_keyarray_custom_vjp(self):
# Regression test for https://github.com/google/jax/issues/18442
@jax.custom_vjp
def f(_, state):
return state
def _f_fwd(_, state):
return state, None
def _f_bwd(_, state_bar):
assert state_bar[1].dtype.name == "key<fry_t>" # key tangent type
return state_bar
f.defvjp(_f_fwd, _f_bwd)
state = (8.0, jax.random.key(123))
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
self.assertEqual(result, 1.0)
def test_keyarray_custom_vjp_symbolic_zeros(self):
@jax.custom_vjp
def f(_, state):
return state
def _f_fwd(_, state):
return tree_util.tree_map(lambda x: x.value, state), None
def _f_bwd(_, state_bar):
self.assertTrue(dtypes.issubdtype(state_bar[1].dtype, dtypes.prng_key))
self.assertIsInstance(state_bar[1], jax.custom_derivatives.SymbolicZero)
return state_bar
f.defvjp(_f_fwd, _f_bwd, symbolic_zeros=True)
state = (8.0, jax.random.key(123))
result = jax.grad(lambda theta: f(theta, state)[0])(3.0)
self.assertEqual(result, 1.0)
# TODO(frostig,mattjj): more polymorphic primitives tests