From 7d7a536b55ce55d91cccc53bd243363596e35c6b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 21 Jul 2023 09:48:38 -0700 Subject: [PATCH] custom prng: introduce mechanism to identify key arrays by dtype --- docs/jax.dtypes.rst | 1 + jax/_src/checkify.py | 4 ++-- jax/_src/dtypes.py | 19 ++++++++++++++++++- jax/_src/prng.py | 8 ++++---- jax/_src/random.py | 12 ++++++------ jax/dtypes.py | 1 + tests/random_test.py | 13 +++++++++---- 7 files changed, 41 insertions(+), 17 deletions(-) diff --git a/docs/jax.dtypes.rst b/docs/jax.dtypes.rst index f59e95d4e..151d80f90 100644 --- a/docs/jax.dtypes.rst +++ b/docs/jax.dtypes.rst @@ -10,5 +10,6 @@ canonicalize_dtype float0 issubdtype + prng_key result_type scalar_type_of diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index fc940e996..dfdbfd7e1 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -22,6 +22,7 @@ from typing import Union, Callable, TypeVar, Any import numpy as np import jax.numpy as jnp +from jax import dtypes from jax import lax from jax._src import api @@ -30,7 +31,6 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import effects from jax._src import pjit -from jax._src import prng from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util @@ -576,7 +576,7 @@ def check_nans(prim, error, enabled_errors, out): return error def isnan(x): - if isinstance(x, prng.PRNGKeyArray): + if jnp.issubdtype(x.dtype, dtypes.prng_key): return False return jnp.any(jnp.isnan(x)) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 8cb07b91a..a18d9cc94 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -46,6 +46,7 @@ else: FLAGS = flags.FLAGS +# TODO(jakevdp): rename opaque dtypes to something more user-friendly class opaque(np.generic): """Scalar class for opaque dtypes. @@ -62,6 +63,23 @@ class opaque(np.generic): pass +class prng_key(opaque): + """Scalar class for PRNG Key dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of `jnp.issubdtype`. + + Examples: + >>> from jax import random + >>> from jax import dtypes + >>> key = random.key(0) + >>> jnp.issubdtype(key.dtype, dtypes.prng_key) + True + """ + pass + + +# TODO(jakevdp): rename opaque dtypes to something more user-friendly class OpaqueDType(metaclass=abc.ABCMeta): """Abstract Base Class for opaque dtypes""" @property @@ -73,7 +91,6 @@ def is_opaque_dtype(dtype: Any) -> bool: # TODO(vanderplas, frostig): remove in favor of inlining `issubdtype` return issubdtype(dtype, opaque) - # fp8 support float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 5b14afa46..8a50214bc 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -447,7 +447,7 @@ class KeyTyRules: @staticmethod def full(shape, fill_value, dtype): physical_shape = (*shape, *dtype.impl.key_shape) - if isinstance(fill_value, PRNGKeyArray): + if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key): key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape) else: key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32')) @@ -580,7 +580,7 @@ class KeyTyRules: class KeyTy(dtypes.OpaqueDType): impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really _rules = KeyTyRules - type = dtypes.opaque + type = dtypes.prng_key def __init__(self, impl): self.impl = impl @@ -888,8 +888,8 @@ batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule def random_unwrap(keys): - if not isinstance(keys, PRNGKeyArrayImpl): - raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}') + if not jnp.issubdtype(keys.dtype, dtypes.prng_key): + raise TypeError(f'random_unwrap takes key array operand, got {keys.dtype=}') return random_unwrap_p.bind(keys) random_unwrap_p = core.Primitive('random_unwrap') diff --git a/jax/_src/random.py b/jax/_src/random.py index dda136454..1cd19bf18 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -84,7 +84,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]: def _return_prng_keys(was_wrapped, key): # TODO(frostig): remove once we always enable_custom_prng - assert isinstance(key, prng.PRNGKeyArray) + assert jnp.issubdtype(key.dtype, dtypes.prng_key) if config.jax_enable_custom_prng: return key else: @@ -92,7 +92,7 @@ def _return_prng_keys(was_wrapped, key): def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array: - assert isinstance(key, prng.PRNGKeyArray) + assert jnp.issubdtype(key.dtype, dtypes.prng_key) return prng.random_bits(key, bit_width=bit_width, shape=shape) @@ -129,7 +129,7 @@ def resolve_prng_impl(impl_spec: Optional[str]): def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str] ) -> PRNGKeyArray: impl = resolve_prng_impl(impl_spec) - if isinstance(seed, prng.PRNGKeyArray): + if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key): raise TypeError( f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.") if np.ndim(seed): @@ -209,7 +209,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray: def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray: # Alternative to fold_in() to use within random samplers. # TODO(frostig): remove and use fold_in() once we always enable_custom_prng - assert isinstance(key, prng.PRNGKeyArray) + assert jnp.issubdtype(key.dtype, dtypes.prng_key) if key.ndim: raise TypeError("fold_in accepts a single key, but was given a key array of" f"shape {key.shape} != (). Use jax.vmap for batching.") @@ -236,7 +236,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: # Alternative to split() to use within random samplers. # TODO(frostig): remove and use split(); we no longer need to wait # to always enable_custom_prng - assert isinstance(key, prng.PRNGKeyArray) + assert jnp.issubdtype(key.dtype, dtypes.prng_key) if key.ndim: raise TypeError("split accepts a single key, but was given a key array of" f"shape {key.shape} != (). Use jax.vmap for batching.") @@ -258,7 +258,7 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray: return _return_prng_keys(wrapped, _split(key, num)) def _key_data(keys: KeyArray) -> Array: - assert isinstance(keys, prng.PRNGKeyArray) + assert jnp.issubdtype(keys.dtype, dtypes.prng_key) return prng.random_unwrap(keys) def key_data(keys: KeyArray) -> Array: diff --git a/jax/dtypes.py b/jax/dtypes.py index aca0f19d3..fbc3f158b 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -22,6 +22,7 @@ from jax._src.dtypes import ( float0 as float0, iinfo, # TODO(phawkins): switch callers to jnp.iinfo? issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype? + prng_key as prng_key, result_type as result_type, scalar_type_of as scalar_type_of, ) diff --git a/tests/random_test.py b/tests/random_test.py index 98e17bee4..bb74e0b9d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -55,7 +55,7 @@ uint_dtypes = jtu.dtypes.all_unsigned def _prng_key_as_array(key): # TODO(frostig): remove some day when we deprecate "raw" key arrays - if isinstance(key, jax.random.PRNGKeyArray): + if jnp.issubdtype(key.dtype, dtypes.prng_key): return key.unsafe_raw_array() else: return key @@ -63,7 +63,7 @@ def _prng_key_as_array(key): def _maybe_unwrap(key): # TODO(frostig): remove some day when we deprecate "raw" key arrays unwrap = prng_internal.random_unwrap - return unwrap(key) if isinstance(key, jax.random.PRNGKeyArray) else key + return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl), @@ -207,7 +207,7 @@ KEY_CTORS = [random.key, random.PRNGKey] class PrngTest(jtu.JaxTestCase): def check_key_has_impl(self, key, impl): - if isinstance(key, random.PRNGKeyArray): + if jnp.issubdtype(key.dtype, dtypes.prng_key): self.assertIs(key.impl, impl) else: self.assertEqual(key.dtype, jnp.dtype('uint32')) @@ -1681,6 +1681,11 @@ class KeyArrayTest(jtu.JaxTestCase): key = random.key(42) self.assertIsInstance(key, random.PRNGKeyArray) + def test_issubdtype(self): + key = random.key(42) + self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key)) + self.assertFalse(jnp.issubdtype(key.dtype, np.integer)) + @skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag') def test_construction_upgrade_flag(self): key = random.PRNGKey(42) @@ -2207,7 +2212,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest): def test_cannot_add(self): key = self.make_key(73) - if not isinstance(key, random.PRNGKeyArray): + if not jnp.issubdtype(key.dtype, dtypes.prng_key): raise SkipTest('relies on typed key arrays') self.assertRaisesRegex( ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',