mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
PiperOrigin-RevId: 715435319
This commit is contained in:
parent
b6acb9cb7a
commit
f1b894d14a
@ -112,7 +112,6 @@ class LoweringContext:
|
||||
replace = dataclasses.replace
|
||||
traceback_caches: mlir.TracebackCaches
|
||||
for_verification: bool
|
||||
forward_compatible: bool
|
||||
|
||||
@property
|
||||
def grid_rank(self):
|
||||
@ -142,10 +141,6 @@ class LoweringRuleContext:
|
||||
|
||||
replace = dataclasses.replace
|
||||
|
||||
@property
|
||||
def forward_compatible(self):
|
||||
return self.lowering_context.forward_compatible
|
||||
|
||||
|
||||
def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
|
||||
) -> TPUMemorySpace:
|
||||
@ -559,7 +554,6 @@ def lower_jaxpr_to_module(
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
name="main",
|
||||
for_verification=for_verification,
|
||||
forward_compatible=lowering_context.is_forward_compat(),
|
||||
)
|
||||
m.body.append(func_op)
|
||||
sym_tab.insert(func_op)
|
||||
@ -586,7 +580,6 @@ def lower_jaxpr_to_module(
|
||||
name=func_name,
|
||||
mosaic_grid_mapping=mosaic_grid_mapping,
|
||||
for_verification=for_verification,
|
||||
forward_compatible=lowering_context.is_forward_compat(),
|
||||
)
|
||||
assert mlir_func.verify(), mlir_func
|
||||
block_shape = [
|
||||
@ -635,7 +628,6 @@ def lower_jaxpr_to_transform_func(
|
||||
name: str,
|
||||
mosaic_grid_mapping: MosaicGridMapping,
|
||||
for_verification: bool,
|
||||
forward_compatible: bool,
|
||||
) -> func.FuncOp:
|
||||
num_grid = len(mosaic_grid_mapping.grid_types)
|
||||
arg_types = [
|
||||
@ -670,7 +662,6 @@ def lower_jaxpr_to_transform_func(
|
||||
mesh_context=mesh_context,
|
||||
traceback_caches=mlir.TracebackCaches(),
|
||||
for_verification=for_verification,
|
||||
forward_compatible=forward_compatible,
|
||||
)
|
||||
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
|
||||
*scalar_prefetch)
|
||||
@ -698,7 +689,6 @@ def lower_jaxpr_to_func(
|
||||
mosaic_grid_mapping: MosaicGridMapping,
|
||||
name: str,
|
||||
for_verification: bool,
|
||||
forward_compatible: bool,
|
||||
) -> func.FuncOp:
|
||||
num_grid = len(mosaic_grid_mapping.grid_types)
|
||||
num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
|
||||
@ -737,7 +727,6 @@ def lower_jaxpr_to_func(
|
||||
mesh_context=mesh_context,
|
||||
traceback_caches=mlir.TracebackCaches(),
|
||||
for_verification=for_verification,
|
||||
forward_compatible=forward_compatible,
|
||||
)
|
||||
return jaxpr_subcomp(
|
||||
lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
|
||||
@ -1843,10 +1832,7 @@ def _convert_element_type_lowering_rule(
|
||||
# This case triggers when casting signed to unsigned or vice versa.
|
||||
return x
|
||||
# TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer.
|
||||
elif _from(floating) and _to(signed):
|
||||
# TODO(apaszke): Remove once a month has passed, along with the
|
||||
# _convert_helper float -> signed conversion above.
|
||||
if not ctx.forward_compatible or both_32bit:
|
||||
elif _from(floating) and _to(signed) and both_32bit:
|
||||
return arith.fptosi(out_type, x)
|
||||
elif _from(signed) and _to(floating) and both_32bit:
|
||||
return arith.sitofp(out_type, x)
|
||||
|
@ -40,7 +40,6 @@
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
@ -540,87 +539,6 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = cast<arith::FPToSIOp>(raw_op);
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation());
|
||||
auto src_vty = dyn_cast<VectorType>(op.getIn().getType());
|
||||
auto dst_vty = dyn_cast<VectorType>(op.getType());
|
||||
if (static_cast<bool>(src_vty) != static_cast<bool>(dst_vty)) {
|
||||
return op.emitOpError("Vector/scalar mismatch between input and output");
|
||||
}
|
||||
bool is_vector = static_cast<bool>(src_vty);
|
||||
unsigned src_bitwidth, dst_bitwidth;
|
||||
if (is_vector) {
|
||||
src_bitwidth = src_vty.getElementTypeBitWidth();
|
||||
dst_bitwidth = dst_vty.getElementTypeBitWidth();
|
||||
} else {
|
||||
src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth();
|
||||
dst_bitwidth = op.getType().getIntOrFloatBitWidth();
|
||||
}
|
||||
if (dst_bitwidth > 32) {
|
||||
return op.emitOpError("Target bitwidth too large");
|
||||
}
|
||||
Value x = op.getIn();
|
||||
// Upcast the input to f32.
|
||||
if (src_bitwidth < 32) {
|
||||
if (is_vector) {
|
||||
x = builder.create<arith::ExtFOp>(
|
||||
VectorType::get(src_vty.getShape(), builder.getF32Type()), x);
|
||||
} else {
|
||||
x = builder.create<arith::ExtFOp>(builder.getF32Type(), x);
|
||||
}
|
||||
}
|
||||
if (dst_bitwidth < 32) {
|
||||
if (!ctx.compatibility_mode) {
|
||||
return op.emitOpError(
|
||||
"On this target only float-to-integer conversions can only happen on "
|
||||
"32-bit values. Enable compatibility mode or upcast to float32.");
|
||||
}
|
||||
// Need to clip values to match XLA
|
||||
auto clip = [&](Value x, Value low, Value high) {
|
||||
auto is_small =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, low);
|
||||
x = builder.create<arith::SelectOp>(is_small, low, x);
|
||||
auto is_large =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, high);
|
||||
x = builder.create<arith::SelectOp>(is_large, high, x);
|
||||
return x;
|
||||
};
|
||||
auto minval = builder.getF32FloatAttr(
|
||||
APInt::getSignedMinValue(dst_bitwidth).getSExtValue());
|
||||
auto maxval = builder.getF32FloatAttr(
|
||||
APInt::getSignedMaxValue(dst_bitwidth).getSExtValue());
|
||||
if (is_vector) {
|
||||
auto x_vty = cast<VectorType>(x.getType());
|
||||
x = clip(x, getFullVector(builder, x_vty, minval),
|
||||
getFullVector(builder, x_vty, maxval));
|
||||
} else {
|
||||
auto f32 = builder.getF32Type();
|
||||
x = clip(x, builder.create<arith::ConstantOp>(f32, minval),
|
||||
builder.create<arith::ConstantOp>(f32, maxval));
|
||||
}
|
||||
}
|
||||
if (is_vector) {
|
||||
x = builder.create<arith::FPToSIOp>(
|
||||
VectorType::get(src_vty.getShape(), builder.getI32Type()), x);
|
||||
} else {
|
||||
x = builder.create<arith::FPToSIOp>(builder.getI32Type(), x);
|
||||
}
|
||||
if (dst_bitwidth < 32) {
|
||||
if (!ctx.compatibility_mode) {
|
||||
return op.emitOpError(
|
||||
"On this target only float-to-integer conversions can only happen on "
|
||||
"32-bit values. Enable compatibility mode or cast to int32 and "
|
||||
"truncate later.");
|
||||
}
|
||||
x = builder.create<arith::TruncIOp>(op.getType(), x);
|
||||
}
|
||||
op.replaceAllUsesWith(x);
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
|
||||
@ -656,7 +574,6 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
|
||||
{vector::MultiDimReductionOp::getOperationName(),
|
||||
canonicalize_multi_dim_reduction},
|
||||
{arith::SelectOp::getOperationName(), canonicalize_select},
|
||||
{arith::FPToSIOp::getOperationName(), canonicalize_fptosi},
|
||||
{tpu::RepeatOp::getOperationName(), canonicalize_repeat}};
|
||||
return *rules;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user