mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
custom prng: introduce mechanism to identify key arrays by dtype
This commit is contained in:
parent
1b33a4eb05
commit
7d7a536b55
@ -10,5 +10,6 @@
|
||||
canonicalize_dtype
|
||||
float0
|
||||
issubdtype
|
||||
prng_key
|
||||
result_type
|
||||
scalar_type_of
|
||||
|
@ -22,6 +22,7 @@ from typing import Union, Callable, TypeVar, Any
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
|
||||
from jax._src import api
|
||||
@ -30,7 +31,6 @@ from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
@ -576,7 +576,7 @@ def check_nans(prim, error, enabled_errors, out):
|
||||
return error
|
||||
|
||||
def isnan(x):
|
||||
if isinstance(x, prng.PRNGKeyArray):
|
||||
if jnp.issubdtype(x.dtype, dtypes.prng_key):
|
||||
return False
|
||||
return jnp.any(jnp.isnan(x))
|
||||
|
||||
|
@ -46,6 +46,7 @@ else:
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# TODO(jakevdp): rename opaque dtypes to something more user-friendly
|
||||
class opaque(np.generic):
|
||||
"""Scalar class for opaque dtypes.
|
||||
|
||||
@ -62,6 +63,23 @@ class opaque(np.generic):
|
||||
pass
|
||||
|
||||
|
||||
class prng_key(opaque):
|
||||
"""Scalar class for PRNG Key dtypes.
|
||||
|
||||
This is an abstract class that should never be instantiated, but rather
|
||||
exists for the sake of `jnp.issubdtype`.
|
||||
|
||||
Examples:
|
||||
>>> from jax import random
|
||||
>>> from jax import dtypes
|
||||
>>> key = random.key(0)
|
||||
>>> jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
True
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# TODO(jakevdp): rename opaque dtypes to something more user-friendly
|
||||
class OpaqueDType(metaclass=abc.ABCMeta):
|
||||
"""Abstract Base Class for opaque dtypes"""
|
||||
@property
|
||||
@ -73,7 +91,6 @@ def is_opaque_dtype(dtype: Any) -> bool:
|
||||
# TODO(vanderplas, frostig): remove in favor of inlining `issubdtype`
|
||||
return issubdtype(dtype, opaque)
|
||||
|
||||
|
||||
# fp8 support
|
||||
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
|
||||
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
|
||||
|
@ -447,7 +447,7 @@ class KeyTyRules:
|
||||
@staticmethod
|
||||
def full(shape, fill_value, dtype):
|
||||
physical_shape = (*shape, *dtype.impl.key_shape)
|
||||
if isinstance(fill_value, PRNGKeyArray):
|
||||
if hasattr(fill_value, 'dtype') and jnp.issubdtype(fill_value.dtype, dtypes.prng_key):
|
||||
key_data = jnp.broadcast_to(random_unwrap(fill_value), physical_shape)
|
||||
else:
|
||||
key_data = lax.full(physical_shape, fill_value, dtype=np.dtype('uint32'))
|
||||
@ -580,7 +580,7 @@ class KeyTyRules:
|
||||
class KeyTy(dtypes.OpaqueDType):
|
||||
impl: Hashable # prng.PRNGImpl. TODO(mattjj,frostig): protocol really
|
||||
_rules = KeyTyRules
|
||||
type = dtypes.opaque
|
||||
type = dtypes.prng_key
|
||||
|
||||
def __init__(self, impl):
|
||||
self.impl = impl
|
||||
@ -888,8 +888,8 @@ batching.primitive_batchers[random_wrap_p] = random_wrap_batch_rule
|
||||
|
||||
|
||||
def random_unwrap(keys):
|
||||
if not isinstance(keys, PRNGKeyArrayImpl):
|
||||
raise TypeError(f'random_unwrap takes key array operand, got {type(keys)}')
|
||||
if not jnp.issubdtype(keys.dtype, dtypes.prng_key):
|
||||
raise TypeError(f'random_unwrap takes key array operand, got {keys.dtype=}')
|
||||
return random_unwrap_p.bind(keys)
|
||||
|
||||
random_unwrap_p = core.Primitive('random_unwrap')
|
||||
|
@ -84,7 +84,7 @@ def _check_prng_key(key) -> tuple[prng.PRNGKeyArray, bool]:
|
||||
|
||||
def _return_prng_keys(was_wrapped, key):
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
if config.jax_enable_custom_prng:
|
||||
return key
|
||||
else:
|
||||
@ -92,7 +92,7 @@ def _return_prng_keys(was_wrapped, key):
|
||||
|
||||
|
||||
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> Array:
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
return prng.random_bits(key, bit_width=bit_width, shape=shape)
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ def resolve_prng_impl(impl_spec: Optional[str]):
|
||||
def _key(ctor_name: str, seed: Union[int, Array], impl_spec: Optional[str]
|
||||
) -> PRNGKeyArray:
|
||||
impl = resolve_prng_impl(impl_spec)
|
||||
if isinstance(seed, prng.PRNGKeyArray):
|
||||
if hasattr(seed, 'dtype') and jnp.issubdtype(seed.dtype, dtypes.prng_key):
|
||||
raise TypeError(
|
||||
f"{ctor_name} accepts a scalar seed, but was given a PRNGKeyArray.")
|
||||
if np.ndim(seed):
|
||||
@ -209,7 +209,7 @@ def unsafe_rbg_key(seed: int) -> KeyArray:
|
||||
def _fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
|
||||
# Alternative to fold_in() to use within random samplers.
|
||||
# TODO(frostig): remove and use fold_in() once we always enable_custom_prng
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
if key.ndim:
|
||||
raise TypeError("fold_in accepts a single key, but was given a key array of"
|
||||
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
||||
@ -236,7 +236,7 @@ def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
# Alternative to split() to use within random samplers.
|
||||
# TODO(frostig): remove and use split(); we no longer need to wait
|
||||
# to always enable_custom_prng
|
||||
assert isinstance(key, prng.PRNGKeyArray)
|
||||
assert jnp.issubdtype(key.dtype, dtypes.prng_key)
|
||||
if key.ndim:
|
||||
raise TypeError("split accepts a single key, but was given a key array of"
|
||||
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
||||
@ -258,7 +258,7 @@ def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
return _return_prng_keys(wrapped, _split(key, num))
|
||||
|
||||
def _key_data(keys: KeyArray) -> Array:
|
||||
assert isinstance(keys, prng.PRNGKeyArray)
|
||||
assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
|
||||
return prng.random_unwrap(keys)
|
||||
|
||||
def key_data(keys: KeyArray) -> Array:
|
||||
|
@ -22,6 +22,7 @@ from jax._src.dtypes import (
|
||||
float0 as float0,
|
||||
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
|
||||
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
|
||||
prng_key as prng_key,
|
||||
result_type as result_type,
|
||||
scalar_type_of as scalar_type_of,
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
|
||||
|
||||
def _prng_key_as_array(key):
|
||||
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
||||
if isinstance(key, jax.random.PRNGKeyArray):
|
||||
if jnp.issubdtype(key.dtype, dtypes.prng_key):
|
||||
return key.unsafe_raw_array()
|
||||
else:
|
||||
return key
|
||||
@ -63,7 +63,7 @@ def _prng_key_as_array(key):
|
||||
def _maybe_unwrap(key):
|
||||
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
||||
unwrap = prng_internal.random_unwrap
|
||||
return unwrap(key) if isinstance(key, jax.random.PRNGKeyArray) else key
|
||||
return unwrap(key) if jnp.issubdtype(key, dtypes.prng_key) else key
|
||||
|
||||
|
||||
PRNG_IMPLS = [('threefry2x32', prng.threefry_prng_impl),
|
||||
@ -207,7 +207,7 @@ KEY_CTORS = [random.key, random.PRNGKey]
|
||||
class PrngTest(jtu.JaxTestCase):
|
||||
|
||||
def check_key_has_impl(self, key, impl):
|
||||
if isinstance(key, random.PRNGKeyArray):
|
||||
if jnp.issubdtype(key.dtype, dtypes.prng_key):
|
||||
self.assertIs(key.impl, impl)
|
||||
else:
|
||||
self.assertEqual(key.dtype, jnp.dtype('uint32'))
|
||||
@ -1681,6 +1681,11 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
key = random.key(42)
|
||||
self.assertIsInstance(key, random.PRNGKeyArray)
|
||||
|
||||
def test_issubdtype(self):
|
||||
key = random.key(42)
|
||||
self.assertTrue(jnp.issubdtype(key.dtype, dtypes.prng_key))
|
||||
self.assertFalse(jnp.issubdtype(key.dtype, np.integer))
|
||||
|
||||
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
|
||||
def test_construction_upgrade_flag(self):
|
||||
key = random.PRNGKey(42)
|
||||
@ -2207,7 +2212,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.make_key(73)
|
||||
if not isinstance(key, random.PRNGKeyArray):
|
||||
if not jnp.issubdtype(key.dtype, dtypes.prng_key):
|
||||
raise SkipTest('relies on typed key arrays')
|
||||
self.assertRaisesRegex(
|
||||
ValueError, r'dtype=key<.*> is not a valid dtype for JAX type promotion.',
|
||||
|
Loading…
x
Reference in New Issue
Block a user