diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 3ebc4deee..aabc2b6c7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -112,7 +112,6 @@ class LoweringContext: replace = dataclasses.replace traceback_caches: mlir.TracebackCaches for_verification: bool - forward_compatible: bool @property def grid_rank(self): @@ -142,10 +141,6 @@ class LoweringRuleContext: replace = dataclasses.replace - @property - def forward_compatible(self): - return self.lowering_context.forward_compatible - def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None ) -> TPUMemorySpace: @@ -559,7 +554,6 @@ def lower_jaxpr_to_module( mosaic_grid_mapping=mosaic_grid_mapping, name="main", for_verification=for_verification, - forward_compatible=lowering_context.is_forward_compat(), ) m.body.append(func_op) sym_tab.insert(func_op) @@ -586,7 +580,6 @@ def lower_jaxpr_to_module( name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, for_verification=for_verification, - forward_compatible=lowering_context.is_forward_compat(), ) assert mlir_func.verify(), mlir_func block_shape = [ @@ -635,7 +628,6 @@ def lower_jaxpr_to_transform_func( name: str, mosaic_grid_mapping: MosaicGridMapping, for_verification: bool, - forward_compatible: bool, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) arg_types = [ @@ -670,7 +662,6 @@ def lower_jaxpr_to_transform_func( mesh_context=mesh_context, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, - forward_compatible=forward_compatible, ) out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, *scalar_prefetch) @@ -698,7 +689,6 @@ def lower_jaxpr_to_func( mosaic_grid_mapping: MosaicGridMapping, name: str, for_verification: bool, - forward_compatible: bool, ) -> func.FuncOp: num_grid = len(mosaic_grid_mapping.grid_types) num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types) @@ -737,7 +727,6 @@ def lower_jaxpr_to_func( mesh_context=mesh_context, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, - forward_compatible=forward_compatible, ) return jaxpr_subcomp( lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch @@ -1843,11 +1832,8 @@ def _convert_element_type_lowering_rule( # This case triggers when casting signed to unsigned or vice versa. return x # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. - elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) + elif _from(floating) and _to(signed) and both_32bit: + return arith.fptosi(out_type, x) elif _from(signed) and _to(floating) and both_32bit: return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 11a587e5d..b05a91826 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -40,7 +40,6 @@ #include "mlir/include/mlir/IR/Value.h" #include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "jaxlib/mosaic/dialect/tpu/vreg_util.h" namespace mlir::tpu { @@ -540,87 +539,6 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx, return success(); } -LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, - Operation &raw_op) { - auto op = cast(raw_op); - ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); - auto src_vty = dyn_cast(op.getIn().getType()); - auto dst_vty = dyn_cast(op.getType()); - if (static_cast(src_vty) != static_cast(dst_vty)) { - return op.emitOpError("Vector/scalar mismatch between input and output"); - } - bool is_vector = static_cast(src_vty); - unsigned src_bitwidth, dst_bitwidth; - if (is_vector) { - src_bitwidth = src_vty.getElementTypeBitWidth(); - dst_bitwidth = dst_vty.getElementTypeBitWidth(); - } else { - src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth(); - dst_bitwidth = op.getType().getIntOrFloatBitWidth(); - } - if (dst_bitwidth > 32) { - return op.emitOpError("Target bitwidth too large"); - } - Value x = op.getIn(); - // Upcast the input to f32. - if (src_bitwidth < 32) { - if (is_vector) { - x = builder.create( - VectorType::get(src_vty.getShape(), builder.getF32Type()), x); - } else { - x = builder.create(builder.getF32Type(), x); - } - } - if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or upcast to float32."); - } - // Need to clip values to match XLA - auto clip = [&](Value x, Value low, Value high) { - auto is_small = - builder.create(arith::CmpFPredicate::OLT, x, low); - x = builder.create(is_small, low, x); - auto is_large = - builder.create(arith::CmpFPredicate::OGT, x, high); - x = builder.create(is_large, high, x); - return x; - }; - auto minval = builder.getF32FloatAttr( - APInt::getSignedMinValue(dst_bitwidth).getSExtValue()); - auto maxval = builder.getF32FloatAttr( - APInt::getSignedMaxValue(dst_bitwidth).getSExtValue()); - if (is_vector) { - auto x_vty = cast(x.getType()); - x = clip(x, getFullVector(builder, x_vty, minval), - getFullVector(builder, x_vty, maxval)); - } else { - auto f32 = builder.getF32Type(); - x = clip(x, builder.create(f32, minval), - builder.create(f32, maxval)); - } - } - if (is_vector) { - x = builder.create( - VectorType::get(src_vty.getShape(), builder.getI32Type()), x); - } else { - x = builder.create(builder.getI32Type(), x); - } - if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or cast to int32 and " - "truncate later."); - } - x = builder.create(op.getType(), x); - } - op.replaceAllUsesWith(x); - op.erase(); - return success(); -} - LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = dyn_cast(raw_op); @@ -656,7 +574,6 @@ const llvm::StringMap &rules() { {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, {arith::SelectOp::getOperationName(), canonicalize_select}, - {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; }