mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
This commit is contained in:
parent
ce85b89884
commit
a527aba646
@ -155,6 +155,10 @@ class LoweringRuleContext:
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@property
|
||||
def forward_compatible(self):
|
||||
return self.lowering_context.forward_compatible
|
||||
|
||||
|
||||
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
|
||||
) -> TPUMemorySpace:
|
||||
@ -1976,8 +1980,11 @@ def _convert_element_type_lowering_rule(
|
||||
# This case triggers when casting signed to unsigned or vice versa.
|
||||
return x
|
||||
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
|
||||
elif _from(floating) and _to(signed) and both_32bit:
|
||||
return arith.fptosi(out_type, x)
|
||||
elif _from(floating) and _to(signed):
|
||||
# TODO(apaszke): Remove once a month has passed, along with the
|
||||
# _convert_helper float -> signed conversion above.
|
||||
if not ctx.forward_compatible or both_32bit:
|
||||
return arith.fptosi(out_type, x)
|
||||
elif _from(signed) and _to(floating) and both_32bit:
|
||||
return arith.sitofp(out_type, x)
|
||||
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
|
||||
|
@ -40,6 +40,7 @@
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
@ -539,6 +540,87 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = cast<arith::FPToSIOp>(raw_op);
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
|
||||
auto src_vty = dyn_cast<VectorType>(op.getIn().getType());
|
||||
auto dst_vty = dyn_cast<VectorType>(op.getType());
|
||||
if (static_cast<bool>(src_vty) != static_cast<bool>(dst_vty)) {
|
||||
return op.emitOpError("Vector/scalar mismatch between input and output");
|
||||
}
|
||||
bool is_vector = static_cast<bool>(src_vty);
|
||||
unsigned src_bitwidth, dst_bitwidth;
|
||||
if (is_vector) {
|
||||
src_bitwidth = src_vty.getElementTypeBitWidth();
|
||||
dst_bitwidth = dst_vty.getElementTypeBitWidth();
|
||||
} else {
|
||||
src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth();
|
||||
dst_bitwidth = op.getType().getIntOrFloatBitWidth();
|
||||
}
|
||||
if (dst_bitwidth > 32) {
|
||||
return op.emitOpError("Target bitwidth too large");
|
||||
}
|
||||
Value x = op.getIn();
|
||||
// Upcast the input to f32.
|
||||
if (src_bitwidth < 32) {
|
||||
if (is_vector) {
|
||||
x = builder.create<arith::ExtFOp>(
|
||||
VectorType::get(src_vty.getShape(), builder.getF32Type()), x);
|
||||
} else {
|
||||
x = builder.create<arith::ExtFOp>(builder.getF32Type(), x);
|
||||
}
|
||||
}
|
||||
if (dst_bitwidth < 32) {
|
||||
if (!ctx.compatibility_mode) {
|
||||
return op.emitOpError(
|
||||
"On this target only float-to-integer conversions can only happen on "
|
||||
"32-bit values. Enable compatibility mode or upcast to float32.");
|
||||
}
|
||||
// Need to clip values to match XLA
|
||||
auto clip = [&](Value x, Value low, Value high) {
|
||||
auto is_small =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, low);
|
||||
x = builder.create<arith::SelectOp>(is_small, low, x);
|
||||
auto is_large =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, high);
|
||||
x = builder.create<arith::SelectOp>(is_large, high, x);
|
||||
return x;
|
||||
};
|
||||
auto minval = builder.getF32FloatAttr(
|
||||
APInt::getSignedMinValue(dst_bitwidth).getSExtValue());
|
||||
auto maxval = builder.getF32FloatAttr(
|
||||
APInt::getSignedMaxValue(dst_bitwidth).getSExtValue());
|
||||
if (is_vector) {
|
||||
auto x_vty = cast<VectorType>(x.getType());
|
||||
x = clip(x, getFullVector(builder, x_vty, minval),
|
||||
getFullVector(builder, x_vty, maxval));
|
||||
} else {
|
||||
auto f32 = builder.getF32Type();
|
||||
x = clip(x, builder.create<arith::ConstantOp>(f32, minval),
|
||||
builder.create<arith::ConstantOp>(f32, maxval));
|
||||
}
|
||||
}
|
||||
if (is_vector) {
|
||||
x = builder.create<arith::FPToSIOp>(
|
||||
VectorType::get(src_vty.getShape(), builder.getI32Type()), x);
|
||||
} else {
|
||||
x = builder.create<arith::FPToSIOp>(builder.getI32Type(), x);
|
||||
}
|
||||
if (dst_bitwidth < 32) {
|
||||
if (!ctx.compatibility_mode) {
|
||||
return op.emitOpError(
|
||||
"On this target only float-to-integer conversions can only happen on "
|
||||
"32-bit values. Enable compatibility mode or cast to int32 and "
|
||||
"truncate later.");
|
||||
}
|
||||
x = builder.create<arith::TruncIOp>(op.getType(), x);
|
||||
}
|
||||
op.replaceAllUsesWith(x);
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
|
||||
@ -574,6 +656,7 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
|
||||
{vector::MultiDimReductionOp::getOperationName(),
|
||||
canonicalize_multi_dim_reduction},
|
||||
{arith::SelectOp::getOperationName(), canonicalize_select},
|
||||
{arith::FPToSIOp::getOperationName(), canonicalize_fptosi},
|
||||
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
|
||||
return *rules;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user