mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
hide keys
attribute of PRNGKeyArray
in favor of unsafe_raw_array
This commit is contained in:
parent
cfa0f78bed
commit
ba370f8c86
@ -97,7 +97,7 @@ class PRNGKeyArray:
|
||||
"""
|
||||
|
||||
impl: PRNGImpl
|
||||
keys: jnp.ndarray
|
||||
_keys: jnp.ndarray
|
||||
|
||||
def __init__(self, impl, key_data: jnp.ndarray):
|
||||
# key_data might be a placeholder python `object` or `bool`
|
||||
@ -107,10 +107,18 @@ class PRNGKeyArray:
|
||||
raise TypeError(
|
||||
f'Invalid PRNG key data {key_data} for PRNG implementation {impl}')
|
||||
self.impl = impl
|
||||
self.keys = key_data
|
||||
self._keys = key_data
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self.keys,), self.impl
|
||||
return (self._keys,), self.impl
|
||||
|
||||
def unsafe_raw_array(self):
|
||||
"""Access the raw numerical array that carries underlying key data.
|
||||
|
||||
Returns:
|
||||
A uint32 JAX array whose leading dimensions are ``self.shape``.
|
||||
"""
|
||||
return self._keys
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, impl, keys):
|
||||
@ -137,26 +145,26 @@ class PRNGKeyArray:
|
||||
'deprecated `shape` attribute of PRNG key arrays. In a future version '
|
||||
'of JAX this attribute will be removed or its value may change.',
|
||||
FutureWarning)
|
||||
return self.keys.shape
|
||||
return self._keys.shape
|
||||
|
||||
@property
|
||||
def _shape(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
return self.keys.shape[:-base_ndim]
|
||||
return self._keys.shape[:-base_ndim]
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
return self.keys.ndim == base_ndim
|
||||
return self._keys.ndim == base_ndim
|
||||
|
||||
def __len__(self):
|
||||
if self._is_scalar():
|
||||
raise TypeError('len() of unsized object')
|
||||
return len(self.keys)
|
||||
return len(self._keys)
|
||||
|
||||
def __iter__(self) -> Iterator['PRNGKeyArray']:
|
||||
if self._is_scalar():
|
||||
raise TypeError('iteration over a 0-d single PRNG key')
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self.keys))
|
||||
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))
|
||||
|
||||
def __getitem__(self, idx) -> 'PRNGKeyArray':
|
||||
if not isinstance(idx, tuple):
|
||||
@ -166,21 +174,21 @@ class PRNGKeyArray:
|
||||
'PRNGKeyArray only supports indexing with integer indices. '
|
||||
f'Cannot index at {idx}')
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
ndim = self.keys.ndim - base_ndim
|
||||
ndim = self._keys.ndim - base_ndim
|
||||
if len(idx) > ndim:
|
||||
raise IndexError(
|
||||
f'too many indices for PRNGKeyArray: array is {ndim}-dimensional '
|
||||
f'but {len(idx)} were indexed')
|
||||
return PRNGKeyArray(self.impl, self.keys[idx])
|
||||
return PRNGKeyArray(self.impl, self._keys[idx])
|
||||
|
||||
def _fold_in(self, data: int) -> 'PRNGKeyArray':
|
||||
return PRNGKeyArray(self.impl, self.impl.fold_in(self.keys, data))
|
||||
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
|
||||
|
||||
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
|
||||
return self.impl.random_bits(self.keys, bit_width, shape)
|
||||
return self.impl.random_bits(self._keys, bit_width, shape)
|
||||
|
||||
def _split(self, num: int) -> 'PRNGKeyArray':
|
||||
return PRNGKeyArray(self.impl, self.impl.split(self.keys, num))
|
||||
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
|
||||
|
||||
def __repr__(self):
|
||||
arr_shape = self._shape
|
||||
|
@ -81,7 +81,7 @@ def _return_prng_keys(was_wrapped, key):
|
||||
if config.jax_enable_custom_prng:
|
||||
return key
|
||||
else:
|
||||
return key.keys if was_wrapped else key
|
||||
return key.unsafe_raw_array() if was_wrapped else key
|
||||
|
||||
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
|
||||
key, _ = _check_prng_key(key)
|
||||
@ -987,7 +987,7 @@ def gamma(key: KeyArray,
|
||||
if key.impl is not prng.threefry_prng_impl:
|
||||
raise NotImplementedError(
|
||||
f'`gamma` is only implemented for the threefry2x32 RNG, not {key.impl}')
|
||||
return gamma_threefry2x32(key.keys, a, shape, dtype)
|
||||
return gamma_threefry2x32(key.unsafe_raw_array(), a, shape, dtype)
|
||||
|
||||
def gamma_threefry2x32(key: jnp.ndarray, # raw ndarray form of a 2x32 key
|
||||
a: RealArray,
|
||||
|
@ -523,7 +523,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _prng_key_as_array(key):
|
||||
return key.keys if config.jax_enable_custom_prng else key
|
||||
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
|
||||
|
||||
# TODO(frostig): remove once we always enable_custom_prng
|
||||
def _array_as_prng_key(arr):
|
||||
|
@ -48,7 +48,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
|
||||
|
||||
def _prng_key_as_array(key):
|
||||
# TODO(frostig): remove once we upgrade to always enable_custom_prng
|
||||
return key.keys if config.jax_enable_custom_prng else key
|
||||
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_rank_promotion="raise")
|
||||
@ -1173,7 +1173,8 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2))
|
||||
for vk in vmapped_keys:
|
||||
self.assertArraysEqual(vk.keys, single_split_key.keys)
|
||||
self.assertArraysEqual(vk.unsafe_raw_array(),
|
||||
single_split_key.unsafe_raw_array())
|
||||
|
||||
def test_vmap_split_mapped_key(self):
|
||||
key = self.seed_prng(73)
|
||||
@ -1182,7 +1183,8 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2))
|
||||
for fk, vk in zip(forloop_keys, vmapped_keys):
|
||||
self.assertArraysEqual(fk.keys, vk.keys)
|
||||
self.assertArraysEqual(fk.unsafe_raw_array(),
|
||||
vk.unsafe_raw_array())
|
||||
|
||||
def test_vmap_random_bits(self):
|
||||
rand_fun = lambda key: random.randint(key, (), 0, 100)
|
||||
|
Loading…
x
Reference in New Issue
Block a user