mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
mhlo.rng op with distribution attr
Aligns with the XLA kRng which takes a distribution as an attribute instead of having separate ops for each distribution. PiperOrigin-RevId: 459389874
This commit is contained in:
parent
89a6766964
commit
5d379bba9e
@ -4243,7 +4243,11 @@ def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
aval_out, = ctx.avals_out
|
||||
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
|
||||
canonicalize_types=False)
|
||||
return mhlo.RngUniformOp(a, b, shape).results
|
||||
if jax._src.lib.mlir_api_version <= 22:
|
||||
return mhlo.RngUniformOp(a, b, shape).results
|
||||
else:
|
||||
return mhlo.RngOp(a, b, shape,
|
||||
mhlo.RngDistributionAttr.get('UNIFORM')).results
|
||||
|
||||
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user