[XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.

PiperOrigin-RevId: 614862128
This commit is contained in:
Jevin Jiang 2024-03-11 18:21:30 -07:00 committed by jax authors
parent 63538771b5
commit 30208fa9cc
4 changed files with 255 additions and 3 deletions

View File

@ -194,6 +194,33 @@ def TPU_LoadOp : TPU_Op<"load"> {
}];
}
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
let arguments = (ins
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$base attr-dict `:` type($base) `,` type($result)
}];
let hasVerifier = 1;
}
def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
let arguments = (ins
AnyVector:$valueToStore,
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs);
let assemblyFormat = [{
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
}];
let hasVerifier = 1;
}
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
AnyVector:$value,

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/IRMapping.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
@ -180,6 +181,56 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op,
return success();
}
template <typename Op>
LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
VectorType vector_ty) {
auto indices = op.getIndices();
auto strides = op.getStrides();
if (memref_ty.getRank() != indices.size()) {
op.emitError("Base memref's rank and indices size do not match: ")
<< memref_ty.getRank() << " vs " << indices.size();
return failure();
}
if (memref_ty.getRank() != strides.size()) {
op.emitError("Base memref's rank and strides size do not match: ")
<< memref_ty.getRank() << " vs " << strides.size();
return failure();
}
if (memref_ty.getRank() != vector_ty.getRank()) {
op.emitError("Base memref's rank and result's rank do not match: ")
<< memref_ty.getRank() << " vs " << vector_ty.getRank();
return failure();
}
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
op.emitError("Indices[")
<< i << "]=" << indices[i] << " is out of range [0, "
<< memref_ty.getDimSize(i) << ")";
return failure();
}
if (strides[i] < 1) {
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
return failure();
}
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
memref_ty.getDimSize(i)) {
op.emitError() << "Strided slice is out of range at dim " << i;
return failure();
}
}
return success();
}
LogicalResult StridedLoadOp::verify() {
return verifyStridedOp<StridedLoadOp>(*this, getMemRefType(getBase()),
getType());
}
LogicalResult StridedStoreOp::verify() {
return verifyStridedOp<StridedStoreOp>(*this, getMemRefType(getBase()),
getValueToStore().getType());
}
LogicalResult ReinterpretCastOp::verify() {
auto source_type = getMemRefType(getInput());
auto target_type = getType();

View File

@ -1158,6 +1158,139 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
Value base_ref, const VectorType &vty,
const VectorLayout &layout,
const ArrayRef<int32_t> &indices,
const ArrayRef<int32_t> &strides) {
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
return op.emitOpError("Not implemented: Unsupported strided op")
<< op.getName();
}
if (layout != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError("Not implemented: Unsupported vector layout in ")
<< op.getName();
}
const auto base_ty = getMemRefType(base_ref);
auto rank = base_ty.getRank();
CHECK_EQ(rank, indices.size());
CHECK_EQ(rank, strides.size());
CHECK_EQ(rank, vty.getShape().size());
if (rank < 2) {
return op.emitOpError("Not implemented: Stride on 1D vector");
}
auto mem_layout = dyn_cast<TiledLayoutAttr>(base_ty.getLayout());
if (!mem_layout) {
return op.emitOpError("Expected a tiled memref");
}
auto tile_strides = mem_layout.getTileStrides();
// Currently we hold constraints that the last dim size of memref needs to be
// exactly same as the lane size of native vreg and the memref has never
// been sliced before on the last dim. In other words, the original base
// memref's shape needs to be (..., target_shape[1]).
if (base_ty.getShape()[rank - 1] != ctx.target_shape[1] ||
tile_strides.take_back(2) != ArrayRef<int64_t>{1, 1}) {
return op.emitOpError("Not Implemented: The last dim size is not ")
<< ctx.target_shape[1] << " in original base memref";
}
if (strides[rank - 1] != 1) {
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
}
if (indices[rank - 1] != 0) {
return op.emitOpError("Not Implemented: Index on last dim is not 0");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
FAILUREOR_ASSIGN_OR_RETURN(
VectorType vreg_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
bool is_load_op = true;
xla::Array<Value> tiles(
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
if (auto store_op = dyn_cast<tpu::StridedStoreOp>(op)) {
is_load_op = false;
FAILUREOR_ASSIGN_OR_RETURN(
tiles, disassemble(builder, layout, store_op.getValueToStore(),
ctx.target_shape));
}
tiles.Each([&](absl::Span<const int64_t> tile_idxs, Value *v) {
CHECK_EQ(tile_idxs.size(), rank);
SmallVector<Value> idxs(rank);
for (int64_t i = 0; i < rank; ++i) {
int64_t stride = (i < rank - 2)
? strides[i]
: (strides[i] * ctx.target_shape[i - rank + 2]);
idxs[i] =
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
}
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
if (sublane_rem > 0 && tile_idxs[rank - 2] == tiles.dim(rank - 2) - 1) {
for (int64_t i = sublane_rem; i < ctx.target_shape[0]; ++i) {
sublane_mask[i] = false;
}
}
const auto sublane_mask_attr =
DenseBoolArrayAttr::get(op.getContext(), sublane_mask);
if (is_load_op) {
*v = builder.create<tpu::LoadOp>(
vreg_ty, base_ref, idxs, sublane_mask_attr,
builder.getI32IntegerAttr(strides[rank - 2]));
} else {
builder.create<tpu::StoreOp>(
*v, base_ref, idxs, sublane_mask_attr,
/*mask=*/nullptr, builder.getI32IntegerAttr(strides[rank - 2]));
}
});
if (is_load_op) {
op.replaceAllUsesWith(
assemble(builder, vty, layout, std::move(tiles), ctx.target_shape));
}
op.erase();
return success();
}
// TODO(jevinjiang): maybe unify with vector load?
LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_OP(llvm::none_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_out = *layouts_out.front();
auto load_op = cast<tpu::StridedLoadOp>(op);
const auto base_ref = load_op.getBase();
const auto indices = load_op.getIndices();
const auto strides = load_op.getStrides();
const auto vty = cast<VectorType>(load_op.getResult().getType());
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
strides);
}
// TODO(jevinjiang): maybe unify with vector store?
LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_OP(layouts_in.front().has_value());
TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
[&](const Layout &l) { return l.has_value(); }));
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
const VectorLayout &to_store_layout = *layouts_in.front();
auto store_op = cast<tpu::StridedStoreOp>(op);
const auto base_ref = store_op.getBase();
const auto indices = store_op.getIndices();
const auto strides = store_op.getStrides();
const auto vty = store_op.getValueToStore().getType();
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
strides);
}
LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,
const bool transpose_lhs,
const bool transpose_rhs,
@ -3510,10 +3643,12 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
{tpu::LoadOp::getOperationName(), tpu_load_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule},
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
{tpu::RegionOp::getOperationName(), tpu_region_rule},
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},

View File

@ -232,11 +232,19 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
} else if (auto op = dyn_cast<tpu::StridedLoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::StridedStoreOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
@ -581,6 +589,37 @@ class VectorLayoutInferer {
return success();
}
LogicalResult infer(tpu::StridedLoadOp op) {
auto vty = op.getResult().getType();
int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != 32) {
NYI("Strided load with non 32-bit data");
}
if (vty.getRank() < 2) {
NYI("Strided load with 1D vector");
}
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
setLayout(op, in_layout,
VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone));
return success();
}
LogicalResult infer(tpu::StridedStoreOp op) {
auto vty = op.getValueToStore().getType();
int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != 32) {
NYI("Strided store with non 32-bit data");
}
if (vty.getRank() < 2) {
NYI("Strided store with 1D vector");
}
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setInLayout(op, {store_layout, kNoLayout});
return success();
}
LogicalResult infer(tpu::MatmulOp op) { return inferMatmul(op); }
LogicalResult infer(tpu::StoreOp op) {