custom prng: introduce mechanism to identify key arrays by dtype

This commit is contained in:
Jake VanderPlas 2023-07-21 09:48:38 -07:00
parent 1b33a4eb05
commit 7d7a536b55
7 changed files with 41 additions and 17 deletions

View File

@ -10,5 +10,6 @@
canonicalize_dtype
float0
issubdtype
prng_key
result_type
scalar_type_of

View File

@ -22,6 +22,7 @@ from typing import Union, Callable, TypeVar, Any
import numpy as np
import jax.numpy as jnp
from jax import dtypes
from jax import lax
from jax._src import api
@ -30,7 +31,6 @@ from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
@ -576,7 +576,7 @@ def check_nans(prim, error, enabled_errors, out):
return error
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
if jnp.issubdtype(x.dtype, dtypes.prng_key):
return False
return jnp.any(jnp.isnan(x))

View File

@ -46,6 +46,7 @@ else:
FLAGS = flags.FLAGS
# TODO(jakevdp): rename opaque dtypes to something more user-friendly
class opaque(np.generic):
"""Scalar class for opaque dtypes.
@ -62,6 +63,23 @@ class opaque(np.generic):
pass
class prng_key(opaque):
"""Scalar class for PRNG Key dtypes.
This is an abstract class that should never be instantiated, but rather
exists for the sake of `jnp.issubdtype`.
Examples:
>>> from jax import random
>>> from jax import dtypes
>>> key = random.key(0)
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
True
"""
pass
# TODO(jakevdp): rename opaque dtypes to something more user-friendly
class OpaqueDType(metaclass=abc.ABCMeta):
"""Abstract Base Class for opaque dtypes"""
@property
@ -73,7 +91,6 @@ def is_opaque_dtype(dtype: Any) -> bool:
# TODO(vanderplas, frostig): remove in favor of inlining `issubdtype`
return issubdtype(dtype, opaque)
# fp8 support
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn

View File

@ -447,7 +447,7 @@ class KeyTyRules:
@staticmethod
def full(shape, fill_value, dtype):
physical_shape = (*shape, *dtype.impl.key_shape)
if isinstance(fill_value, PRNGKeyArray):
if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key):
key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape)
else:
key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32'))
@ -580,7 +580,7 @@ class KeyTyRules:
class KeyTy(dtypes.OpaqueDType):
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
_rules = KeyTyRules
type = dtypes.opaque
type = dtypes.prng_key
def __init__(self, impl):
self.impl = impl
@ -888,8 +888,8 @@ batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
def random_unwrap(keys):
if not isinstance(keys, PRNGKeyArrayImpl):
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
if not jnp.issubdtype(keys.dtype, dtypes.prng_key):
raise TypeError(f'random_unwrap takes key array operand, got {keys.dtype=}')
return random_unwrap_p.bind(keys)
random_unwrap_p = core.Primitive('random_unwrap')

View File

@ -84,7 +84,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
def _return_prng_keys(was_wrapped, key):
# TODO(frostig): remove once we always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if config.jax_enable_custom_prng:
return key
else:
@ -92,7 +92,7 @@ def _return_prng_keys(was_wrapped, key):
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array:
assert isinstance(key, prng.PRNGKeyArray)
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
return prng.random_bits(key, bit_width=bit_width, shape=shape)
@ -129,7 +129,7 @@ def resolve_prng_impl(impl_spec: Optional[str]):
def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
) -> PRNGKeyArray:
impl = resolve_prng_impl(impl_spec)
if isinstance(seed, prng.PRNGKeyArray):
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
raise TypeError(
f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
@ -209,7 +209,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
# Alternative to fold_in() to use within random samplers.
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("fold_in accepts a single key, but was given a key array of"
f"shape {key.shape} != (). Use jax.vmap for batching.")
@ -236,7 +236,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
assert isinstance(key, prng.PRNGKeyArray)
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
if key.ndim:
raise TypeError("split accepts a single key, but was given a key array of"
f"shape {key.shape} != (). Use jax.vmap for batching.")
@ -258,7 +258,7 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
return _return_prng_keys(wrapped, _split(key, num))
def _key_data(keys: KeyArray) -> Array:
assert isinstance(keys, prng.PRNGKeyArray)
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
return prng.random_unwrap(keys)
def key_data(keys: KeyArray) -> Array:

View File

@ -22,6 +22,7 @@ from jax._src.dtypes import (
float0 as float0,
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
prng_key as prng_key,
result_type as result_type,
scalar_type_of as scalar_type_of,
)

View File

@ -55,7 +55,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
def _prng_key_as_array(key):
# TODO(frostig): remove some day when we deprecate "raw" key arrays
if isinstance(key, jax.random.PRNGKeyArray):
if jnp.issubdtype(key.dtype, dtypes.prng_key):
return key.unsafe_raw_array()
else:
return key
@ -63,7 +63,7 @@ def _prng_key_as_array(key):
def _maybe_unwrap(key):
# TODO(frostig): remove some day when we deprecate "raw" key arrays
unwrap = prng_internal.random_unwrap
return unwrap(key) if isinstance(key, jax.random.PRNGKeyArray) else key
return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
@ -207,7 +207,7 @@ KEY_CTORS = [random.key, random.PRNGKey]
class PrngTest(jtu.JaxTestCase):
def check_key_has_impl(self, key, impl):
if isinstance(key, random.PRNGKeyArray):
if jnp.issubdtype(key.dtype, dtypes.prng_key):
self.assertIs(key.impl, impl)
else:
self.assertEqual(key.dtype, jnp.dtype('uint32'))
@ -1681,6 +1681,11 @@ class KeyArrayTest(jtu.JaxTestCase):
key = random.key(42)
self.assertIsInstance(key, random.PRNGKeyArray)
def test_issubdtype(self):
key = random.key(42)
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key))
self.assertFalse(jnp.issubdtype(key.dtype, np.integer))
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_construction_upgrade_flag(self):
key = random.PRNGKey(42)
@ -2207,7 +2212,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def test_cannot_add(self):
key = self.make_key(73)
if not isinstance(key, random.PRNGKeyArray):
if not jnp.issubdtype(key.dtype, dtypes.prng_key):
raise SkipTest('relies on typed key arrays')
self.assertRaisesRegex(
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',