mhlo.rng op with distribution attr

Aligns with the XLA kRng which takes a distribution as an attribute
instead of having separate ops for each distribution.

PiperOrigin-RevId: 459389874
This commit is contained in:
Anish Tondwalkar 2022-07-06 18:02:33 -07:00 committed by jax authors
parent 89a6766964
commit 5d379bba9e

View File

@ -4243,7 +4243,11 @@ def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
canonicalize_types=False)
return mhlo.RngUniformOp(a, b, shape).results
if jax._src.lib.mlir_api_version <= 22:
return mhlo.RngUniformOp(a, b, shape).results
else:
return mhlo.RngOp(a, b, shape,
mhlo.RngDistributionAttr.get('UNIFORM')).results
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)