[XLA:Mosaic] Add internal scratch VMEM

- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.

PiperOrigin-RevId: 644184896
This commit is contained in:
Jevin Jiang 2024-06-17 17:28:49 -07:00 committed by jax authors
parent 701c63e19a
commit ed4958cb3e
4 changed files with 49 additions and 5 deletions

View File

@ -605,6 +605,12 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> {
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
}
def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> {
let arguments = (ins);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
}
def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
let arguments = (ins Variadic<I32>:$seeds);
let results = (outs);
@ -695,6 +701,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">,
Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">,
Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">,
];
}

View File

@ -58,7 +58,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation = -1, int lane_count = 128, int sublane_count = 8,
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128);
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128,
int max_sublanes_in_scratch = 0);
std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);

View File

@ -7,6 +7,7 @@
#include <cstdlib>
#include <functional>
#include <memory>
#include <numeric>
#include <optional>
#include <tuple>
#include <utility>
@ -157,6 +158,36 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
.getResult();
}
// Get the address of pre-allocated internal scratch space with requested shape.
//
// Arguments:
// shape: The shape of the requested scratch space.
// elem_ty: The type of the elements in the requested scratch space.
//
// Returns:
// A memref of the requested shape and type.
FailureOr<Value> getInternalScratch(RewriteContext &ctx, OpBuilder &builder,
Location loc, ArrayRef<int64_t> shape,
Type elem_ty) {
if (shape.empty()) {
return failure();
}
if (shape.back() % ctx.target_shape[1] != 0) {
return failure();
}
int sublane_count =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) /
ctx.target_shape[1];
if (sublane_count > ctx.max_sublanes_in_scratch) {
return failure();
}
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType scratch_ref_ty,
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation));
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
.getResult();
}
// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
@ -5024,12 +5055,14 @@ struct ApplyVectorLayoutPass
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
int sublane_count_, int mxu_contracting_size_,
int mxu_noncontracting_size_) {
int mxu_noncontracting_size_,
int max_sublanes_in_scratch_) {
hardware_generation = hardware_generation_;
sublane_count = sublane_count_;
lane_count = lane_count_;
mxu_contracting_size = mxu_contracting_size_;
mxu_noncontracting_size = mxu_noncontracting_size_;
max_sublanes_in_scratch = max_sublanes_in_scratch_;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
@ -5041,7 +5074,8 @@ struct ApplyVectorLayoutPass
RewriteContext ctx{func,
hardware_generation,
{sublane_count, lane_count},
{mxu_contracting_size, mxu_noncontracting_size}};
{mxu_contracting_size, mxu_noncontracting_size},
max_sublanes_in_scratch};
if (failed(applyLayoutFunc(ctx, func))) {
signalPassFailure();
return;
@ -5051,10 +5085,11 @@ struct ApplyVectorLayoutPass
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation, int lane_count, int sublane_count,
int mxu_contracting_size, int mxu_noncontracting_size) {
int mxu_contracting_size, int mxu_noncontracting_size,
int max_sublanes_in_scratch) {
return std::make_unique<ApplyVectorLayoutPass>(
hardware_generation, lane_count, sublane_count, mxu_contracting_size,
mxu_noncontracting_size);
mxu_noncontracting_size, max_sublanes_in_scratch);
}
} // namespace mlir::tpu

View File

@ -21,6 +21,7 @@ struct RewriteContext {
const int hardware_generation;
const std::array<int64_t, 2> target_shape = {8, 128};
const std::array<int64_t, 2> mxu_shape = {128, 128};
const int max_sublanes_in_scratch = 0;
MLIRContext *getMLIRContext() { return func.getContext(); }
};