wip bits-changing partitionable rng based on iota raveling

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Roy Frostig 2022-10-28 14:17:34 -07:00
parent f9e7629c3f
commit 63bfb87edf

View File

@ -980,6 +980,50 @@ mlir.register_lowering(
platform='rocm')
def iota_32x2_shape(shape):
if len(shape) == 0:
return (jnp.zeros((), np.dtype('uint32')),) * 2
return iota_32x2_shape_p.bind(shape=shape)
iota_32x2_shape_p = core.Primitive('iota_32x2_shape')
iota_32x2_shape_p.multiple_results = True
@iota_32x2_shape_p.def_abstract_eval
def iota_32x2_shape_abstract_eval(*, shape):
return (core.ShapedArray(shape, np.dtype('uint32')),) * 2
def iota_32x2_shape_lowering(ctx, *, shape):
def _add(x, y):
return mlir.mhlo.AddOp(x, y).result
def _mul(x, y):
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')),
canonicalize_types=False)
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
return mlir.mhlo.MulOp(x_bcast, y).result
def _sum(xs):
return reduce(_add, xs)
assert len(shape) > 0
aval_out, _ = ctx.avals_out
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
mlir.i64_attr(dimension)).result
for dimension in range(len(shape))]
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
counts = _sum(_mul(s, i) for i, s in zip(iotas, strides)) # type: ignore
counts_shifted = mlir.mhlo.ShiftRightLogicalOp(
counts, mlir.ir_constant(np.full(shape, 32, np.dtype('uint64')),
canonicalize_types=False)).result
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
counts_shifted).result
return (counts_hi, counts_lo)
mlir.register_lowering(iota_32x2_shape_p, iota_32x2_shape_lowering)
@partial(jit, inline=True)
def threefry_2x32(keypair, count):
"""Apply the Threefry 2x32 hash.
@ -1039,7 +1083,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
if (config.jax_threefry_partitionable and bit_width == 32 and
if (config.jax_threefry_partitionable and
not any(core.is_special_dim_size(d) for d in shape)):
return _threefry_random_bits_partitionable(key, bit_width, shape)
else: