mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Mosaic] Add internal scratch VMEM
- Make internal scratch size configurable. - Pass the number of max sublanes allowed in scratch to apply-vector-layout pass. - Create a helper function to fetch internal scratch VMEM address. PiperOrigin-RevId: 644184896
This commit is contained in:
parent
701c63e19a
commit
ed4958cb3e
@ -605,6 +605,12 @@ def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> {
|
||||
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
|
||||
}
|
||||
|
||||
def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> {
|
||||
let arguments = (ins);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
let assemblyFormat = [{ attr-dict `:` type($result) }];
|
||||
}
|
||||
|
||||
def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
|
||||
let arguments = (ins Variadic<I32>:$seeds);
|
||||
let results = (outs);
|
||||
@ -695,6 +701,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO
|
||||
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
|
||||
Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">,
|
||||
Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">,
|
||||
Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -58,7 +58,8 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
int hardware_generation = -1, int lane_count = 128, int sublane_count = 8,
|
||||
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128);
|
||||
int mxu_contracting_size = 128, int mxu_noncontracting_size = 128,
|
||||
int max_sublanes_in_scratch = 0);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
@ -157,6 +158,36 @@ FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Get the address of pre-allocated internal scratch space with requested shape.
|
||||
//
|
||||
// Arguments:
|
||||
// shape: The shape of the requested scratch space.
|
||||
// elem_ty: The type of the elements in the requested scratch space.
|
||||
//
|
||||
// Returns:
|
||||
// A memref of the requested shape and type.
|
||||
FailureOr<Value> getInternalScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
Location loc, ArrayRef<int64_t> shape,
|
||||
Type elem_ty) {
|
||||
if (shape.empty()) {
|
||||
return failure();
|
||||
}
|
||||
if (shape.back() % ctx.target_shape[1] != 0) {
|
||||
return failure();
|
||||
}
|
||||
int sublane_count =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) /
|
||||
ctx.target_shape[1];
|
||||
if (sublane_count > ctx.max_sublanes_in_scratch) {
|
||||
return failure();
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
MemRefType scratch_ref_ty,
|
||||
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation));
|
||||
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Models Numpy's np.repeat, repeating each element `repeats` times along the
|
||||
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
|
||||
// 3, this will return [1, 1, 1, 2, 2, 2].
|
||||
@ -5024,12 +5055,14 @@ struct ApplyVectorLayoutPass
|
||||
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
|
||||
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
|
||||
int sublane_count_, int mxu_contracting_size_,
|
||||
int mxu_noncontracting_size_) {
|
||||
int mxu_noncontracting_size_,
|
||||
int max_sublanes_in_scratch_) {
|
||||
hardware_generation = hardware_generation_;
|
||||
sublane_count = sublane_count_;
|
||||
lane_count = lane_count_;
|
||||
mxu_contracting_size = mxu_contracting_size_;
|
||||
mxu_noncontracting_size = mxu_noncontracting_size_;
|
||||
max_sublanes_in_scratch = max_sublanes_in_scratch_;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
// Fail if hardware_generation has not been set from the default value.
|
||||
@ -5041,7 +5074,8 @@ struct ApplyVectorLayoutPass
|
||||
RewriteContext ctx{func,
|
||||
hardware_generation,
|
||||
{sublane_count, lane_count},
|
||||
{mxu_contracting_size, mxu_noncontracting_size}};
|
||||
{mxu_contracting_size, mxu_noncontracting_size},
|
||||
max_sublanes_in_scratch};
|
||||
if (failed(applyLayoutFunc(ctx, func))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
@ -5051,10 +5085,11 @@ struct ApplyVectorLayoutPass
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
int hardware_generation, int lane_count, int sublane_count,
|
||||
int mxu_contracting_size, int mxu_noncontracting_size) {
|
||||
int mxu_contracting_size, int mxu_noncontracting_size,
|
||||
int max_sublanes_in_scratch) {
|
||||
return std::make_unique<ApplyVectorLayoutPass>(
|
||||
hardware_generation, lane_count, sublane_count, mxu_contracting_size,
|
||||
mxu_noncontracting_size);
|
||||
mxu_noncontracting_size, max_sublanes_in_scratch);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
@ -21,6 +21,7 @@ struct RewriteContext {
|
||||
const int hardware_generation;
|
||||
const std::array<int64_t, 2> target_shape = {8, 128};
|
||||
const std::array<int64_t, 2> mxu_shape = {128, 128};
|
||||
const int max_sublanes_in_scratch = 0;
|
||||
|
||||
MLIRContext *getMLIRContext() { return func.getContext(); }
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user