hide keys attribute of PRNGKeyArray in favor of unsafe_raw_array

This commit is contained in:
Roy Frostig 2021-10-11 21:21:37 -07:00
parent cfa0f78bed
commit ba370f8c86
4 changed files with 29 additions and 19 deletions

View File

@ -97,7 +97,7 @@ class PRNGKeyArray:
"""
impl: PRNGImpl
keys: jnp.ndarray
_keys: jnp.ndarray
def __init__(self, impl, key_data: jnp.ndarray):
# key_data might be a placeholder python `object` or `bool`
@ -107,10 +107,18 @@ class PRNGKeyArray:
raise TypeError(
f'Invalid PRNG key data {key_data} for PRNG implementation {impl}')
self.impl = impl
self.keys = key_data
self._keys = key_data
def tree_flatten(self):
return (self.keys,), self.impl
return (self._keys,), self.impl
def unsafe_raw_array(self):
"""Access the raw numerical array that carries underlying key data.
Returns:
A uint32 JAX array whose leading dimensions are ``self.shape``.
"""
return self._keys
@classmethod
def tree_unflatten(cls, impl, keys):
@ -137,26 +145,26 @@ class PRNGKeyArray:
'deprecated `shape` attribute of PRNG key arrays. In a future version '
'of JAX this attribute will be removed or its value may change.',
FutureWarning)
return self.keys.shape
return self._keys.shape
@property
def _shape(self):
base_ndim = len(self.impl.key_shape)
return self.keys.shape[:-base_ndim]
return self._keys.shape[:-base_ndim]
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
return self.keys.ndim == base_ndim
return self._keys.ndim == base_ndim
def __len__(self):
if self._is_scalar():
raise TypeError('len() of unsized object')
return len(self.keys)
return len(self._keys)
def __iter__(self) -> Iterator['PRNGKeyArray']:
if self._is_scalar():
raise TypeError('iteration over a 0-d single PRNG key')
return (PRNGKeyArray(self.impl, k) for k in iter(self.keys))
return (PRNGKeyArray(self.impl, k) for k in iter(self._keys))
def __getitem__(self, idx) -> 'PRNGKeyArray':
if not isinstance(idx, tuple):
@ -166,21 +174,21 @@ class PRNGKeyArray:
'PRNGKeyArray only supports indexing with integer indices. '
f'Cannot index at {idx}')
base_ndim = len(self.impl.key_shape)
ndim = self.keys.ndim - base_ndim
ndim = self._keys.ndim - base_ndim
if len(idx) > ndim:
raise IndexError(
f'too many indices for PRNGKeyArray: array is {ndim}-dimensional '
f'but {len(idx)} were indexed')
return PRNGKeyArray(self.impl, self.keys[idx])
return PRNGKeyArray(self.impl, self._keys[idx])
def _fold_in(self, data: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.fold_in(self.keys, data))
return PRNGKeyArray(self.impl, self.impl.fold_in(self._keys, data))
def _random_bits(self, bit_width, shape) -> jnp.ndarray:
return self.impl.random_bits(self.keys, bit_width, shape)
return self.impl.random_bits(self._keys, bit_width, shape)
def _split(self, num: int) -> 'PRNGKeyArray':
return PRNGKeyArray(self.impl, self.impl.split(self.keys, num))
return PRNGKeyArray(self.impl, self.impl.split(self._keys, num))
def __repr__(self):
arr_shape = self._shape

View File

@ -81,7 +81,7 @@ def _return_prng_keys(was_wrapped, key):
if config.jax_enable_custom_prng:
return key
else:
return key.keys if was_wrapped else key
return key.unsafe_raw_array() if was_wrapped else key
def _random_bits(key: prng.PRNGKeyArray, bit_width, shape) -> jnp.ndarray:
key, _ = _check_prng_key(key)
@ -987,7 +987,7 @@ def gamma(key: KeyArray,
if key.impl is not prng.threefry_prng_impl:
raise NotImplementedError(
f'`gamma` is only implemented for the threefry2x32 RNG, not {key.impl}')
return gamma_threefry2x32(key.keys, a, shape, dtype)
return gamma_threefry2x32(key.unsafe_raw_array(), a, shape, dtype)
def gamma_threefry2x32(key: jnp.ndarray, # raw ndarray form of a 2x32 key
a: RealArray,

View File

@ -523,7 +523,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# TODO(frostig): remove once we always enable_custom_prng
def _prng_key_as_array(key):
return key.keys if config.jax_enable_custom_prng else key
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
# TODO(frostig): remove once we always enable_custom_prng
def _array_as_prng_key(arr):

View File

@ -48,7 +48,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
def _prng_key_as_array(key):
# TODO(frostig): remove once we upgrade to always enable_custom_prng
return key.keys if config.jax_enable_custom_prng else key
return key.unsafe_raw_array() if config.jax_enable_custom_prng else key
@jtu.with_config(jax_numpy_rank_promotion="raise")
@ -1173,7 +1173,8 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
vmapped_keys = vmap(lambda _: random.split(key))(jnp.zeros(3,))
self.assertEqual(vmapped_keys.shape, (3, 2))
for vk in vmapped_keys:
self.assertArraysEqual(vk.keys, single_split_key.keys)
self.assertArraysEqual(vk.unsafe_raw_array(),
single_split_key.unsafe_raw_array())
def test_vmap_split_mapped_key(self):
key = self.seed_prng(73)
@ -1182,7 +1183,8 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(fk.keys, vk.keys)
self.assertArraysEqual(fk.unsafe_raw_array(),
vk.unsafe_raw_array())
def test_vmap_random_bits(self):
rand_fun = lambda key: random.randint(key, (), 0, 100)