mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
wip bits-changing partitionable rng based on iota raveling
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
f9e7629c3f
commit
63bfb87edf
@ -980,6 +980,50 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def iota_32x2_shape(shape):
|
||||
if len(shape) == 0:
|
||||
return (jnp.zeros((), np.dtype('uint32')),) * 2
|
||||
return iota_32x2_shape_p.bind(shape=shape)
|
||||
|
||||
iota_32x2_shape_p = core.Primitive('iota_32x2_shape')
|
||||
iota_32x2_shape_p.multiple_results = True
|
||||
|
||||
@iota_32x2_shape_p.def_abstract_eval
|
||||
def iota_32x2_shape_abstract_eval(*, shape):
|
||||
return (core.ShapedArray(shape, np.dtype('uint32')),) * 2
|
||||
|
||||
def iota_32x2_shape_lowering(ctx, *, shape):
|
||||
def _add(x, y):
|
||||
return mlir.mhlo.AddOp(x, y).result
|
||||
|
||||
def _mul(x, y):
|
||||
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')),
|
||||
canonicalize_types=False)
|
||||
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
|
||||
return mlir.mhlo.MulOp(x_bcast, y).result
|
||||
|
||||
def _sum(xs):
|
||||
return reduce(_add, xs)
|
||||
|
||||
assert len(shape) > 0
|
||||
aval_out, _ = ctx.avals_out
|
||||
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
|
||||
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
|
||||
mlir.i64_attr(dimension)).result
|
||||
for dimension in range(len(shape))]
|
||||
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
|
||||
counts = _sum(_mul(s, i) for i, s in zip(iotas, strides)) # type: ignore
|
||||
counts_shifted = mlir.mhlo.ShiftRightLogicalOp(
|
||||
counts, mlir.ir_constant(np.full(shape, 32, np.dtype('uint64')),
|
||||
canonicalize_types=False)).result
|
||||
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
|
||||
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
|
||||
counts_shifted).result
|
||||
return (counts_hi, counts_lo)
|
||||
|
||||
mlir.register_lowering(iota_32x2_shape_p, iota_32x2_shape_lowering)
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def threefry_2x32(keypair, count):
|
||||
"""Apply the Threefry 2x32 hash.
|
||||
@ -1039,7 +1083,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
if bit_width not in (8, 16, 32, 64):
|
||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||
|
||||
if (config.jax_threefry_partitionable and bit_width == 32 and
|
||||
if (config.jax_threefry_partitionable and
|
||||
not any(core.is_special_dim_size(d) for d in shape)):
|
||||
return _threefry_random_bits_partitionable(key, bit_width, shape)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user