Reverts f1b894d14a28ac22a037fb79177b991275c75a18

PiperOrigin-RevId: 716653711
This commit is contained in:
jax authors 2025-01-17 06:59:55 -08:00
parent ce85b89884
commit a527aba646
2 changed files with 92 additions and 2 deletions

View File

@ -155,6 +155,10 @@ class LoweringRuleContext:
replace = dataclasses.replace
@property
def forward_compatible(self):
return self.lowering_context.forward_compatible
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
) -> TPUMemorySpace:
@ -1976,8 +1980,11 @@ def _convert_element_type_lowering_rule(
# This case triggers when casting signed to unsigned or vice versa.
return x
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
elif _from(floating) and _to(signed) and both_32bit:
return arith.fptosi(out_type, x)
elif _from(floating) and _to(signed):
# TODO(apaszke): Remove once a month has passed, along with the
# _convert_helper float -> signed conversion above.
if not ctx.forward_compatible or both_32bit:
return arith.fptosi(out_type, x)
elif _from(signed) and _to(floating) and both_32bit:
return arith.sitofp(out_type, x)
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:

View File

@ -40,6 +40,7 @@
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
namespace mlir::tpu {
@ -539,6 +540,87 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
return success();
}
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = cast<arith::FPToSIOp>(raw_op);
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
auto src_vty = dyn_cast<VectorType>(op.getIn().getType());
auto dst_vty = dyn_cast<VectorType>(op.getType());
if (static_cast<bool>(src_vty) != static_cast<bool>(dst_vty)) {
return op.emitOpError("Vector/scalar mismatch between input and output");
}
bool is_vector = static_cast<bool>(src_vty);
unsigned src_bitwidth, dst_bitwidth;
if (is_vector) {
src_bitwidth = src_vty.getElementTypeBitWidth();
dst_bitwidth = dst_vty.getElementTypeBitWidth();
} else {
src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth();
dst_bitwidth = op.getType().getIntOrFloatBitWidth();
}
if (dst_bitwidth > 32) {
return op.emitOpError("Target bitwidth too large");
}
Value x = op.getIn();
// Upcast the input to f32.
if (src_bitwidth < 32) {
if (is_vector) {
x = builder.create<arith::ExtFOp>(
VectorType::get(src_vty.getShape(), builder.getF32Type()), x);
} else {
x = builder.create<arith::ExtFOp>(builder.getF32Type(), x);
}
}
if (dst_bitwidth < 32) {
if (!ctx.compatibility_mode) {
return op.emitOpError(
"On this target only float-to-integer conversions can only happen on "
"32-bit values. Enable compatibility mode or upcast to float32.");
}
// Need to clip values to match XLA
auto clip = [&](Value x, Value low, Value high) {
auto is_small =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, low);
x = builder.create<arith::SelectOp>(is_small, low, x);
auto is_large =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, high);
x = builder.create<arith::SelectOp>(is_large, high, x);
return x;
};
auto minval = builder.getF32FloatAttr(
APInt::getSignedMinValue(dst_bitwidth).getSExtValue());
auto maxval = builder.getF32FloatAttr(
APInt::getSignedMaxValue(dst_bitwidth).getSExtValue());
if (is_vector) {
auto x_vty = cast<VectorType>(x.getType());
x = clip(x, getFullVector(builder, x_vty, minval),
getFullVector(builder, x_vty, maxval));
} else {
auto f32 = builder.getF32Type();
x = clip(x, builder.create<arith::ConstantOp>(f32, minval),
builder.create<arith::ConstantOp>(f32, maxval));
}
}
if (is_vector) {
x = builder.create<arith::FPToSIOp>(
VectorType::get(src_vty.getShape(), builder.getI32Type()), x);
} else {
x = builder.create<arith::FPToSIOp>(builder.getI32Type(), x);
}
if (dst_bitwidth < 32) {
if (!ctx.compatibility_mode) {
return op.emitOpError(
"On this target only float-to-integer conversions can only happen on "
"32-bit values. Enable compatibility mode or cast to int32 and "
"truncate later.");
}
x = builder.create<arith::TruncIOp>(op.getType(), x);
}
op.replaceAllUsesWith(x);
op.erase();
return success();
}
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
@ -574,6 +656,7 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
{vector::MultiDimReductionOp::getOperationName(),
canonicalize_multi_dim_reduction},
{arith::SelectOp::getOperationName(), canonicalize_select},
{arith::FPToSIOp::getOperationName(), canonicalize_fptosi},
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
return *rules;
}