mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
This commit is contained in:
parent
a43edb4644
commit
543dd94762
@ -371,6 +371,7 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
|
||||
}
|
||||
|
||||
// Integer packs are always signed at the moment.
|
||||
// Float to integer packing rounds to nearest even.
|
||||
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
|
||||
let arguments = (ins
|
||||
Variadic<TPU_Vreg>:$sources,
|
||||
@ -414,6 +415,26 @@ def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [
|
||||
I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">,
|
||||
I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">,
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::tpu";
|
||||
}
|
||||
|
||||
def TPU_RoundingModeEnum : EnumAttr<TPU_Dialect, TPU_RoundingMode, "rounding_mode"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Internal operation. All arith.fptosi operations that change the bitwidth
|
||||
// must be canonicalized to this operation.
|
||||
def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> {
|
||||
let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode);
|
||||
let results = (outs AnyVectorOfAnyRank:$output);
|
||||
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> {
|
||||
let parameters = (ins
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
@ -1053,6 +1054,16 @@ LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FPToSIOp::canonicalize(FPToSIOp op, PatternRewriter &rewriter) {
|
||||
if (auto round_op = op.getInput().getDefiningOp<mlir::math::RoundEvenOp>()) {
|
||||
rewriter.replaceOpWithNewOp<tpu::FPToSIOp>(
|
||||
op, op.getType(), round_op.getOperand(),
|
||||
tpu::RoundingMode::kToNearestEven);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConcatenateOp::verify() {
|
||||
auto dimension = getDimension();
|
||||
if (getOperands().size() < 2) {
|
||||
|
@ -54,6 +54,7 @@
|
||||
#include "llvm/include/llvm/Support/LogicalResult.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/Builders.h"
|
||||
@ -859,7 +860,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
const VectorLayout &layout_in,
|
||||
const VectorLayout &layout_out) {
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
|
||||
auto source = cast<TypedValue<VectorType>>(op.getIn());
|
||||
auto source = cast<TypedValue<VectorType>>(op.getOperand());
|
||||
auto result_ty = cast<VectorType>(op.getResult().getType());
|
||||
auto output_vregs_shape =
|
||||
layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape);
|
||||
@ -1062,6 +1063,45 @@ LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
|
||||
*layouts_out.front());
|
||||
}
|
||||
|
||||
LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_in.front().has_value());
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
auto &layout_in = *layouts_in.front();
|
||||
auto &layout_out = *layouts_out.front();
|
||||
if (layout_in.bitwidth() == layout_out.bitwidth()) {
|
||||
return elementwise_op_rule(ctx, op, layouts_in, layouts_out);
|
||||
} else if (layout_in.bitwidth() > layout_out.bitwidth()) {
|
||||
// FPToSI semantics require rounding towards zero, but packing instructions
|
||||
// use rounding towards nearest even. We need to insert explicit rounding,
|
||||
// unless the input is already rounded to nearest even.
|
||||
auto fptosi_op = cast<tpu::FPToSIOp>(op);
|
||||
switch (fptosi_op.getRoundingMode()) {
|
||||
case tpu::RoundingMode::kToNearestEven:
|
||||
break; // That is the mode used by tpu.pack_subelements.
|
||||
case tpu::RoundingMode::kTowardsZero: {
|
||||
auto input = cast<TypedValue<VectorType>>(fptosi_op.getInput());
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), fptosi_op);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> vregs,
|
||||
disassemble(builder, layout_in, input, ctx.target_shape));
|
||||
vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
*v = builder.create<mlir::math::TruncOp>(op.getLoc(), v->getType(),
|
||||
*v);
|
||||
});
|
||||
fptosi_op->replaceUsesOfWith(
|
||||
input, assemble(builder, input.getType(), layout_in, vregs,
|
||||
ctx.target_shape));
|
||||
} break;
|
||||
}
|
||||
return trunc_op_rule_impl(ctx, fptosi_op, layout_in, layout_out);
|
||||
}
|
||||
return op.emitOpError("Unsupported FPToSI conversion");
|
||||
}
|
||||
|
||||
LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -4672,6 +4712,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
|
||||
{tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule},
|
||||
{tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule},
|
||||
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
|
||||
{vector::ExtractOp::getOperationName(), vector_extract_rule},
|
||||
{vector::LoadOp::getOperationName(), vector_load_rule},
|
||||
|
@ -22,8 +22,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -36,6 +34,7 @@
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "mlir/include/mlir/IR/Operation.h"
|
||||
#include "mlir/include/mlir/IR/PatternMatch.h"
|
||||
#include "mlir/include/mlir/IR/Region.h"
|
||||
#include "mlir/include/mlir/IR/Value.h"
|
||||
#include "mlir/include/mlir/Support/LLVM.h"
|
||||
@ -540,6 +539,7 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
|
||||
return success();
|
||||
}
|
||||
|
||||
// All conversions that change bitwidth must be canonicalized to tpu.fptosi.
|
||||
LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = cast<arith::FPToSIOp>(raw_op);
|
||||
@ -561,6 +561,24 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx,
|
||||
if (dst_bitwidth > 32) {
|
||||
return op.emitOpError("Target bitwidth too large");
|
||||
}
|
||||
if (ctx.hardware_generation >= 6 && is_vector &&
|
||||
src_vty.getElementType().isBF16() &&
|
||||
dst_vty.getElementType().isSignlessInteger(8)) {
|
||||
auto new_op = builder.create<tpu::FPToSIOp>(
|
||||
op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero);
|
||||
op.replaceAllUsesWith(new_op.getResult());
|
||||
op.erase();
|
||||
// We briefly trigger canonicalization here to potentially fuse the rounding
|
||||
// ops into the newly created tpu.fptosi.
|
||||
{
|
||||
PatternRewriter rewriter(new_op.getContext());
|
||||
rewriter.setInsertionPoint(new_op);
|
||||
// We don't care if the canonicalization pattern matched or not.
|
||||
(void)tpu::FPToSIOp::canonicalize(new_op, rewriter);
|
||||
new_op = nullptr; // Canonicalization may have erased the op!
|
||||
}
|
||||
return success();
|
||||
}
|
||||
Value x = op.getIn();
|
||||
// Upcast the input to f32.
|
||||
if (src_bitwidth < 32) {
|
||||
|
@ -165,6 +165,14 @@ class VectorLayoutInferer {
|
||||
if (inferTrunc(&any_op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::FPToSIOp>(any_op);
|
||||
op &&
|
||||
cast<VectorType>(op.getOperand().getType())
|
||||
.getElementTypeBitWidth() >
|
||||
cast<VectorType>(op.getType()).getElementTypeBitWidth()) {
|
||||
if (inferTrunc(&any_op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<arith::SelectOp>(any_op)) {
|
||||
auto true_ty = dyn_cast<VectorType>(op.getTrueValue().getType());
|
||||
auto false_ty = dyn_cast<VectorType>(op.getFalseValue().getType());
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Tests for TPU specific operations within pallas_call."""
|
||||
|
||||
import functools
|
||||
import math
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -368,6 +369,33 @@ class OpsTest(PallasBaseTest):
|
||||
expected = jnp.where(mask, x, jnp.zeros_like(x))
|
||||
self.assertArraysEqual(out, expected)
|
||||
|
||||
@parameterized.product(
|
||||
target=(jnp.int8,), # TODO(apaszke): Add int4.
|
||||
round=(False, True),
|
||||
)
|
||||
def test_quantize(self, target, round):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 1, 15):
|
||||
self.skipTest("Requires libtpu built after 2025-01-15")
|
||||
if not jtu.is_device_tpu_at_least(version=6):
|
||||
self.skipTest("Requires TPUv6+")
|
||||
shape = (256, 256)
|
||||
# NOTE: 256 * 256 == 2 ** 16, so those are all bf16 values.
|
||||
x = lax.bitcast_convert_type(
|
||||
np.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape),
|
||||
jnp.bfloat16,
|
||||
)
|
||||
|
||||
round_fn = jnp.rint if round else lambda x: x
|
||||
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = round_fn(x_ref[...]).astype(target)
|
||||
out = self.pallas_call(
|
||||
kernel, out_shape=jax.ShapeDtypeStruct(shape, target)
|
||||
)(x)
|
||||
|
||||
ref = jax.jit(lambda x: round_fn(x).astype(target))(x)
|
||||
np.testing.assert_array_equal(out, ref)
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user