mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[mhlo] Add result type inference for mhlo.reduce_precision.
PiperOrigin-RevId: 452642523
This commit is contained in:
parent
6c89e90808
commit
bc877faae0
@ -3666,9 +3666,13 @@ masking.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
|
||||
mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
if jax._src.lib.mlir_api_version >= 21:
|
||||
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
else:
|
||||
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
|
||||
mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
|
||||
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
||||
|
||||
|
@ -712,10 +712,16 @@ def _select_and_gather_add_lowering(
|
||||
def pack(a, b):
|
||||
a_dims = ir.RankedTensorType(a.type).shape
|
||||
b_dims = ir.RankedTensorType(b.type).shape
|
||||
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
if jax._src.lib.mlir_api_version >= 21:
|
||||
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
else:
|
||||
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
b = mhlo.ShiftRightLogicalOp(
|
||||
|
Loading…
x
Reference in New Issue
Block a user