accept general shape option in jax.random.split

Several PRNG implementations (notably partitionable threefry) support
splitting to arbitrary shapes, rather than only to a 1-D vector of
keys. This change:

* Upgrades `jax.random.split` to accept a general shape as an
  argument.
* Updates the internal PRNG interface, and our various PRNG
  implementations, to accept and handle such a shape argument.

This change keeps the argument name `num`. We can still think on
whether and how we'd like to upgrade to `shape`.

Note that we could have supported arbitrary shapes by reduction to the
previous API (with a flat split count), using reshapes. We'd like to
avoid that, so as not to hide this structure from the underlying
implementation. For instance, partitionable threefry hashes a *shaped*
iota in order to split keys, and we don't want to flatten and reshape
around that for no reason.

Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
Roy Frostig 2023-07-06 13:23:48 -07:00
parent a29d4bcd33
commit df2891ff13
4 changed files with 65 additions and 44 deletions

View File

@ -64,6 +64,7 @@ zip, unsafe_zip = safe_zip, zip
Device = xc.Device
Shard = Any # TODO(jakevdp): fix circular imports and import Shard
Shape = tuple[int, ...]
UINT_DTYPES = {
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
@ -80,14 +81,14 @@ class PRNGImpl(NamedTuple):
seed :: int[] -> K
fold_in :: K -> int[] -> K
split[n] :: K -> K[n]
random_bits[shape, bit_width] :: K -> uint<bit_width>[shape]
split[shape] :: K -> K[*shape]
random_bits[shape, bit_width] :: K -> uint<bit_width>[*shape]
A PRNG implementation is adapted to an array-like object of keys
``K`` by the ``PRNGKeyArray`` class, which should be created via the
``seed_with_impl`` function.
"""
key_shape: core.Shape
key_shape: Shape
seed: Callable
split: Callable
random_bits: Callable
@ -717,31 +718,31 @@ def random_seed_lowering(ctx, seeds, *, impl):
mlir.register_lowering(random_seed_p, random_seed_lowering)
def random_split(keys, count):
return random_split_p.bind(keys, count=count)
def random_split(keys, shape: Shape):
return random_split_p.bind(keys, shape=shape)
random_split_p = core.Primitive('random_split')
ad.defjvp_zero(random_split_p)
batching.defvectorized(random_split_p)
@random_split_p.def_abstract_eval
def random_split_abstract_eval(keys_aval, *, count):
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
def random_split_abstract_eval(keys_aval, *, shape):
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, *shape))
@random_split_p.def_impl
def random_split_impl(keys, *, count):
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
keys.impl, keys.unsafe_raw_array(), keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys.impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
return split(base_arr)
def random_split_lowering(ctx, keys, *, count):
def random_split_lowering(ctx, keys, *, shape):
aval, = ctx.avals_in
impl = aval.dtype.impl
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count))
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, shape))
split_lowering = mlir.lower_fun(split, multiple_results=False)
return mlir.delegate_lowering(
ctx, split_lowering, keys,
@ -1249,28 +1250,29 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
def threefry_split(key: typing.Array, num: core.DimSize) -> typing.Array:
num = core.concrete_dim_or_error(num)
return _threefry_split(key, num)
def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
shape = tuple(unsafe_map(core.concrete_dim_or_error, shape))
return _threefry_split(key, shape)
@partial(jit, static_argnums=(1,))
def _threefry_split(key, num) -> typing.Array:
def _threefry_split(key, shape) -> typing.Array:
if config.jax_threefry_partitionable:
return _threefry_split_foldlike(key, num) # type: ignore
return _threefry_split_foldlike(key, shape) # type: ignore
else:
return _threefry_split_original(key, num) # type: ignore
return _threefry_split_original(key, shape) # type: ignore
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_original(key, num) -> typing.Array:
def _threefry_split_original(key, shape) -> typing.Array:
num = math.prod(shape)
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))
return lax.reshape(threefry_2x32(key, counts), (*shape, 2))
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_foldlike(key, num) -> typing.Array:
def _threefry_split_foldlike(key, shape) -> typing.Array:
k1, k2 = key
counts1, counts2 = iota_2x32_shape((num,))
counts1, counts2 = iota_2x32_shape(shape)
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
return jnp.stack([bits1, bits2], axis=1)
return jnp.stack([bits1, bits2], axis=bits1.ndim)
def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
@ -1329,7 +1331,7 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape):
if not nblocks:
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
else:
keys = threefry_split(key, nblocks + 1)
keys = threefry_split(key, (nblocks + 1,))
subkeys, last_key = keys[:-1], keys[-1]
blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
@ -1378,13 +1380,15 @@ def _rbg_seed(seed: typing.Array) -> typing.Array:
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: typing.Array, num: int) -> typing.Array:
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
if config.jax_threefry_partitionable:
_threefry_split = _threefry_split_foldlike
else:
_threefry_split = _threefry_split_original
halfkeys = key.reshape(2, 2)
return vmap(
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
_threefry_split, (0, None), len(shape))(halfkeys, shape).reshape(
*shape, 4)
def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
assert not data.shape
@ -1407,10 +1411,12 @@ rbg_prng_impl = PRNGImpl(
fold_in=_rbg_fold_in,
tag='rbg')
def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array:
def _unsafe_rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
# treat 10 iterations of random bits as a 'hash function'
num = math.prod(shape)
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
return lax.slice_in_dim(keys, start_index=None, limit_index=None, stride=10)
return lax.slice_in_dim(
keys, start_index=None, limit_index=None, stride=10).reshape(*shape, 4)
def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
assert not data.shape

View File

@ -232,7 +232,7 @@ def fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
key, wrapped = _check_prng_key(key)
return _return_prng_keys(wrapped, _fold_in(key, data))
def _split(key: KeyArray, num: int = 2) -> KeyArray:
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
# Alternative to split() to use within random samplers.
# TODO(frostig): remove and use split(); we no longer need to wait
# to always enable_custom_prng
@ -240,15 +240,16 @@ def _split(key: KeyArray, num: int = 2) -> KeyArray:
if key.ndim:
raise TypeError("split accepts a single key, but was given a key array of"
f"shape {key.shape} != (). Use jax.vmap for batching.")
return prng.random_split(key, count=num)
shape = tuple(num) if isinstance(num, Sequence) else (num,)
return prng.random_split(key, shape=shape)
def split(key: KeyArray, num: int = 2) -> KeyArray:
def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
"""Splits a PRNG key into `num` new keys by adding a leading axis.
Args:
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
num: optional, a positive integer indicating the number of keys to produce
(default 2).
num: optional, a positive integer (or tuple of integers) indicating
the number (or shape) of keys to produce. Defaults to 2.
Returns:
An array-like object of `num` new PRNG keys.

View File

@ -2575,18 +2575,18 @@ def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
def _random_split_impl(keys: TfVal, *, shape, _in_avals, _out_aval):
keys_aval, = _in_avals
def impl_wrapper(keys: TfVal, *, count):
def impl_wrapper(keys: TfVal, *, shape):
return prng.random_split_impl_base(
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
keys_aval.dtype.impl, keys, keys_aval.ndim, shape=shape)
converted_impl = _convert_jax_impl(
impl_wrapper, multiple_results=False, with_physical_avals=True,
extra_name_stack="random_split")
return converted_impl(
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
keys, shape=shape, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[prng.random_split_p] = _random_split_impl

View File

@ -612,6 +612,23 @@ class LaxRandomTest(jtu.JaxTestCase):
def make_key(self, seed):
return random.threefry2x32_key(seed)
@jtu.sample_product(
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
)
def test_split_size_shape(self, num):
key = self.make_key(0)
if num is None:
key_split = jax.random.split(key)
else:
key_split = jax.random.split(key, num)
if num is None:
self.assertEqual(key_split.shape, (2, *key.shape))
elif type(num) is tuple:
self.assertEqual(key_split.shape, (*num, *key.shape))
else:
self.assertEqual(key_split.shape, (num, *key.shape))
@jtu.sample_product(dtype=jtu.dtypes.floating)
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
@ -2039,12 +2056,9 @@ def _double_threefry_seed(seed):
return jnp.vstack([threefry_seed(s1),
threefry_seed(s2)])
def _double_threefry_split(key, num):
split0 = threefry_split(key[0], num)
split1 = threefry_split(key[1], num)
merge = jnp.vstack([jnp.expand_dims(split0, axis=0),
jnp.expand_dims(split1, axis=0)])
return merge.transpose((1, 0, 2))
def _double_threefry_split(key, shape):
return vmap(
threefry_split, (0, None), len(shape))(key, shape)
def _double_threefry_random_bits(key, bit_width, shape):
bits0 = threefry_random_bits(key[0], bit_width, shape)