[XLA:Mosaic] Support arbitrary aligned shape for tpu.bitcast and support bitcast with bitwidth change in element.

PiperOrigin-RevId: 582524212
This commit is contained in:
Jevin Jiang 2023-11-14 20:25:08 -08:00 committed by jax authors
parent 555c569e67
commit 8a64d9af40
4 changed files with 135 additions and 10 deletions

View File

@ -300,6 +300,12 @@ def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
}
def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
let arguments = (ins TPU_Vreg:$input);
let results = (outs TPU_Vreg:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
}
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
let arguments = (ins Variadic<AnyType>:$input);
let results = (outs AnyType:$output);

View File

@ -1254,6 +1254,52 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (!layout_in.hasNativeTiling(ctx.target_shape) ||
!layout_out.hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented: unsupported tiling");
}
if (layout_in.offsets() != LayoutOffsets{0, 0} ||
layout_out.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: unsupported offsets");
}
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: unsupported implicit dim");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
auto bitcast_op = cast<tpu::BitcastOp>(op);
const VectorType vty = bitcast_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(
const auto native_vreg_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> in_tiles,
disassemble(builder, layout_in, bitcast_op.getInput(), ctx.target_shape));
xla::Array<Value> out_tiles(in_tiles.dimensions());
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const Value in_tile = in_tiles(idxs);
*v = builder.create<tpu::BitcastVregOp>(native_vreg_ty, in_tile);
});
bitcast_op.replaceAllUsesWith(
assemble(builder, vty, layout_out, out_tiles, ctx.target_shape)
.getOperation());
bitcast_op.erase();
return success();
}
LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
@ -2701,12 +2747,12 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
APInt(32, 0xFFFFFFFF))));
auto masked_tile = builder.create<arith::AndIOp>(
store_op.getLoc(), mask,
builder.create<tpu::BitcastOp>(mask.getType(), tile));
builder.create<tpu::BitcastVregOp>(mask.getType(), tile));
auto mask_neg = builder.create<arith::XOrIOp>(ones, mask);
auto masked_data = builder.create<arith::AndIOp>(
mask_neg,
builder.create<tpu::BitcastOp>(mask.getType(), data));
updated = builder.create<tpu::BitcastOp>(
builder.create<tpu::BitcastVregOp>(mask.getType(), data));
updated = builder.create<tpu::BitcastVregOp>(
tile.getType(),
builder.create<arith::OrIOp>(masked_data, masked_tile));
} else {
@ -2917,6 +2963,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::ContractionOp::getOperationName(), vector_contract_rule},
@ -3541,7 +3588,7 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
DenseElementsAttr::get(bits_vreg_ty, shift_bits));
dst_tiles.Each([&](absl::Span<const int64_t> /*idx*/, Value *tile) {
auto bit_tile =
builder.create<tpu::BitcastOp>(v.getLoc(), bits_vreg_ty, *tile);
builder.create<tpu::BitcastVregOp>(v.getLoc(), bits_vreg_ty, *tile);
Operation *shift_tile;
if (subelem_diff > 0) {
shift_tile =
@ -3552,7 +3599,7 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
shift_vreg);
}
*tile = builder
.create<tpu::BitcastOp>(v.getLoc(), tile->getType(),
.create<tpu::BitcastVregOp>(v.getLoc(), tile->getType(),
shift_tile->getResult(0))
.getResult();
return absl::OkStatus();

View File

@ -241,6 +241,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::BitcastOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RepeatOp>(any_op)) {
if (infer(op).failed()) {
return failure();
@ -690,6 +694,50 @@ class VectorLayoutInferer {
return success();
}
LogicalResult infer(tpu::BitcastOp op) {
auto src_layout = getLayout(op.getInput());
LayoutOffsets src_offsets = src_layout->offsets();
if (src_offsets[0].value_or(0) || src_offsets[1].value_or(0)) {
NYI("unsupported bitcast with offsets");
}
if (src_layout->implicit_dim() != ImplicitDim::kNone) {
NYI("unsupported bitcast with an implicit dim");
}
// Check if input and output have same bit size.
auto in_ty = dyn_cast<VectorType>(op.getInput().getType());
auto out_ty = dyn_cast<VectorType>(op.getOutput().getType());
auto in_bitwidth = in_ty.getElementTypeBitWidth();
auto out_bitwidth = out_ty.getElementTypeBitWidth();
TPU_CHECK_OP(in_ty && out_ty && in_ty.getRank() == out_ty.getRank(),
"Input and output have different rank");
if (out_ty.getRank() < 2) {
NYI("Support bitcast with 1D vector");
}
for (int i = 0; i < in_ty.getRank(); ++i) {
auto in_dim = in_ty.getDimSize(i);
auto out_dim = out_ty.getDimSize(i);
// The sublane dimension is scaled down by the ratio of input element
// bitwidth to output element bitwidth when bitcasting. For example,
// bitcasting a vector<16x128xbf16> to a vector<8x128xi32> packs every 2
// rows in the bf16 vector into 1 row in the i32 vector. This means the
// bit representation of one i32 element vector[i,j] is equal to
// concatenating bf16 elements vector[2*i+1,j] and vector[2*i,j].
if (i == in_ty.getRank() - 2) {
in_dim *= in_bitwidth;
out_dim *= out_bitwidth;
}
TPU_CHECK_OP(in_dim == out_dim,
"Input and output have incompatible shape");
}
setLayout(op,
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
ImplicitDim::kNone),
VectorLayout(out_bitwidth, src_offsets,
nativeTiling(out_bitwidth), ImplicitDim::kNone));
return success();
}
LogicalResult infer(tpu::RepeatOp op) {
auto src_layout = getLayout(op.getSource());
setLayout(op, src_layout, src_layout);

View File

@ -1394,14 +1394,14 @@ def relayout(
),
)
for idx, tile in np.ndenumerate(dst_tiles):
bit_tile = tpu.BitcastOp(bits_vreg_ty, tile)
bit_tile = tpu.BitcastVregOp(bits_vreg_ty, tile)
if subelem_diff > 0:
shift_tile = arith.ShLIOp(bit_tile, shift_vreg)
elif subelem_diff < 0:
shift_tile = arith.ShRUIOp(bit_tile, shift_vreg)
else:
raise AssertionError("unexpected equal subelements")
dst_tiles[idx] = tpu.BitcastOp(tile.type, shift_tile).result
dst_tiles[idx] = tpu.BitcastVregOp(tile.type, shift_tile).result
# Shifting columns.
if src.offsets[1] is REPLICATED:
@ -2200,6 +2200,28 @@ def _tpu_store_rule( # pylint: disable=missing-function-docstring
return ctx.erase(op)
@_register_rule("tpu.bitcast")
def _tpu_bitcast_rule( # pylint: disable=missing-function-docstring
ctx: RewriteContext,
op: tpu.BitcastOp,
layout_in: VectorLayout,
layout_out: VectorLayout,
):
if not layout_in.has_native_tiling or not layout_out.has_native_tiling:
raise NotImplementedError("unsupported tiling")
if layout_in.offsets != (0, 0) or layout_out.offsets != (0, 0):
raise NotImplementedError("unsupported offsets")
if layout_in.implicit_dim is not None or layout_out.implicit_dim is not None:
raise NotImplementedError("unsupported implicit dim")
ty = ir.VectorType(op.result.type)
vreg = native_vreg_ty(ty.element_type)
in_tiles = disassemble(layout_in, op.input)
out_tiles = np.empty_like(in_tiles, dtype=object)
for idx, tile in np.ndenumerate(in_tiles):
out_tiles[idx] = tpu.BitcastVregOp(vreg, tile)
return ctx.replace(op, assemble(ty, layout_out, out_tiles))
@_register_rule("tpu.trace")
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
layout_in: Layout, layout_out: Layout):
@ -2713,10 +2735,12 @@ def _vector_store_rule( # pylint: disable=missing-function-docstring
mask.type, ir.IntegerAttr.get(i32(), 0xFFFFFFFF)
),
)
masked_tile = arith.AndIOp(mask, tpu.BitcastOp(mask.type, tile))
masked_tile = arith.AndIOp(mask, tpu.BitcastVregOp(mask.type, tile))
mask_neg = arith.XOrIOp(ones, mask)
masked_data = arith.AndIOp(mask_neg, tpu.BitcastOp(mask.type, data))
updated = tpu.BitcastOp(
masked_data = arith.AndIOp(
mask_neg, tpu.BitcastVregOp(mask.type, data)
)
updated = tpu.BitcastVregOp(
tile.type, arith.OrIOp(masked_data, masked_tile))
else:
updated = arith.SelectOp(mask, tile, data)