mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Packed loads and stores with 1D tiling should use (1, 128 * packing)
There are multiple representations for 1D tiling in vector layouts and we need to choose one consistently. PiperOrigin-RevId: 638331061
This commit is contained in:
parent
f0eaba9e29
commit
8f8b976421
@ -2461,7 +2461,9 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Tiling memref_tiling,
|
||||
getMemRefTiling(load_op.getBase(), ctx.target_shape));
|
||||
if (layout_out.tiling() != memref_tiling) {
|
||||
if (memref_tiling != layout_out.tiling() &&
|
||||
!(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 &&
|
||||
memref_tiling[1] % layout_out.tiling()[1] == 0)) {
|
||||
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
|
||||
// TODO(b/295393167): need to support strided load for bitwidth < 32.
|
||||
if (layout_out.bitwidth() != 32 ||
|
||||
@ -3659,7 +3661,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const Tiling memref_tiling,
|
||||
getMemRefTiling(store_op.getBase(), ctx.target_shape));
|
||||
if (to_store_layout.tiling() != memref_tiling) {
|
||||
if (memref_tiling != to_store_layout.tiling() &&
|
||||
!(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 &&
|
||||
memref_tiling[1] % to_store_layout.tiling()[1] == 0)) {
|
||||
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
|
||||
// TODO(b/295393167): need to support strided store for bitwidth < 32.
|
||||
if (to_store_layout.bitwidth() != 32 ||
|
||||
|
@ -1019,6 +1019,10 @@ class VectorLayoutInferer {
|
||||
"memref and vector rank mismatch");
|
||||
int64_t rank = res_ty.getRank();
|
||||
int8_t bitwidth = res_ty.getElementTypeBitWidth();
|
||||
if (kNativeBitwidth % bitwidth != 0) {
|
||||
return op.emitOpError("Unsupported bitwidth");
|
||||
}
|
||||
const int packing = kNativeBitwidth / bitwidth;
|
||||
auto maybe_tiling =
|
||||
verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(),
|
||||
src_ty.getRank(), src_ty.getElementTypeBitWidth());
|
||||
@ -1050,12 +1054,10 @@ class VectorLayoutInferer {
|
||||
}
|
||||
if (rank == 1) {
|
||||
TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads");
|
||||
const int64_t lane_tiling = packing * target_shape_[1];
|
||||
auto tile = tiling.front();
|
||||
TPU_CHECK_OP(tile % target_shape_[1] == 0,
|
||||
"Unsupported tiling for 1D load");
|
||||
TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported tiling for 1D load");
|
||||
CHECK_EQ(tile_offsets.size(), 1);
|
||||
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
|
||||
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
|
||||
// TODO(apaszke): We could generate replicated loads for short values.
|
||||
setLayout(op, in_layout,
|
||||
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
|
||||
@ -1372,6 +1374,10 @@ class VectorLayoutInferer {
|
||||
"memref and vector rank mismatch");
|
||||
int64_t rank = ref_ty.getRank();
|
||||
int8_t bitwidth = store_ty.getElementTypeBitWidth();
|
||||
if (kNativeBitwidth % bitwidth != 0) {
|
||||
return op.emitOpError("Unsupported bitwidth");
|
||||
}
|
||||
const int packing = kNativeBitwidth / bitwidth;
|
||||
auto maybe_tiling =
|
||||
verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(),
|
||||
ref_ty.getRank(), ref_ty.getElementTypeBitWidth());
|
||||
@ -1402,11 +1408,10 @@ class VectorLayoutInferer {
|
||||
}
|
||||
if (rank == 1) {
|
||||
TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D store");
|
||||
const int64_t lane_tiling = packing * target_shape_[1];
|
||||
auto tile = tiling.front();
|
||||
TPU_CHECK_OP(tile % target_shape_[1] == 0,
|
||||
TPU_CHECK_OP(tile % lane_tiling == 0,
|
||||
"Unsupported 1D tiling for 1D store");
|
||||
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
|
||||
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
|
||||
CHECK_EQ(tile_offsets.size(), 1);
|
||||
store_layout =
|
||||
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
|
||||
@ -1806,6 +1811,7 @@ class VectorLayoutInferer {
|
||||
std::optional<absl::Span<const int64_t>> verifyMemoryTiling(
|
||||
Operation *op, ArrayRef<xla::Tile> mem_tiling, int64_t rank,
|
||||
int8_t bitwidth) {
|
||||
const int packing = kNativeBitwidth / bitwidth;
|
||||
if (bitwidth == 32) {
|
||||
if (mem_tiling.size() != 1) {
|
||||
op->emitOpError("Only one-level tiling supported for 32-bit loads");
|
||||
@ -1822,7 +1828,7 @@ class VectorLayoutInferer {
|
||||
}
|
||||
auto first = mem_tiling[0].dimensions();
|
||||
auto second = mem_tiling[1].dimensions();
|
||||
if (first.size() != 1 || first[0] % target_shape_[1] != 0) {
|
||||
if (first.size() != 1 || first[0] % (packing * target_shape_[1]) != 0) {
|
||||
op->emitOpError("Invalid first-level tile in 1D memory op");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user