[Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128)

PiperOrigin-RevId: 696870258
This commit is contained in:
jax authors 2024-11-15 06:41:14 -08:00
parent 9a0e9e55d8
commit 1471702adc
3 changed files with 92 additions and 50 deletions

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <ostream>
#include <tuple>
#include "absl/log/check.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
@ -39,6 +38,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
namespace mlir::tpu {
@ -259,18 +259,23 @@ class VectorLayout {
int layout_rank() const { return layout_rank(implicit_dim_); }
bool operator==(const VectorLayout &other) const;
bool operator!=(const VectorLayout &other) const {
return !(*this == other);
}
bool operator!=(const VectorLayout &other) const { return !(*this == other); }
// How many tiles fit in each vector register.
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
const int64_t tile_elems = tiling_[0] * tiling_[1];
const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1];
static int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape,
const int8_t bitwidth,
const std::array<int64_t, 2> tiling) {
CHECK_NE(0, bitwidth) << "bitwidth cannot be 0";
const int64_t tile_elems = tiling[0] * tiling[1];
const int64_t vreg_capacity =
(32 / bitwidth) * target_shape[0] * target_shape[1];
const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems);
CHECK_EQ(rem, 0);
return tiles_per_vreg;
}
// How many tiles fit in each vector register.
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_);
}
int64_t sublanesPerTile(const std::array<int64_t, 2> target_shape) const {
auto [sublanes_per_tile, rem] =
@ -283,8 +288,16 @@ class VectorLayout {
//
// We never reuse the same vector register to store data of multiple rows,
// so only the minormost dimension can increase.
static std::array<int64_t, 2> vregSlice(std::array<int64_t, 2> target_shape,
const int8_t bitwidth,
const std::array<int64_t, 2> tiling) {
return {
tiling[0],
VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]};
}
std::array<int64_t, 2> vregSlice(std::array<int64_t, 2> target_shape) const {
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_);
}
template <typename T>

View File

@ -2554,7 +2554,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(res_layout.has_value());
auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank();
if (dimension >= num_untiled_dims) {
if (res_ty.getRank() == 1 &&
res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) {
tiling_dim = 1;
} else if (dimension >= num_untiled_dims) {
tiling_dim = dimension - num_untiled_dims;
}
@ -2576,6 +2579,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError("Not implemented: result/input offsets mismatch.");
}
if (layout.implicit_dim() != res_layout->implicit_dim()) {
return op.emitOpError(
"Not implemented: result/input implicit dim mismatch.");
}
if (i > 1) {
auto curr_offsets = layout.offsets();
auto last_operand_offsets = layouts_in[i - 1]->offsets();
@ -2611,29 +2619,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
if (!tiling_dim.has_value()) {
out_vregs = concatenate(operand_vregs, dimension);
} else {
if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 &&
res_layout->implicit_dim() ==
VectorLayout::ImplicitDim::kNone;
if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor ||
is_rank1_with_no_implicit_dim) {
return op.emitOpError("Not implemented: implicit dim");
}
if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
res_layout->bitwidth() != 32) {
return op.emitOpError(
"Not implemented: only 32-bit bitwidth supported for SecondMinor "
"implicit dim");
}
if (res_layout->offsets()[tiling_dim.value()] != 0) {
return op.emitOpError("Not implemented: result non-zero offset.");
}
if (!res_layout->hasNativeTiling(ctx.target_shape)) {
if (!res_layout->hasNativeTiling(ctx.target_shape) &&
res_ty.getRank() != 1) {
return op.emitOpError("Not implemented: Non native tiling in concat.");
}
int64_t offset_at_dim = 0;
{
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
auto const &layout = *layouts_in[i];
Value operand = op.getOperand(i);
const Layout &layout = *layouts_in[i];
xla::Array<Value> vreg_array = operand_vregs[i];
std::array<int64_t, 2> vreg_slice = layout->vregSlice(ctx.target_shape);
std::array<int64_t, 2> tiling = layout->tiling();
auto vty = cast<VectorType>(operand.getType());
auto shape = vty.getShape();
VectorType vty = cast<VectorType>(operand.getType());
ArrayRef<int64_t> shape = vty.getShape();
auto starting_point = offset_at_dim;
auto offset_amount =
starting_point % layout.tiling()[tiling_dim.value()];
if (offset_amount != layout.offsets()[tiling_dim.value()]) {
int64_t starting_point = offset_at_dim;
int64_t offset_amount =
starting_point % vreg_slice[tiling_dim.value()];
if (offset_amount >= tiling[tiling_dim.value()]) {
return op.emitError(
"Not implemented: Input offsets outside of the first tile");
}
if (offset_amount != layout->offsets()[tiling_dim.value()]) {
return op.emitOpError(
"Not implemented: Relayout not called, unaligned dims "
"concatenated without proper offsets. Ensure that "
@ -2649,10 +2675,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
auto &vreg = operand_vregs[i];
const auto &layout = layouts_in[i];
if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: implicit dim");
}
const int64_t operand_offset = *layout->offsets()[tiling_dim.value()];
if (operand_offset != 0) {
// We are offset, so we must blend with the previous vreg.

View File

@ -770,14 +770,11 @@ class VectorLayoutInferer {
LogicalResult infer(tpu::ConcatenateOp op) {
TPU_CHECK_OP(!op.getSources().empty(),
"Need at least one vector to concatenate");
auto res_rank = op.getType().getRank();
auto dimension = op.getDimension();
int64_t res_rank = op.getType().getRank();
uint32_t dimension = op.getDimension();
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
"Expect a valid concatenate dimension");
if (res_rank == 1) {
NYI("Support concatenation with 1D vectors");
}
auto res_ty = op.getResult().getType();
VectorType res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
std::optional<int64_t> tiling_dim;
@ -790,29 +787,39 @@ class VectorLayoutInferer {
if (tiling_dim.has_value()) {
int64_t starting_point = 0;
auto first_layout = getLayout(op.getSources().front());
auto op_layouts = getLayoutFromOperands(op);
Layout first_layout = getLayout(op.getSources().front());
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
SmallVector<Layout> in_layouts;
in_layouts.reserve(op.getSources().size());
auto native_tiling = nativeTiling(bitwidth);
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
std::array<int64_t, 2> tiling =
res_rank == 1 ? std::array<int64_t, 2>{1L, target_shape_[1]}
: nativeTiling(bitwidth);
ImplicitDim implicit_dim =
res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone;
std::array<int64_t, 2> vreg_slice =
VectorLayout::vregSlice(target_shape_, bitwidth, tiling);
for (int i = 0; i < op.getSources().size(); ++i) {
// Compute the offset per source.
// Ex: for a cat of (10, 128), (10, 128) on dim 0, where the
// vreg_sice for that dim is 8, the first source starts at
// vreg_slice for that dim is 8, the first source starts at
// offset 0, and overflows the vreg
// by 2, so the offset for the second input is 2.
auto op_shape =
ArrayRef<int64_t> op_shape =
cast<VectorType>(op.getSources()[i].getType()).getShape();
auto offset_amount = starting_point % native_tiling[tiling_dim.value()];
auto op_layout = op_layouts[i];
Layout op_layout = op_layouts[i];
int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()];
if (offset_amount >= tiling[tiling_dim.value()]) {
return op.emitError(
"Not implemented: Input offsets outside of the first tile");
}
SmallVector<int64_t> in_idx{op_layout->offsets()[0].value_or(0),
op_layout->offsets()[1].value_or(0)};
in_idx[tiling_dim.value()] = offset_amount;
starting_point += op_shape[dimension];
in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]},
native_tiling, ImplicitDim::kNone));
tiling, implicit_dim));
}
SmallVector<int64_t> res_layout_offsets(
{first_layout->offsets()[0].value_or(0),
@ -821,13 +828,13 @@ class VectorLayoutInferer {
// TODO(mvoz): A tiny optimization we could do here later is to
// no-op setting tiling when sublane dim size is aligned to sublane
// tiling.
auto res_layout =
VectorLayout res_layout =
VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]},
native_tiling, ImplicitDim::kNone);
tiling, implicit_dim);
setLayout(op, in_layouts, res_layout);
return success();
} else {
auto layout = getLayout(op.getSources().front());
Layout layout = getLayout(op.getSources().front());
// When concatenating vectors with replicated offsets, we want to reset
// the replicated offset to zero. Because we are not sure if the
// replicated value from each vector are same.
@ -1464,11 +1471,11 @@ class VectorLayoutInferer {
// unfolding, it's still a no-op, but we need to
// add support in apply-vector-layout.
LayoutOffsets offsets = {0, layout.offsets()[1]};
setLayout(op,
VectorLayout(layout.bitwidth(), offsets, tiling,
layout.implicit_dim()),
VectorLayout(layout.bitwidth(), offsets, tiling,
implicit_dim));
setLayout(
op,
VectorLayout(layout.bitwidth(), offsets, tiling,
layout.implicit_dim()),
VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim));
return success();
}
sublane_tiling /= 2;
@ -1845,9 +1852,9 @@ class VectorLayoutInferer {
"only 32-bit random bit generation supported");
// TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp.
LayoutOffsets offsets = {0, 0};
setOutLayout(op, VectorLayout(
kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth),
ImplicitDim::kNone));
setOutLayout(
op, VectorLayout(kNativeBitwidth, offsets,
nativeTiling(kNativeBitwidth), ImplicitDim::kNone));
return success();
}