mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
accept general shape
option in jax.random.split
Several PRNG implementations (notably partitionable threefry) support splitting to arbitrary shapes, rather than only to a 1-D vector of keys. This change: * Upgrades `jax.random.split` to accept a general shape as an argument. * Updates the internal PRNG interface, and our various PRNG implementations, to accept and handle such a shape argument. This change keeps the argument name `num`. We can still think on whether and how we'd like to upgrade to `shape`. Note that we could have supported arbitrary shapes by reduction to the previous API (with a flat split count), using reshapes. We'd like to avoid that, so as not to hide this structure from the underlying implementation. For instance, partitionable threefry hashes a *shaped* iota in order to split keys, and we don't want to flatten and reshape around that for no reason. Co-authored-by: Jake Vanderplas <jakevdp@google.com>
This commit is contained in:
parent
a29d4bcd33
commit
df2891ff13
@ -64,6 +64,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
Device = xc.Device
|
||||
Shard = Any # TODO(jakevdp): fix circular imports and import Shard
|
||||
Shape = tuple[int, ...]
|
||||
|
||||
UINT_DTYPES = {
|
||||
8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} # type: ignore[has-type]
|
||||
@ -80,14 +81,14 @@ class PRNGImpl(NamedTuple):
|
||||
|
||||
seed :: int[] -> K
|
||||
fold_in :: K -> int[] -> K
|
||||
split[n] :: K -> K[n]
|
||||
random_bits[shape, bit_width] :: K -> uint<bit_width>[shape]
|
||||
split[shape] :: K -> K[*shape]
|
||||
random_bits[shape, bit_width] :: K -> uint<bit_width>[*shape]
|
||||
|
||||
A PRNG implementation is adapted to an array-like object of keys
|
||||
``K`` by the ``PRNGKeyArray`` class, which should be created via the
|
||||
``seed_with_impl`` function.
|
||||
"""
|
||||
key_shape: core.Shape
|
||||
key_shape: Shape
|
||||
seed: Callable
|
||||
split: Callable
|
||||
random_bits: Callable
|
||||
@ -717,31 +718,31 @@ def random_seed_lowering(ctx, seeds, *, impl):
|
||||
mlir.register_lowering(random_seed_p, random_seed_lowering)
|
||||
|
||||
|
||||
def random_split(keys, count):
|
||||
return random_split_p.bind(keys, count=count)
|
||||
def random_split(keys, shape: Shape):
|
||||
return random_split_p.bind(keys, shape=shape)
|
||||
|
||||
random_split_p = core.Primitive('random_split')
|
||||
ad.defjvp_zero(random_split_p)
|
||||
batching.defvectorized(random_split_p)
|
||||
|
||||
@random_split_p.def_abstract_eval
|
||||
def random_split_abstract_eval(keys_aval, *, count):
|
||||
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, count))
|
||||
def random_split_abstract_eval(keys_aval, *, shape):
|
||||
return keys_shaped_array(keys_aval.dtype.impl, (*keys_aval.shape, *shape))
|
||||
|
||||
@random_split_p.def_impl
|
||||
def random_split_impl(keys, *, count):
|
||||
def random_split_impl(keys, *, shape):
|
||||
base_arr = random_split_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
|
||||
keys.impl, keys.unsafe_raw_array(), keys.ndim, shape=shape)
|
||||
return PRNGKeyArrayImpl(keys.impl, base_arr)
|
||||
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
|
||||
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
|
||||
split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, shape))
|
||||
return split(base_arr)
|
||||
|
||||
def random_split_lowering(ctx, keys, *, count):
|
||||
def random_split_lowering(ctx, keys, *, shape):
|
||||
aval, = ctx.avals_in
|
||||
impl = aval.dtype.impl
|
||||
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, count))
|
||||
split = iterated_vmap_unary(aval.ndim, lambda k: impl.split(k, shape))
|
||||
split_lowering = mlir.lower_fun(split, multiple_results=False)
|
||||
return mlir.delegate_lowering(
|
||||
ctx, split_lowering, keys,
|
||||
@ -1249,28 +1250,29 @@ def threefry_2x32(keypair, count):
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
|
||||
|
||||
def threefry_split(key: typing.Array, num: core.DimSize) -> typing.Array:
|
||||
num = core.concrete_dim_or_error(num)
|
||||
return _threefry_split(key, num)
|
||||
def threefry_split(key: typing.Array, shape: Shape) -> typing.Array:
|
||||
shape = tuple(unsafe_map(core.concrete_dim_or_error, shape))
|
||||
return _threefry_split(key, shape)
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _threefry_split(key, num) -> typing.Array:
|
||||
def _threefry_split(key, shape) -> typing.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
return _threefry_split_foldlike(key, num) # type: ignore
|
||||
return _threefry_split_foldlike(key, shape) # type: ignore
|
||||
else:
|
||||
return _threefry_split_original(key, num) # type: ignore
|
||||
return _threefry_split_original(key, shape) # type: ignore
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_original(key, num) -> typing.Array:
|
||||
def _threefry_split_original(key, shape) -> typing.Array:
|
||||
num = math.prod(shape)
|
||||
counts = lax.iota(np.uint32, num * 2)
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
return lax.reshape(threefry_2x32(key, counts), (*shape, 2))
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_foldlike(key, num) -> typing.Array:
|
||||
def _threefry_split_foldlike(key, shape) -> typing.Array:
|
||||
k1, k2 = key
|
||||
counts1, counts2 = iota_2x32_shape((num,))
|
||||
counts1, counts2 = iota_2x32_shape(shape)
|
||||
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
|
||||
return jnp.stack([bits1, bits2], axis=1)
|
||||
return jnp.stack([bits1, bits2], axis=bits1.ndim)
|
||||
|
||||
|
||||
def threefry_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
@ -1329,7 +1331,7 @@ def _threefry_random_bits_original(key: typing.Array, bit_width, shape):
|
||||
if not nblocks:
|
||||
bits = threefry_2x32(key, lax.iota(np.uint32, rem))
|
||||
else:
|
||||
keys = threefry_split(key, nblocks + 1)
|
||||
keys = threefry_split(key, (nblocks + 1,))
|
||||
subkeys, last_key = keys[:-1], keys[-1]
|
||||
blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max))
|
||||
last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
|
||||
@ -1378,13 +1380,15 @@ def _rbg_seed(seed: typing.Array) -> typing.Array:
|
||||
halfkey = threefry_seed(seed)
|
||||
return jnp.concatenate([halfkey, halfkey])
|
||||
|
||||
def _rbg_split(key: typing.Array, num: int) -> typing.Array:
|
||||
def _rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
_threefry_split = _threefry_split_foldlike
|
||||
else:
|
||||
_threefry_split = _threefry_split_original
|
||||
halfkeys = key.reshape(2, 2)
|
||||
return vmap(
|
||||
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
||||
_threefry_split, (0, None), len(shape))(halfkeys, shape).reshape(
|
||||
*shape, 4)
|
||||
|
||||
def _rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
assert not data.shape
|
||||
@ -1407,10 +1411,12 @@ rbg_prng_impl = PRNGImpl(
|
||||
fold_in=_rbg_fold_in,
|
||||
tag='rbg')
|
||||
|
||||
def _unsafe_rbg_split(key: typing.Array, num: int) -> typing.Array:
|
||||
def _unsafe_rbg_split(key: typing.Array, shape: Shape) -> typing.Array:
|
||||
# treat 10 iterations of random bits as a 'hash function'
|
||||
num = math.prod(shape)
|
||||
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
||||
return lax.slice_in_dim(keys, start_index=None, limit_index=None, stride=10)
|
||||
return lax.slice_in_dim(
|
||||
keys, start_index=None, limit_index=None, stride=10).reshape(*shape, 4)
|
||||
|
||||
def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array:
|
||||
assert not data.shape
|
||||
|
@ -232,7 +232,7 @@ def fold_in(key: KeyArray, data: IntegerArray) -> KeyArray:
|
||||
key, wrapped = _check_prng_key(key)
|
||||
return _return_prng_keys(wrapped, _fold_in(key, data))
|
||||
|
||||
def _split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
def _split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
# Alternative to split() to use within random samplers.
|
||||
# TODO(frostig): remove and use split(); we no longer need to wait
|
||||
# to always enable_custom_prng
|
||||
@ -240,15 +240,16 @@ def _split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
if key.ndim:
|
||||
raise TypeError("split accepts a single key, but was given a key array of"
|
||||
f"shape {key.shape} != (). Use jax.vmap for batching.")
|
||||
return prng.random_split(key, count=num)
|
||||
shape = tuple(num) if isinstance(num, Sequence) else (num,)
|
||||
return prng.random_split(key, shape=shape)
|
||||
|
||||
def split(key: KeyArray, num: int = 2) -> KeyArray:
|
||||
def split(key: KeyArray, num: Union[int, tuple[int, ...]] = 2) -> KeyArray:
|
||||
"""Splits a PRNG key into `num` new keys by adding a leading axis.
|
||||
|
||||
Args:
|
||||
key: a PRNG key (from ``PRNGKey``, ``split``, ``fold_in``).
|
||||
num: optional, a positive integer indicating the number of keys to produce
|
||||
(default 2).
|
||||
num: optional, a positive integer (or tuple of integers) indicating
|
||||
the number (or shape) of keys to produce. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
An array-like object of `num` new PRNG keys.
|
||||
|
@ -2575,18 +2575,18 @@ def _random_seed_impl(seeds: TfVal, *, impl, _in_avals, _out_aval):
|
||||
tf_impl_with_avals[prng.random_seed_p] = _random_seed_impl
|
||||
|
||||
|
||||
def _random_split_impl(keys: TfVal, *, count, _in_avals, _out_aval):
|
||||
def _random_split_impl(keys: TfVal, *, shape, _in_avals, _out_aval):
|
||||
keys_aval, = _in_avals
|
||||
|
||||
def impl_wrapper(keys: TfVal, *, count):
|
||||
def impl_wrapper(keys: TfVal, *, shape):
|
||||
return prng.random_split_impl_base(
|
||||
keys_aval.dtype.impl, keys, keys_aval.ndim, count=count)
|
||||
keys_aval.dtype.impl, keys, keys_aval.ndim, shape=shape)
|
||||
|
||||
converted_impl = _convert_jax_impl(
|
||||
impl_wrapper, multiple_results=False, with_physical_avals=True,
|
||||
extra_name_stack="random_split")
|
||||
return converted_impl(
|
||||
keys, count=count, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
keys, shape=shape, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[prng.random_split_p] = _random_split_impl
|
||||
|
||||
|
@ -612,6 +612,23 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
def make_key(self, seed):
|
||||
return random.threefry2x32_key(seed)
|
||||
|
||||
@jtu.sample_product(
|
||||
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
|
||||
)
|
||||
def test_split_size_shape(self, num):
|
||||
key = self.make_key(0)
|
||||
if num is None:
|
||||
key_split = jax.random.split(key)
|
||||
else:
|
||||
key_split = jax.random.split(key, num)
|
||||
|
||||
if num is None:
|
||||
self.assertEqual(key_split.shape, (2, *key.shape))
|
||||
elif type(num) is tuple:
|
||||
self.assertEqual(key_split.shape, (*num, *key.shape))
|
||||
else:
|
||||
self.assertEqual(key_split.shape, (num, *key.shape))
|
||||
|
||||
@jtu.sample_product(dtype=jtu.dtypes.floating)
|
||||
def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
|
||||
bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
|
||||
@ -2039,12 +2056,9 @@ def _double_threefry_seed(seed):
|
||||
return jnp.vstack([threefry_seed(s1),
|
||||
threefry_seed(s2)])
|
||||
|
||||
def _double_threefry_split(key, num):
|
||||
split0 = threefry_split(key[0], num)
|
||||
split1 = threefry_split(key[1], num)
|
||||
merge = jnp.vstack([jnp.expand_dims(split0, axis=0),
|
||||
jnp.expand_dims(split1, axis=0)])
|
||||
return merge.transpose((1, 0, 2))
|
||||
def _double_threefry_split(key, shape):
|
||||
return vmap(
|
||||
threefry_split, (0, None), len(shape))(key, shape)
|
||||
|
||||
def _double_threefry_random_bits(key, bit_width, shape):
|
||||
bits0 = threefry_random_bits(key[0], bit_width, shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user