[Mosaic] Use strided load to load one entire row more efficiently.

PiperOrigin-RevId: 564831610
This commit is contained in:
Jevin Jiang 2023-09-12 14:18:53 -07:00 committed by jax authors
parent c617bcb515
commit 801cbef011
2 changed files with 48 additions and 6 deletions

View File

@ -815,9 +815,19 @@ class VectorLayoutInferer {
offsets[0] = tile_indices[0] % tiling[0];
}
offsets[1] = tile_indices[1] % target_shape_[1];
// We can use replicated loads if we're only loading a single sublane.
std::array<int64_t, 2> layout_tiling{tiling[0], tiling[1]};
if (num_sublanes == 1 && bitwidth == 32 && tiling == target_shape_) {
if (num_sublanes == 1 && bitwidth == 32 &&
tiling[1] == target_shape_[1] &&
tile_res_shape[1] > target_shape_[1]) {
// We can strided load sublanes if we're loading a single sublane for
// multiple times. Enabling this helps load one entire row from memref
// more efficiently.
setLayout(op, in_layout,
VectorLayout(bitwidth, offsets, {1, layout_tiling[1]},
ImplicitDim::kNone));
} else if (num_sublanes == 1 && bitwidth == 32 &&
tiling == target_shape_) {
// We can use replicated loads if we're only loading a single sublane.
setLayout(op, in_layout,
VectorLayout(bitwidth, {std::nullopt, offsets[1]},
layout_tiling, ImplicitDim::kNone));

View File

@ -2342,14 +2342,34 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
if layout_out.implicit_dim == ImplicitDim.MINOR:
raise NotImplementedError
is_1d = layout_out.implicit_dim is not None
if layout_out.tiling != get_memref_tiling(op.base):
raise NotImplementedError
memref_tiling = get_memref_tiling(op.base)
if layout_out.tiling != memref_tiling:
# Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
# TODO(b/295393167): need to support strided load for bitwidth < 32.
if layout_out.bitwidth != 32 or layout_out.tiling != (
1,
TARGET_SHAPE.lanes,
):
raise NotImplementedError
# TODO(apaszke): Check that loads are from vmem!
indices = [get_int_const(v, "vector.load index") for v in op.indices]
for i, n, extent in zip(indices, ty.shape, memref_ty.shape):
if i + n > extent:
raise ValueError("reading out of bounds")
*_, ss, _ = layout_out.implicit_shape(ty.shape)
sublane_stride = 1
# The stride of load should be the number of sublanes in memref tile when
# loaing a single sublane.
if (
layout_out.bitwidth == 32
and layout_out.tiling
== (
1,
TARGET_SHAPE.lanes,
)
and ss == 1
):
sublane_stride = memref_tiling[0]
tiling = TargetTuple(*layout_out.tiling)
s, l = offsets = TargetTuple(*layout_out.offsets)
check(l is not REPLICATED, "load replicated along lanes is unsupported")
@ -2392,7 +2412,12 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
if bounds.mask_varies_along(SUBLANES):
assert s is not REPLICATED # Replicated loads should never go OOB
tile = tpu.LoadOp(
target_ty, op.base, indices_vs, bounds.get_sublane_mask())
target_ty,
op.base,
indices_vs,
bounds.get_sublane_mask(),
sublane_stride=sublane_stride,
)
else:
if load_map is not None:
if layout_out.bitwidth != 32:
@ -2400,9 +2425,16 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
tile = vector.TransferReadOp(
target_ty, op.base, indices_vs, load_map, padding)
else:
assert s is not REPLICATED
sublane_mask = ir.DenseBoolArrayAttr.get(
[True] * TARGET_SHAPE.sublanes)
tile = tpu.LoadOp(target_ty, op.base, indices_vs, sublane_mask)
tile = tpu.LoadOp(
target_ty,
op.base,
indices_vs,
sublane_mask,
sublane_stride=sublane_stride,
)
tiles[tile_ixs] = tile
return ctx.replace(op, assemble(ty, layout_out, tiles))