mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
This commit is contained in:
parent
63538771b5
commit
30208fa9cc
@ -194,6 +194,33 @@ def TPU_LoadOp : TPU_Op<"load"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
|
||||
let arguments = (ins
|
||||
AnyMemRef:$base,
|
||||
DenseI32ArrayAttr:$indices,
|
||||
DenseI32ArrayAttr:$strides
|
||||
);
|
||||
let results = (outs AnyVector:$result);
|
||||
let assemblyFormat = [{
|
||||
$base attr-dict `:` type($base) `,` type($result)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
|
||||
let arguments = (ins
|
||||
AnyVector:$valueToStore,
|
||||
AnyMemRef:$base,
|
||||
DenseI32ArrayAttr:$indices,
|
||||
DenseI32ArrayAttr:$strides
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
|
||||
let arguments = (ins
|
||||
AnyVector:$value,
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/include/mlir/IR/IRMapping.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
@ -180,6 +181,56 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
|
||||
VectorType vector_ty) {
|
||||
auto indices = op.getIndices();
|
||||
auto strides = op.getStrides();
|
||||
if (memref_ty.getRank() != indices.size()) {
|
||||
op.emitError("Base memref's rank and indices size do not match: ")
|
||||
<< memref_ty.getRank() << " vs " << indices.size();
|
||||
return failure();
|
||||
}
|
||||
if (memref_ty.getRank() != strides.size()) {
|
||||
op.emitError("Base memref's rank and strides size do not match: ")
|
||||
<< memref_ty.getRank() << " vs " << strides.size();
|
||||
return failure();
|
||||
}
|
||||
if (memref_ty.getRank() != vector_ty.getRank()) {
|
||||
op.emitError("Base memref's rank and result's rank do not match: ")
|
||||
<< memref_ty.getRank() << " vs " << vector_ty.getRank();
|
||||
return failure();
|
||||
}
|
||||
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
|
||||
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
|
||||
op.emitError("Indices[")
|
||||
<< i << "]=" << indices[i] << " is out of range [0, "
|
||||
<< memref_ty.getDimSize(i) << ")";
|
||||
return failure();
|
||||
}
|
||||
if (strides[i] < 1) {
|
||||
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
|
||||
return failure();
|
||||
}
|
||||
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
|
||||
memref_ty.getDimSize(i)) {
|
||||
op.emitError() << "Strided slice is out of range at dim " << i;
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult StridedLoadOp::verify() {
|
||||
return verifyStridedOp<StridedLoadOp>(*this, getMemRefType(getBase()),
|
||||
getType());
|
||||
}
|
||||
|
||||
LogicalResult StridedStoreOp::verify() {
|
||||
return verifyStridedOp<StridedStoreOp>(*this, getMemRefType(getBase()),
|
||||
getValueToStore().getType());
|
||||
}
|
||||
|
||||
LogicalResult ReinterpretCastOp::verify() {
|
||||
auto source_type = getMemRefType(getInput());
|
||||
auto target_type = getType();
|
||||
|
@ -1158,6 +1158,139 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
|
||||
Value base_ref, const VectorType &vty,
|
||||
const VectorLayout &layout,
|
||||
const ArrayRef<int32_t> &indices,
|
||||
const ArrayRef<int32_t> &strides) {
|
||||
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
|
||||
return op.emitOpError("Not implemented: Unsupported strided op")
|
||||
<< op.getName();
|
||||
}
|
||||
if (layout != VectorLayout(32, {0, 0}, ctx.target_shape,
|
||||
VectorLayout::ImplicitDim::kNone)) {
|
||||
return op.emitOpError("Not implemented: Unsupported vector layout in ")
|
||||
<< op.getName();
|
||||
}
|
||||
const auto base_ty = getMemRefType(base_ref);
|
||||
auto rank = base_ty.getRank();
|
||||
CHECK_EQ(rank, indices.size());
|
||||
CHECK_EQ(rank, strides.size());
|
||||
CHECK_EQ(rank, vty.getShape().size());
|
||||
if (rank < 2) {
|
||||
return op.emitOpError("Not implemented: Stride on 1D vector");
|
||||
}
|
||||
auto mem_layout = dyn_cast<TiledLayoutAttr>(base_ty.getLayout());
|
||||
if (!mem_layout) {
|
||||
return op.emitOpError("Expected a tiled memref");
|
||||
}
|
||||
auto tile_strides = mem_layout.getTileStrides();
|
||||
|
||||
// Currently we hold constraints that the last dim size of memref needs to be
|
||||
// exactly same as the lane size of native vreg and the memref has never
|
||||
// been sliced before on the last dim. In other words, the original base
|
||||
// memref's shape needs to be (..., target_shape[1]).
|
||||
if (base_ty.getShape()[rank - 1] != ctx.target_shape[1] ||
|
||||
tile_strides.take_back(2) != ArrayRef<int64_t>{1, 1}) {
|
||||
return op.emitOpError("Not Implemented: The last dim size is not ")
|
||||
<< ctx.target_shape[1] << " in original base memref";
|
||||
}
|
||||
if (strides[rank - 1] != 1) {
|
||||
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
|
||||
}
|
||||
if (indices[rank - 1] != 0) {
|
||||
return op.emitOpError("Not Implemented: Index on last dim is not 0");
|
||||
}
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
VectorType vreg_ty,
|
||||
getNativeVregType(vty.getElementType(), ctx.target_shape));
|
||||
|
||||
bool is_load_op = true;
|
||||
xla::Array<Value> tiles(
|
||||
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
|
||||
if (auto store_op = dyn_cast<tpu::StridedStoreOp>(op)) {
|
||||
is_load_op = false;
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
tiles, disassemble(builder, layout, store_op.getValueToStore(),
|
||||
ctx.target_shape));
|
||||
}
|
||||
|
||||
tiles.Each([&](absl::Span<const int64_t> tile_idxs, Value *v) {
|
||||
CHECK_EQ(tile_idxs.size(), rank);
|
||||
SmallVector<Value> idxs(rank);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
int64_t stride = (i < rank - 2)
|
||||
? strides[i]
|
||||
: (strides[i] * ctx.target_shape[i - rank + 2]);
|
||||
idxs[i] =
|
||||
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
|
||||
}
|
||||
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
|
||||
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
|
||||
if (sublane_rem > 0 && tile_idxs[rank - 2] == tiles.dim(rank - 2) - 1) {
|
||||
for (int64_t i = sublane_rem; i < ctx.target_shape[0]; ++i) {
|
||||
sublane_mask[i] = false;
|
||||
}
|
||||
}
|
||||
const auto sublane_mask_attr =
|
||||
DenseBoolArrayAttr::get(op.getContext(), sublane_mask);
|
||||
if (is_load_op) {
|
||||
*v = builder.create<tpu::LoadOp>(
|
||||
vreg_ty, base_ref, idxs, sublane_mask_attr,
|
||||
builder.getI32IntegerAttr(strides[rank - 2]));
|
||||
} else {
|
||||
builder.create<tpu::StoreOp>(
|
||||
*v, base_ref, idxs, sublane_mask_attr,
|
||||
/*mask=*/nullptr, builder.getI32IntegerAttr(strides[rank - 2]));
|
||||
}
|
||||
});
|
||||
if (is_load_op) {
|
||||
op.replaceAllUsesWith(
|
||||
assemble(builder, vty, layout, std::move(tiles), ctx.target_shape));
|
||||
}
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(jevinjiang): maybe unify with vector load?
|
||||
LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_OP(llvm::none_of(layouts_in,
|
||||
[&](const Layout &l) { return l.has_value(); }));
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
auto load_op = cast<tpu::StridedLoadOp>(op);
|
||||
const auto base_ref = load_op.getBase();
|
||||
const auto indices = load_op.getIndices();
|
||||
const auto strides = load_op.getStrides();
|
||||
const auto vty = cast<VectorType>(load_op.getResult().getType());
|
||||
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
|
||||
strides);
|
||||
}
|
||||
|
||||
// TODO(jevinjiang): maybe unify with vector store?
|
||||
LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_OP(layouts_in.front().has_value());
|
||||
TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
|
||||
[&](const Layout &l) { return l.has_value(); }));
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
|
||||
|
||||
const VectorLayout &to_store_layout = *layouts_in.front();
|
||||
auto store_op = cast<tpu::StridedStoreOp>(op);
|
||||
const auto base_ref = store_op.getBase();
|
||||
const auto indices = store_op.getIndices();
|
||||
const auto strides = store_op.getStrides();
|
||||
const auto vty = store_op.getValueToStore().getType();
|
||||
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
|
||||
strides);
|
||||
}
|
||||
|
||||
LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,
|
||||
const bool transpose_lhs,
|
||||
const bool transpose_rhs,
|
||||
@ -3510,10 +3643,12 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
|
||||
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
|
||||
{tpu::LoadOp::getOperationName(), tpu_load_rule},
|
||||
{tpu::StoreOp::getOperationName(), tpu_store_rule},
|
||||
{tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule},
|
||||
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
|
||||
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
|
||||
{tpu::RegionOp::getOperationName(), tpu_region_rule},
|
||||
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
|
||||
{tpu::StoreOp::getOperationName(), tpu_store_rule},
|
||||
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
|
||||
|
@ -232,11 +232,19 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
|
||||
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
|
||||
} else if (auto op = dyn_cast<tpu::StridedLoadOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::StridedStoreOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
@ -581,6 +589,37 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::StridedLoadOp op) {
|
||||
auto vty = op.getResult().getType();
|
||||
int8_t bitwidth = vty.getElementTypeBitWidth();
|
||||
if (bitwidth != 32) {
|
||||
NYI("Strided load with non 32-bit data");
|
||||
}
|
||||
if (vty.getRank() < 2) {
|
||||
NYI("Strided load with 1D vector");
|
||||
}
|
||||
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
|
||||
setLayout(op, in_layout,
|
||||
VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
|
||||
ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::StridedStoreOp op) {
|
||||
auto vty = op.getValueToStore().getType();
|
||||
int8_t bitwidth = vty.getElementTypeBitWidth();
|
||||
if (bitwidth != 32) {
|
||||
NYI("Strided store with non 32-bit data");
|
||||
}
|
||||
if (vty.getRank() < 2) {
|
||||
NYI("Strided store with 1D vector");
|
||||
}
|
||||
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
|
||||
ImplicitDim::kNone);
|
||||
setInLayout(op, {store_layout, kNoLayout});
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::MatmulOp op) { return inferMatmul(op); }
|
||||
|
||||
LogicalResult infer(tpu::StoreOp op) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user