[mhlo] Add result type inference for mhlo.reduce_precision.

PiperOrigin-RevId: 452642523
This commit is contained in:
Xin Zhou 2022-06-02 16:03:20 -07:00 committed by jax authors
parent 6c89e90808
commit bc877faae0
2 changed files with 17 additions and 7 deletions

View File

@ -3666,9 +3666,13 @@ masking.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
if jax._src.lib.mlir_api_version >= 21:
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
else:
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)

View File

@ -712,10 +712,16 @@ def _select_and_gather_add_lowering(
def pack(a, b):
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
if jax._src.lib.mlir_api_version >= 21:
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
else:
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
b = mhlo.ShiftRightLogicalOp(