mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128)
PiperOrigin-RevId: 696870258
This commit is contained in:
parent
9a0e9e55d8
commit
1471702adc
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <ostream>
|
||||
#include <tuple>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
@ -39,6 +38,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
@ -259,18 +259,23 @@ class VectorLayout {
|
||||
int layout_rank() const { return layout_rank(implicit_dim_); }
|
||||
|
||||
bool operator==(const VectorLayout &other) const;
|
||||
bool operator!=(const VectorLayout &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
bool operator!=(const VectorLayout &other) const { return !(*this == other); }
|
||||
|
||||
// How many tiles fit in each vector register.
|
||||
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
|
||||
const int64_t tile_elems = tiling_[0] * tiling_[1];
|
||||
const int64_t vreg_capacity = packing() * target_shape[0] * target_shape[1];
|
||||
static int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape,
|
||||
const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> tiling) {
|
||||
CHECK_NE(0, bitwidth) << "bitwidth cannot be 0";
|
||||
const int64_t tile_elems = tiling[0] * tiling[1];
|
||||
const int64_t vreg_capacity =
|
||||
(32 / bitwidth) * target_shape[0] * target_shape[1];
|
||||
const auto [tiles_per_vreg, rem] = std::div(vreg_capacity, tile_elems);
|
||||
CHECK_EQ(rem, 0);
|
||||
return tiles_per_vreg;
|
||||
}
|
||||
// How many tiles fit in each vector register.
|
||||
int64_t tilesPerVreg(const std::array<int64_t, 2> target_shape) const {
|
||||
return VectorLayout::tilesPerVreg(target_shape, bitwidth_, tiling_);
|
||||
}
|
||||
|
||||
int64_t sublanesPerTile(const std::array<int64_t, 2> target_shape) const {
|
||||
auto [sublanes_per_tile, rem] =
|
||||
@ -283,8 +288,16 @@ class VectorLayout {
|
||||
//
|
||||
// We never reuse the same vector register to store data of multiple rows,
|
||||
// so only the minormost dimension can increase.
|
||||
static std::array<int64_t, 2> vregSlice(std::array<int64_t, 2> target_shape,
|
||||
const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> tiling) {
|
||||
return {
|
||||
tiling[0],
|
||||
VectorLayout::tilesPerVreg(target_shape, bitwidth, tiling) * tiling[1]};
|
||||
}
|
||||
|
||||
std::array<int64_t, 2> vregSlice(std::array<int64_t, 2> target_shape) const {
|
||||
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
|
||||
return VectorLayout::vregSlice(target_shape, bitwidth_, tiling_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -2554,7 +2554,10 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_OP(res_layout.has_value());
|
||||
auto num_untiled_dims = res_ty.getRank() - res_layout->layout_rank();
|
||||
|
||||
if (dimension >= num_untiled_dims) {
|
||||
if (res_ty.getRank() == 1 &&
|
||||
res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor) {
|
||||
tiling_dim = 1;
|
||||
} else if (dimension >= num_untiled_dims) {
|
||||
tiling_dim = dimension - num_untiled_dims;
|
||||
}
|
||||
|
||||
@ -2576,6 +2579,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
|
||||
return op.emitOpError("Not implemented: result/input offsets mismatch.");
|
||||
}
|
||||
|
||||
if (layout.implicit_dim() != res_layout->implicit_dim()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: result/input implicit dim mismatch.");
|
||||
}
|
||||
|
||||
if (i > 1) {
|
||||
auto curr_offsets = layout.offsets();
|
||||
auto last_operand_offsets = layouts_in[i - 1]->offsets();
|
||||
@ -2611,29 +2619,47 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
|
||||
if (!tiling_dim.has_value()) {
|
||||
out_vregs = concatenate(operand_vregs, dimension);
|
||||
} else {
|
||||
if (res_layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
bool is_rank1_with_no_implicit_dim = res_ty.getRank() == 1 &&
|
||||
res_layout->implicit_dim() ==
|
||||
VectorLayout::ImplicitDim::kNone;
|
||||
if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kMinor ||
|
||||
is_rank1_with_no_implicit_dim) {
|
||||
return op.emitOpError("Not implemented: implicit dim");
|
||||
}
|
||||
if (res_layout->implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
|
||||
res_layout->bitwidth() != 32) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: only 32-bit bitwidth supported for SecondMinor "
|
||||
"implicit dim");
|
||||
}
|
||||
if (res_layout->offsets()[tiling_dim.value()] != 0) {
|
||||
return op.emitOpError("Not implemented: result non-zero offset.");
|
||||
}
|
||||
if (!res_layout->hasNativeTiling(ctx.target_shape)) {
|
||||
if (!res_layout->hasNativeTiling(ctx.target_shape) &&
|
||||
res_ty.getRank() != 1) {
|
||||
return op.emitOpError("Not implemented: Non native tiling in concat.");
|
||||
}
|
||||
|
||||
int64_t offset_at_dim = 0;
|
||||
{
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto operand = op.getOperand(i);
|
||||
auto const &layout = *layouts_in[i];
|
||||
Value operand = op.getOperand(i);
|
||||
const Layout &layout = *layouts_in[i];
|
||||
xla::Array<Value> vreg_array = operand_vregs[i];
|
||||
std::array<int64_t, 2> vreg_slice = layout->vregSlice(ctx.target_shape);
|
||||
std::array<int64_t, 2> tiling = layout->tiling();
|
||||
|
||||
auto vty = cast<VectorType>(operand.getType());
|
||||
auto shape = vty.getShape();
|
||||
VectorType vty = cast<VectorType>(operand.getType());
|
||||
ArrayRef<int64_t> shape = vty.getShape();
|
||||
|
||||
auto starting_point = offset_at_dim;
|
||||
auto offset_amount =
|
||||
starting_point % layout.tiling()[tiling_dim.value()];
|
||||
if (offset_amount != layout.offsets()[tiling_dim.value()]) {
|
||||
int64_t starting_point = offset_at_dim;
|
||||
int64_t offset_amount =
|
||||
starting_point % vreg_slice[tiling_dim.value()];
|
||||
if (offset_amount >= tiling[tiling_dim.value()]) {
|
||||
return op.emitError(
|
||||
"Not implemented: Input offsets outside of the first tile");
|
||||
}
|
||||
if (offset_amount != layout->offsets()[tiling_dim.value()]) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Relayout not called, unaligned dims "
|
||||
"concatenated without proper offsets. Ensure that "
|
||||
@ -2649,10 +2675,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
|
||||
auto &vreg = operand_vregs[i];
|
||||
const auto &layout = layouts_in[i];
|
||||
|
||||
if (layout->implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
return op.emitOpError("Not implemented: implicit dim");
|
||||
}
|
||||
|
||||
const int64_t operand_offset = *layout->offsets()[tiling_dim.value()];
|
||||
if (operand_offset != 0) {
|
||||
// We are offset, so we must blend with the previous vreg.
|
||||
|
@ -770,14 +770,11 @@ class VectorLayoutInferer {
|
||||
LogicalResult infer(tpu::ConcatenateOp op) {
|
||||
TPU_CHECK_OP(!op.getSources().empty(),
|
||||
"Need at least one vector to concatenate");
|
||||
auto res_rank = op.getType().getRank();
|
||||
auto dimension = op.getDimension();
|
||||
int64_t res_rank = op.getType().getRank();
|
||||
uint32_t dimension = op.getDimension();
|
||||
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
|
||||
"Expect a valid concatenate dimension");
|
||||
if (res_rank == 1) {
|
||||
NYI("Support concatenation with 1D vectors");
|
||||
}
|
||||
auto res_ty = op.getResult().getType();
|
||||
VectorType res_ty = op.getResult().getType();
|
||||
int8_t bitwidth = res_ty.getElementTypeBitWidth();
|
||||
|
||||
std::optional<int64_t> tiling_dim;
|
||||
@ -790,29 +787,39 @@ class VectorLayoutInferer {
|
||||
if (tiling_dim.has_value()) {
|
||||
int64_t starting_point = 0;
|
||||
|
||||
auto first_layout = getLayout(op.getSources().front());
|
||||
auto op_layouts = getLayoutFromOperands(op);
|
||||
Layout first_layout = getLayout(op.getSources().front());
|
||||
SmallVector<Layout, 4> op_layouts = getLayoutFromOperands(op);
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op.getSources().size());
|
||||
|
||||
auto native_tiling = nativeTiling(bitwidth);
|
||||
|
||||
// Set implicit dim to treat 1D as (1, N) and tile it as (1, 128)
|
||||
std::array<int64_t, 2> tiling =
|
||||
res_rank == 1 ? std::array<int64_t, 2>{1L, target_shape_[1]}
|
||||
: nativeTiling(bitwidth);
|
||||
ImplicitDim implicit_dim =
|
||||
res_rank == 1 ? ImplicitDim::kSecondMinor : ImplicitDim::kNone;
|
||||
std::array<int64_t, 2> vreg_slice =
|
||||
VectorLayout::vregSlice(target_shape_, bitwidth, tiling);
|
||||
for (int i = 0; i < op.getSources().size(); ++i) {
|
||||
// Compute the offset per source.
|
||||
// Ex: for a cat of (10, 128), (10, 128) on dim 0, where the
|
||||
// vreg_sice for that dim is 8, the first source starts at
|
||||
// vreg_slice for that dim is 8, the first source starts at
|
||||
// offset 0, and overflows the vreg
|
||||
// by 2, so the offset for the second input is 2.
|
||||
auto op_shape =
|
||||
ArrayRef<int64_t> op_shape =
|
||||
cast<VectorType>(op.getSources()[i].getType()).getShape();
|
||||
auto offset_amount = starting_point % native_tiling[tiling_dim.value()];
|
||||
auto op_layout = op_layouts[i];
|
||||
Layout op_layout = op_layouts[i];
|
||||
int64_t offset_amount = starting_point % vreg_slice[tiling_dim.value()];
|
||||
if (offset_amount >= tiling[tiling_dim.value()]) {
|
||||
return op.emitError(
|
||||
"Not implemented: Input offsets outside of the first tile");
|
||||
}
|
||||
SmallVector<int64_t> in_idx{op_layout->offsets()[0].value_or(0),
|
||||
op_layout->offsets()[1].value_or(0)};
|
||||
in_idx[tiling_dim.value()] = offset_amount;
|
||||
starting_point += op_shape[dimension];
|
||||
in_layouts.push_back(VectorLayout(bitwidth, {in_idx[0], in_idx[1]},
|
||||
native_tiling, ImplicitDim::kNone));
|
||||
tiling, implicit_dim));
|
||||
}
|
||||
SmallVector<int64_t> res_layout_offsets(
|
||||
{first_layout->offsets()[0].value_or(0),
|
||||
@ -821,13 +828,13 @@ class VectorLayoutInferer {
|
||||
// TODO(mvoz): A tiny optimization we could do here later is to
|
||||
// no-op setting tiling when sublane dim size is aligned to sublane
|
||||
// tiling.
|
||||
auto res_layout =
|
||||
VectorLayout res_layout =
|
||||
VectorLayout(bitwidth, {res_layout_offsets[0], res_layout_offsets[1]},
|
||||
native_tiling, ImplicitDim::kNone);
|
||||
tiling, implicit_dim);
|
||||
setLayout(op, in_layouts, res_layout);
|
||||
return success();
|
||||
} else {
|
||||
auto layout = getLayout(op.getSources().front());
|
||||
Layout layout = getLayout(op.getSources().front());
|
||||
// When concatenating vectors with replicated offsets, we want to reset
|
||||
// the replicated offset to zero. Because we are not sure if the
|
||||
// replicated value from each vector are same.
|
||||
@ -1464,11 +1471,11 @@ class VectorLayoutInferer {
|
||||
// unfolding, it's still a no-op, but we need to
|
||||
// add support in apply-vector-layout.
|
||||
LayoutOffsets offsets = {0, layout.offsets()[1]};
|
||||
setLayout(op,
|
||||
VectorLayout(layout.bitwidth(), offsets, tiling,
|
||||
layout.implicit_dim()),
|
||||
VectorLayout(layout.bitwidth(), offsets, tiling,
|
||||
implicit_dim));
|
||||
setLayout(
|
||||
op,
|
||||
VectorLayout(layout.bitwidth(), offsets, tiling,
|
||||
layout.implicit_dim()),
|
||||
VectorLayout(layout.bitwidth(), offsets, tiling, implicit_dim));
|
||||
return success();
|
||||
}
|
||||
sublane_tiling /= 2;
|
||||
@ -1845,9 +1852,9 @@ class VectorLayoutInferer {
|
||||
"only 32-bit random bit generation supported");
|
||||
// TODO: b/342054464 - Support implicit dims for PRNGRandomBitsOp.
|
||||
LayoutOffsets offsets = {0, 0};
|
||||
setOutLayout(op, VectorLayout(
|
||||
kNativeBitwidth, offsets, nativeTiling(kNativeBitwidth),
|
||||
ImplicitDim::kNone));
|
||||
setOutLayout(
|
||||
op, VectorLayout(kNativeBitwidth, offsets,
|
||||
nativeTiling(kNativeBitwidth), ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user