mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[XLA:Mosaic] Support arbitrary aligned shape for tpu.bitcast and support bitcast with bitwidth change in element.
PiperOrigin-RevId: 582524212
This commit is contained in:
parent
555c569e67
commit
8a64d9af40
@ -300,6 +300,12 @@ def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
|
||||
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
|
||||
}
|
||||
|
||||
def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
|
||||
let arguments = (ins TPU_Vreg:$input);
|
||||
let results = (outs TPU_Vreg:$output);
|
||||
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
|
||||
}
|
||||
|
||||
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
|
||||
let arguments = (ins Variadic<AnyType>:$input);
|
||||
let results = (outs AnyType:$output);
|
||||
|
@ -1254,6 +1254,52 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
if (!layouts_in.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null input layout");
|
||||
}
|
||||
if (!layouts_out.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null output layout");
|
||||
}
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
if (!layout_in.hasNativeTiling(ctx.target_shape) ||
|
||||
!layout_out.hasNativeTiling(ctx.target_shape)) {
|
||||
return op.emitOpError("Not implemented: unsupported tiling");
|
||||
}
|
||||
if (layout_in.offsets() != LayoutOffsets{0, 0} ||
|
||||
layout_out.offsets() != LayoutOffsets{0, 0}) {
|
||||
return op.emitOpError("Not implemented: unsupported offsets");
|
||||
}
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
|
||||
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
return op.emitOpError("Not implemented: unsupported implicit dim");
|
||||
}
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
auto bitcast_op = cast<tpu::BitcastOp>(op);
|
||||
const VectorType vty = bitcast_op.getResult().getType();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const auto native_vreg_ty,
|
||||
getNativeVregType(vty.getElementType(), ctx.target_shape));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> in_tiles,
|
||||
disassemble(builder, layout_in, bitcast_op.getInput(), ctx.target_shape));
|
||||
xla::Array<Value> out_tiles(in_tiles.dimensions());
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
const Value in_tile = in_tiles(idxs);
|
||||
*v = builder.create<tpu::BitcastVregOp>(native_vreg_ty, in_tile);
|
||||
});
|
||||
bitcast_op.replaceAllUsesWith(
|
||||
assemble(builder, vty, layout_out, out_tiles, ctx.target_shape)
|
||||
.getOperation());
|
||||
bitcast_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -2701,12 +2747,12 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
|
||||
APInt(32, 0xFFFFFFFF))));
|
||||
auto masked_tile = builder.create<arith::AndIOp>(
|
||||
store_op.getLoc(), mask,
|
||||
builder.create<tpu::BitcastOp>(mask.getType(), tile));
|
||||
builder.create<tpu::BitcastVregOp>(mask.getType(), tile));
|
||||
auto mask_neg = builder.create<arith::XOrIOp>(ones, mask);
|
||||
auto masked_data = builder.create<arith::AndIOp>(
|
||||
mask_neg,
|
||||
builder.create<tpu::BitcastOp>(mask.getType(), data));
|
||||
updated = builder.create<tpu::BitcastOp>(
|
||||
builder.create<tpu::BitcastVregOp>(mask.getType(), data));
|
||||
updated = builder.create<tpu::BitcastVregOp>(
|
||||
tile.getType(),
|
||||
builder.create<arith::OrIOp>(masked_data, masked_tile));
|
||||
} else {
|
||||
@ -2917,6 +2963,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
|
||||
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
|
||||
{tpu::StoreOp::getOperationName(), tpu_store_rule},
|
||||
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
|
||||
{vector::ContractionOp::getOperationName(), vector_contract_rule},
|
||||
@ -3541,7 +3588,7 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
|
||||
DenseElementsAttr::get(bits_vreg_ty, shift_bits));
|
||||
dst_tiles.Each([&](absl::Span<const int64_t> /*idx*/, Value *tile) {
|
||||
auto bit_tile =
|
||||
builder.create<tpu::BitcastOp>(v.getLoc(), bits_vreg_ty, *tile);
|
||||
builder.create<tpu::BitcastVregOp>(v.getLoc(), bits_vreg_ty, *tile);
|
||||
Operation *shift_tile;
|
||||
if (subelem_diff > 0) {
|
||||
shift_tile =
|
||||
@ -3552,7 +3599,7 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
|
||||
shift_vreg);
|
||||
}
|
||||
*tile = builder
|
||||
.create<tpu::BitcastOp>(v.getLoc(), tile->getType(),
|
||||
.create<tpu::BitcastVregOp>(v.getLoc(), tile->getType(),
|
||||
shift_tile->getResult(0))
|
||||
.getResult();
|
||||
return absl::OkStatus();
|
||||
|
@ -241,6 +241,10 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::BitcastOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::RepeatOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -690,6 +694,50 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::BitcastOp op) {
|
||||
auto src_layout = getLayout(op.getInput());
|
||||
LayoutOffsets src_offsets = src_layout->offsets();
|
||||
if (src_offsets[0].value_or(0) || src_offsets[1].value_or(0)) {
|
||||
NYI("unsupported bitcast with offsets");
|
||||
}
|
||||
if (src_layout->implicit_dim() != ImplicitDim::kNone) {
|
||||
NYI("unsupported bitcast with an implicit dim");
|
||||
}
|
||||
// Check if input and output have same bit size.
|
||||
auto in_ty = dyn_cast<VectorType>(op.getInput().getType());
|
||||
auto out_ty = dyn_cast<VectorType>(op.getOutput().getType());
|
||||
auto in_bitwidth = in_ty.getElementTypeBitWidth();
|
||||
auto out_bitwidth = out_ty.getElementTypeBitWidth();
|
||||
TPU_CHECK_OP(in_ty && out_ty && in_ty.getRank() == out_ty.getRank(),
|
||||
"Input and output have different rank");
|
||||
if (out_ty.getRank() < 2) {
|
||||
NYI("Support bitcast with 1D vector");
|
||||
}
|
||||
for (int i = 0; i < in_ty.getRank(); ++i) {
|
||||
auto in_dim = in_ty.getDimSize(i);
|
||||
auto out_dim = out_ty.getDimSize(i);
|
||||
|
||||
// The sublane dimension is scaled down by the ratio of input element
|
||||
// bitwidth to output element bitwidth when bitcasting. For example,
|
||||
// bitcasting a vector<16x128xbf16> to a vector<8x128xi32> packs every 2
|
||||
// rows in the bf16 vector into 1 row in the i32 vector. This means the
|
||||
// bit representation of one i32 element vector[i,j] is equal to
|
||||
// concatenating bf16 elements vector[2*i+1,j] and vector[2*i,j].
|
||||
if (i == in_ty.getRank() - 2) {
|
||||
in_dim *= in_bitwidth;
|
||||
out_dim *= out_bitwidth;
|
||||
}
|
||||
TPU_CHECK_OP(in_dim == out_dim,
|
||||
"Input and output have incompatible shape");
|
||||
}
|
||||
setLayout(op,
|
||||
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
|
||||
ImplicitDim::kNone),
|
||||
VectorLayout(out_bitwidth, src_offsets,
|
||||
nativeTiling(out_bitwidth), ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::RepeatOp op) {
|
||||
auto src_layout = getLayout(op.getSource());
|
||||
setLayout(op, src_layout, src_layout);
|
||||
|
@ -1394,14 +1394,14 @@ def relayout(
|
||||
),
|
||||
)
|
||||
for idx, tile in np.ndenumerate(dst_tiles):
|
||||
bit_tile = tpu.BitcastOp(bits_vreg_ty, tile)
|
||||
bit_tile = tpu.BitcastVregOp(bits_vreg_ty, tile)
|
||||
if subelem_diff > 0:
|
||||
shift_tile = arith.ShLIOp(bit_tile, shift_vreg)
|
||||
elif subelem_diff < 0:
|
||||
shift_tile = arith.ShRUIOp(bit_tile, shift_vreg)
|
||||
else:
|
||||
raise AssertionError("unexpected equal subelements")
|
||||
dst_tiles[idx] = tpu.BitcastOp(tile.type, shift_tile).result
|
||||
dst_tiles[idx] = tpu.BitcastVregOp(tile.type, shift_tile).result
|
||||
|
||||
# Shifting columns.
|
||||
if src.offsets[1] is REPLICATED:
|
||||
@ -2200,6 +2200,28 @@ def _tpu_store_rule( # pylint: disable=missing-function-docstring
|
||||
return ctx.erase(op)
|
||||
|
||||
|
||||
@_register_rule("tpu.bitcast")
|
||||
def _tpu_bitcast_rule( # pylint: disable=missing-function-docstring
|
||||
ctx: RewriteContext,
|
||||
op: tpu.BitcastOp,
|
||||
layout_in: VectorLayout,
|
||||
layout_out: VectorLayout,
|
||||
):
|
||||
if not layout_in.has_native_tiling or not layout_out.has_native_tiling:
|
||||
raise NotImplementedError("unsupported tiling")
|
||||
if layout_in.offsets != (0, 0) or layout_out.offsets != (0, 0):
|
||||
raise NotImplementedError("unsupported offsets")
|
||||
if layout_in.implicit_dim is not None or layout_out.implicit_dim is not None:
|
||||
raise NotImplementedError("unsupported implicit dim")
|
||||
ty = ir.VectorType(op.result.type)
|
||||
vreg = native_vreg_ty(ty.element_type)
|
||||
in_tiles = disassemble(layout_in, op.input)
|
||||
out_tiles = np.empty_like(in_tiles, dtype=object)
|
||||
for idx, tile in np.ndenumerate(in_tiles):
|
||||
out_tiles[idx] = tpu.BitcastVregOp(vreg, tile)
|
||||
return ctx.replace(op, assemble(ty, layout_out, out_tiles))
|
||||
|
||||
|
||||
@_register_rule("tpu.trace")
|
||||
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
|
||||
layout_in: Layout, layout_out: Layout):
|
||||
@ -2713,10 +2735,12 @@ def _vector_store_rule( # pylint: disable=missing-function-docstring
|
||||
mask.type, ir.IntegerAttr.get(i32(), 0xFFFFFFFF)
|
||||
),
|
||||
)
|
||||
masked_tile = arith.AndIOp(mask, tpu.BitcastOp(mask.type, tile))
|
||||
masked_tile = arith.AndIOp(mask, tpu.BitcastVregOp(mask.type, tile))
|
||||
mask_neg = arith.XOrIOp(ones, mask)
|
||||
masked_data = arith.AndIOp(mask_neg, tpu.BitcastOp(mask.type, data))
|
||||
updated = tpu.BitcastOp(
|
||||
masked_data = arith.AndIOp(
|
||||
mask_neg, tpu.BitcastVregOp(mask.type, data)
|
||||
)
|
||||
updated = tpu.BitcastVregOp(
|
||||
tile.type, arith.OrIOp(masked_data, masked_tile))
|
||||
else:
|
||||
updated = arith.SelectOp(mask, tile, data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user