Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a

PiperOrigin-RevId: 715435319
This commit is contained in:
George Necula 2025-01-14 10:29:22 -08:00 committed by jax authors
parent b6acb9cb7a
commit f1b894d14a
2 changed files with 2 additions and 99 deletions

View File

@ -112,7 +112,6 @@ class LoweringContext:
replace = dataclasses.replace
traceback_caches: mlir.TracebackCaches
for_verification: bool
forward_compatible: bool
@property
def grid_rank(self):
@ -142,10 +141,6 @@ class LoweringRuleContext:
replace = dataclasses.replace
@property
def forward_compatible(self):
return self.lowering_context.forward_compatible
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
) -> TPUMemorySpace:
@ -559,7 +554,6 @@ def lower_jaxpr_to_module(
mosaic_grid_mapping=mosaic_grid_mapping,
name="main",
for_verification=for_verification,
forward_compatible=lowering_context.is_forward_compat(),
)
m.body.append(func_op)
sym_tab.insert(func_op)
@ -586,7 +580,6 @@ def lower_jaxpr_to_module(
name=func_name,
mosaic_grid_mapping=mosaic_grid_mapping,
for_verification=for_verification,
forward_compatible=lowering_context.is_forward_compat(),
)
assert mlir_func.verify(), mlir_func
block_shape = [
@ -635,7 +628,6 @@ def lower_jaxpr_to_transform_func(
name: str,
mosaic_grid_mapping: MosaicGridMapping,
for_verification: bool,
forward_compatible: bool,
) -> func.FuncOp:
num_grid = len(mosaic_grid_mapping.grid_types)
arg_types = [
@ -670,7 +662,6 @@ def lower_jaxpr_to_transform_func(
mesh_context=mesh_context,
traceback_caches=mlir.TracebackCaches(),
for_verification=for_verification,
forward_compatible=forward_compatible,
)
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
*scalar_prefetch)
@ -698,7 +689,6 @@ def lower_jaxpr_to_func(
mosaic_grid_mapping: MosaicGridMapping,
name: str,
for_verification: bool,
forward_compatible: bool,
) -> func.FuncOp:
num_grid = len(mosaic_grid_mapping.grid_types)
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
@ -737,7 +727,6 @@ def lower_jaxpr_to_func(
mesh_context=mesh_context,
traceback_caches=mlir.TracebackCaches(),
for_verification=for_verification,
forward_compatible=forward_compatible,
)
return jaxpr_subcomp(
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
@ -1843,10 +1832,7 @@ def _convert_element_type_lowering_rule(
# This case triggers when casting signed to unsigned or vice versa.
return x
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
elif _from(floating) and _to(signed):
# TODO(apaszke): Remove once a month has passed, along with the
# _convert_helper float -> signed conversion above.
if not ctx.forward_compatible or both_32bit:
elif _from(floating) and _to(signed) and both_32bit:
return arith.fptosi(out_type, x)
elif _from(signed) and _to(floating) and both_32bit:
return arith.sitofp(out_type, x)

View File

@ -40,7 +40,6 @@
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
namespace mlir::tpu {
@ -540,87 +539,6 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
return success();
}
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = cast<arith::FPToSIOp>(raw_op);
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
auto src_vty = dyn_cast<VectorType>(op.getIn().getType());
auto dst_vty = dyn_cast<VectorType>(op.getType());
if (static_cast<bool>(src_vty) != static_cast<bool>(dst_vty)) {
return op.emitOpError("Vector/scalar mismatch between input and output");
}
bool is_vector = static_cast<bool>(src_vty);
unsigned src_bitwidth, dst_bitwidth;
if (is_vector) {
src_bitwidth = src_vty.getElementTypeBitWidth();
dst_bitwidth = dst_vty.getElementTypeBitWidth();
} else {
src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth();
dst_bitwidth = op.getType().getIntOrFloatBitWidth();
}
if (dst_bitwidth > 32) {
return op.emitOpError("Target bitwidth too large");
}
Value x = op.getIn();
// Upcast the input to f32.
if (src_bitwidth < 32) {
if (is_vector) {
x = builder.create<arith::ExtFOp>(
VectorType::get(src_vty.getShape(), builder.getF32Type()), x);
} else {
x = builder.create<arith::ExtFOp>(builder.getF32Type(), x);
}
}
if (dst_bitwidth < 32) {
if (!ctx.compatibility_mode) {
return op.emitOpError(
"On this target only float-to-integer conversions can only happen on "
"32-bit values. Enable compatibility mode or upcast to float32.");
}
// Need to clip values to match XLA
auto clip = [&](Value x, Value low, Value high) {
auto is_small =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, low);
x = builder.create<arith::SelectOp>(is_small, low, x);
auto is_large =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, high);
x = builder.create<arith::SelectOp>(is_large, high, x);
return x;
};
auto minval = builder.getF32FloatAttr(
APInt::getSignedMinValue(dst_bitwidth).getSExtValue());
auto maxval = builder.getF32FloatAttr(
APInt::getSignedMaxValue(dst_bitwidth).getSExtValue());
if (is_vector) {
auto x_vty = cast<VectorType>(x.getType());
x = clip(x, getFullVector(builder, x_vty, minval),
getFullVector(builder, x_vty, maxval));
} else {
auto f32 = builder.getF32Type();
x = clip(x, builder.create<arith::ConstantOp>(f32, minval),
builder.create<arith::ConstantOp>(f32, maxval));
}
}
if (is_vector) {
x = builder.create<arith::FPToSIOp>(
VectorType::get(src_vty.getShape(), builder.getI32Type()), x);
} else {
x = builder.create<arith::FPToSIOp>(builder.getI32Type(), x);
}
if (dst_bitwidth < 32) {
if (!ctx.compatibility_mode) {
return op.emitOpError(
"On this target only float-to-integer conversions can only happen on "
"32-bit values. Enable compatibility mode or cast to int32 and "
"truncate later.");
}
x = builder.create<arith::TruncIOp>(op.getType(), x);
}
op.replaceAllUsesWith(x);
op.erase();
return success();
}
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
@ -656,7 +574,6 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
{vector::MultiDimReductionOp::getOperationName(),
canonicalize_multi_dim_reduction},
{arith::SelectOp::getOperationName(), canonicalize_select},
{arith::FPToSIOp::getOperationName(), canonicalize_fptosi},
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
return *rules;
}