partitionable threefry PRNG random bits implementation

the cost is 2x overgeneration of bits

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2022-10-25 08:13:55 -07:00 committed by Matthew Johnson
parent 9abacbdb56
commit c8b9280fb3
3 changed files with 143 additions and 4 deletions

View File

@ -734,6 +734,22 @@ default_prng_impl = config.define_enum_state(
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))
threefry_partitionable = config.define_bool_state(
name='jax_threefry_partitionable',
default=False,
upgrade=True,
help=('Enables internal threefry PRNG implementation changes that '
'render it automatically partitionable in some cases. For use '
'with pjit and/or jax_array=True. Without this flag, using the '
'standard jax.random pseudo-random number generation may result '
'in extraneous communication and/or redundant distributed '
'computation. With this flag, the communication overheads disappear '
'in some cases.\n'
'\n'
'Currently, setting this flag does not change random values '
'generated by a given PRNG key value. However, its behavior may '
'change in the future.'))
enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,

View File

@ -1049,14 +1049,57 @@ def _threefry_fold_in(key, data):
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
return _threefry_random_bits(key, bit_width, shape)
@partial(jit, static_argnums=(1, 2), inline=True)
def _threefry_random_bits(key: jnp.ndarray, bit_width, shape):
if not _is_threefry_prng_key(key):
raise TypeError("threefry_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
if (config.jax_threefry_partitionable and bit_width == 32 and
not any(core.is_special_dim_size(d) for d in shape)):
return _threefry_random_bits_partitionable(key, bit_width, shape)
else:
return _threefry_random_bits_original(key, bit_width, shape)
def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
size = prod(shape)
n, r = divmod(bit_width * size, 32)
if r > 0:
n += 1
even_size = n + (n % 2)
if not shape:
counts = jnp.arange(n, dtype=jnp.uint32).reshape(shape)
else:
iotas = [lax.broadcasted_iota(jnp.dtype('uint32'), shape, i)
for i in range(len(shape))]
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
counts = sum(s * i for i, s in zip(iotas, strides)) # type: ignore
circ0 = counts % (even_size // 2)
circ1 = (circ0 + even_size // 2) % n
k1, k2 = key
bits_xx, bits_yy = threefry2x32_p.bind(k1, k2, circ0, circ1)
dtype = UINT_DTYPES[bit_width]
if bit_width == 64:
assert n == even_size
assert False # broken...
bits_x, bits_y = bits_xx[:size // 2], bits_yy[:size // 2]
bits_x = lax.convert_element_type(bits_x, dtype)
bits_y = lax.convert_element_type(bits_y, dtype)
bits = lax.shift_left(bits_x, dtype(32)) | bits_y
else:
bits = jnp.where(counts < even_size // 2, bits_xx, bits_yy)
if bit_width != 32:
assert False # broken...
bits = bits.view(dtype)
return bits
@partial(jit, static_argnums=(1, 2), inline=True)
def _threefry_random_bits_original(key: jnp.ndarray, bit_width, shape):
size = prod(shape)
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
# polymorphism

View File

@ -31,6 +31,7 @@ from jax.experimental.pjit import pjit
from jax.experimental import PartitionSpec as P
from jax._src import sharding
from jax._src import array
from jax._src import prng
from jax.experimental import maps
from jax.config import config
@ -765,5 +766,84 @@ class ShardingTest(jtu.JaxTestCase):
repr(out.sharding) # doesn't crash
@jtu.with_config(jax_array=True)
class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected
@parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
def test_random_bits_is_pure_map_1d(self, num_devices):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
mesh = jtu.create_global_mesh((num_devices,), ('x',))
s = sharding.MeshPspecSharding(mesh, P('x'))
n = num_devices ** 2
global_x = jnp.arange(n).astype('uint32')
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
jax.config.update('jax_threefry_partitionable', True)
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
self.assertIn( f'[{n}]', unopt_txt)
self.assertNotIn(f'[{n}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
# check against single-device previous implementation reference
jax.config.update('jax_threefry_partitionable', False)
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref2)
@parameterized.named_parameters(
{"testcase_name": f"_{mesh_shape}_{pspec}",
"mesh_shape": mesh_shape, "pspec": pspec}
for mesh_shape in [(3, 2), (4, 2), (2, 3)]
for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
@jax.jit
def f(x):
bits = prng.threefry_random_bits(jnp.array([0, 0], dtype='uint32'),
32, x.shape)
return bits + x
global_shape = tuple(np.square(mesh_shape))
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
s = sharding.MeshPspecSharding(mesh, pspec)
n = prod(global_shape)
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
# check computation is fully partitioned and without any communication
jax.config.update('jax_threefry_partitionable', True)
unopt_txt = f.lower(x).as_text(dialect='hlo')
opt_txt = f.lower(x).compile().as_text()
global_shape_fmt = ','.join(str(x) for x in global_shape)
self.assertIn( f'[{global_shape_fmt}]', unopt_txt)
self.assertNotIn(f'[{global_shape_fmt}]', opt_txt)
self.assertNotIn('all-reduce', opt_txt)
self.assertNotIn('collective-permute', opt_txt)
# check against single-device reference
y = f(x)
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)
# check against single-device previous implementation reference
jax.config.update('jax_threefry_partitionable', False)
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref2)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())