mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
This commit is contained in:
parent
69878c4924
commit
5b996f7680
@ -249,7 +249,7 @@ bool incrementIndex(const MutableArrayRef<int64_t> idx,
|
||||
}
|
||||
|
||||
bool sliceIsEmpty(const absl::Span<const int64_t> starts,
|
||||
const absl::Span<const int64_t> limits) {
|
||||
const absl::Span<const int64_t> limits) {
|
||||
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
|
||||
CHECK_LE(s, l);
|
||||
if (s == l) {
|
||||
@ -282,9 +282,19 @@ void updateSliceFromRange(xla::Array<T> &arr, Range data,
|
||||
return;
|
||||
}
|
||||
SmallVector<int64_t> idx(toArrayRef(starts));
|
||||
auto in_bounds = [&] {
|
||||
for (int64_t i = 0; i < idx.size(); ++i) {
|
||||
if (idx[i] >= arr.dim(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
auto data_it = data.begin();
|
||||
do {
|
||||
arr(idx) = *data_it;
|
||||
if (in_bounds()) {
|
||||
arr(idx) = *data_it;
|
||||
}
|
||||
++data_it;
|
||||
} while (incrementSliceIndex(idx, starts, limits));
|
||||
CHECK(data_it == data.end());
|
||||
@ -1307,7 +1317,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
|
||||
}
|
||||
tpu::LoadOp load_op = cast<tpu::LoadOp>(op);
|
||||
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
|
||||
VectorLayout::ImplicitDim::kNone)) {
|
||||
VectorLayout::ImplicitDim::kNone)) {
|
||||
return op.emitOpError("Invalid output layout for ") << load_op->getName();
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
@ -1863,7 +1873,7 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_EQ_OP(op.getNumResults(), 1);
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
if (layouts_in[0] !=layouts_out[0]) {
|
||||
if (layouts_in[0] != layouts_out[0]) {
|
||||
return op.emitOpError("Expected same input and output layout");
|
||||
}
|
||||
OpBuilder builder(&op);
|
||||
@ -2622,8 +2632,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
const LayoutOffsets offsets_in = layout_in.offsets();
|
||||
const LayoutOffsets offsets_out = layout_out.offsets();
|
||||
if (layout_in.tiling() != layout_out.tiling()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing tiling mid-broadcast");
|
||||
return op.emitOpError("Not implemented: Changing tiling mid-broadcast");
|
||||
}
|
||||
auto tiling = layout_in.tiling();
|
||||
|
||||
@ -2745,8 +2754,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
VectorType::get(ctx.target_shape, builder.getI32Type());
|
||||
auto idx_const = builder.create<arith::ConstantOp>(
|
||||
broadcast_op.getLoc(), idx_ty,
|
||||
DenseElementsAttr::get(idx_ty,
|
||||
builder.getI32IntegerAttr(offset)));
|
||||
DenseElementsAttr::get(idx_ty, builder.getI32IntegerAttr(offset)));
|
||||
int64_t sublanes_per_tile = layout_in.sublanesPerTile(ctx.target_shape);
|
||||
DenseI32ArrayAttr sublane_pattern;
|
||||
if (num_tiles != 1) {
|
||||
@ -3687,11 +3695,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
|
||||
"Not implemented: Non-native or offset layout unsupported");
|
||||
}
|
||||
const int64_t transpose_unit_size = ctx.target_shape[1];
|
||||
for (const int64_t s : src_ty.getShape().take_back(2)) {
|
||||
if (s % transpose_unit_size != 0) {
|
||||
return transpose_op->emitOpError("Not implemented: Padded transpose");
|
||||
}
|
||||
}
|
||||
if (ctx.hardware_generation < 4 && layout_in.bitwidth() != 32) {
|
||||
return transpose_op->emitOpError(
|
||||
"Not implemented: TPUs before v4 only support 32-bit transposes");
|
||||
@ -3730,8 +3733,8 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
|
||||
src_slice_ends.append(incremented_batch_idx.begin(),
|
||||
incremented_batch_idx.end());
|
||||
src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end});
|
||||
xla::Array<Value> src_tile_vregs =
|
||||
src_vregs.Slice(src_slice_starts, src_slice_ends);
|
||||
xla::Array<Value> src_tile_vregs = src_vregs.Slice(
|
||||
src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true);
|
||||
// Drop leading singleton (batch) dimensions to have a shape that conforms
|
||||
// with the vreg array shape specified by layout_in, as expected by assemble
|
||||
src_tile_vregs.Reshape(
|
||||
@ -3762,12 +3765,12 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<int64_t> batch_sizes =
|
||||
dst_ty.getShape().take_front(num_batch_dims);
|
||||
SmallVector<int64_t> batch_idx(num_batch_dims);
|
||||
const int64_t tile_rows =
|
||||
xla::CeilOfRatio(*(src_ty.getShape().end() - 2), transpose_unit_size);
|
||||
const int64_t num_col_tiles =
|
||||
xla::CeilOfRatio(*(src_ty.getShape().end() - 1), transpose_unit_size);
|
||||
do {
|
||||
const int64_t tile_rows =
|
||||
*(src_ty.getShape().end() - 2) / transpose_unit_size;
|
||||
for (int64_t src_row = 0; src_row < tile_rows; ++src_row) {
|
||||
const int64_t num_col_tiles =
|
||||
*(src_ty.getShape().end() - 1) / transpose_unit_size;
|
||||
if (can_batch) {
|
||||
const int64_t num_batch_tiles = num_col_tiles / 2;
|
||||
for (int64_t src_col = 0; src_col < num_batch_tiles; ++src_col) {
|
||||
@ -4307,7 +4310,7 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
*(src_tiles.dimensions().end() - 2) == 1)) &&
|
||||
dst.offsets()[1] == 0 && src.tiling() == std::array<int64_t, 2>{1, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8, 128}) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) {
|
||||
@ -4466,8 +4469,8 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
v.getLoc(), bits_vreg_ty,
|
||||
DenseElementsAttr::get(bits_vreg_ty, shift_bits));
|
||||
dst_tiles.Each([&](absl::Span<const int64_t> /*idx*/, Value *tile) {
|
||||
auto bit_tile =
|
||||
builder.create<tpu::BitcastVregOp>(v.getLoc(), bits_vreg_ty, *tile);
|
||||
auto bit_tile = builder.create<tpu::BitcastVregOp>(
|
||||
v.getLoc(), bits_vreg_ty, *tile);
|
||||
Operation *shift_tile;
|
||||
if (subelem_diff > 0) {
|
||||
shift_tile =
|
||||
@ -4479,7 +4482,7 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
}
|
||||
*tile = builder
|
||||
.create<tpu::BitcastVregOp>(v.getLoc(), tile->getType(),
|
||||
shift_tile->getResult(0))
|
||||
shift_tile->getResult(0))
|
||||
.getResult();
|
||||
return absl::OkStatus();
|
||||
});
|
||||
|
@ -898,8 +898,9 @@ class VectorLayoutInferer {
|
||||
if (some_layout->tiling()[0] == 1) {
|
||||
offsets[0] = std::nullopt;
|
||||
}
|
||||
*some_layout = VectorLayout(some_layout->bitwidth(), offsets,
|
||||
default_tiling_, some_layout->implicit_dim());
|
||||
*some_layout =
|
||||
VectorLayout(some_layout->bitwidth(), offsets, default_tiling_,
|
||||
some_layout->implicit_dim());
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() != ImplicitDim::kNone) {
|
||||
@ -1410,44 +1411,32 @@ class VectorLayoutInferer {
|
||||
|
||||
LogicalResult infer(vector::TransposeOp op) {
|
||||
auto permutation = op.getPermutation();
|
||||
TPU_CHECK_OP(permutation.size() > 1,
|
||||
"Vector and scalar transpose should be a no-op and removed");
|
||||
|
||||
auto some_layout = getLayout(op.getVector());
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
auto &layout = *some_layout;
|
||||
auto src_ty = op.getSourceVectorType();
|
||||
TPU_CHECK_OP(permutation.size() == src_ty.getRank(),
|
||||
"Transpose permutation has incorrect rank");
|
||||
if (layout.implicit_dim() == ImplicitDim::kNone) {
|
||||
TPU_CHECK_OP((layout.offsets() == LayoutOffsets{0, 0}),
|
||||
"Padded transposes unsupported");
|
||||
auto xlu_width = target_shape_[1];
|
||||
for (int64_t s : src_ty.getShape().take_back(2)) {
|
||||
TPU_CHECK_OP(s % xlu_width == 0, "Padded transposes unsupported");
|
||||
}
|
||||
for (auto dim : permutation.drop_back(2)) {
|
||||
TPU_CHECK_OP(
|
||||
dim < src_ty.getRank() - 2,
|
||||
"Unsupported transpose permutation - minor dims into major");
|
||||
}
|
||||
for (auto dim : permutation.take_back(2)) {
|
||||
TPU_CHECK_OP(
|
||||
dim >= src_ty.getRank() - 2,
|
||||
"Unsupported transpose permutation - major dims into minor");
|
||||
}
|
||||
Layout required_layout = some_layout;
|
||||
if (permutation.size() < 2) {
|
||||
return failure();
|
||||
}
|
||||
// Require native tiling if we're going to use the XLU.
|
||||
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
|
||||
auto native_tiling = nativeTiling(layout.bitwidth());
|
||||
required_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
native_tiling, ImplicitDim::kNone);
|
||||
}
|
||||
setLayout(op, required_layout, required_layout);
|
||||
return success();
|
||||
for (auto dim : permutation.drop_back(2)) {
|
||||
TPU_CHECK_OP(dim < src_ty.getRank() - 2,
|
||||
"Unsupported transpose permutation - minor dims into major");
|
||||
}
|
||||
op.emitOpError("Unsupported transpose");
|
||||
return failure();
|
||||
for (auto dim : permutation.take_back(2)) {
|
||||
TPU_CHECK_OP(dim >= src_ty.getRank() - 2,
|
||||
"Unsupported transpose permutation - major dims into minor");
|
||||
}
|
||||
Layout required_layout = some_layout;
|
||||
// Require native tiling if we're going to use the XLU.
|
||||
if (permutation[permutation.size() - 1] == permutation.size() - 2) {
|
||||
auto native_tiling = nativeTiling(layout.bitwidth());
|
||||
required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0},
|
||||
native_tiling, ImplicitDim::kNone);
|
||||
}
|
||||
setLayout(op, required_layout, required_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferExt(Operation *op) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user