mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic] Use strided load to load one entire row more efficiently.
PiperOrigin-RevId: 564831610
This commit is contained in:
parent
c617bcb515
commit
801cbef011
@ -815,9 +815,19 @@ class VectorLayoutInferer {
|
||||
offsets[0] = tile_indices[0] % tiling[0];
|
||||
}
|
||||
offsets[1] = tile_indices[1] % target_shape_[1];
|
||||
// We can use replicated loads if we're only loading a single sublane.
|
||||
std::array<int64_t, 2> layout_tiling{tiling[0], tiling[1]};
|
||||
if (num_sublanes == 1 && bitwidth == 32 && tiling == target_shape_) {
|
||||
if (num_sublanes == 1 && bitwidth == 32 &&
|
||||
tiling[1] == target_shape_[1] &&
|
||||
tile_res_shape[1] > target_shape_[1]) {
|
||||
// We can strided load sublanes if we're loading a single sublane for
|
||||
// multiple times. Enabling this helps load one entire row from memref
|
||||
// more efficiently.
|
||||
setLayout(op, in_layout,
|
||||
VectorLayout(bitwidth, offsets, {1, layout_tiling[1]},
|
||||
ImplicitDim::kNone));
|
||||
} else if (num_sublanes == 1 && bitwidth == 32 &&
|
||||
tiling == target_shape_) {
|
||||
// We can use replicated loads if we're only loading a single sublane.
|
||||
setLayout(op, in_layout,
|
||||
VectorLayout(bitwidth, {std::nullopt, offsets[1]},
|
||||
layout_tiling, ImplicitDim::kNone));
|
||||
|
@ -2342,14 +2342,34 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
|
||||
if layout_out.implicit_dim == ImplicitDim.MINOR:
|
||||
raise NotImplementedError
|
||||
is_1d = layout_out.implicit_dim is not None
|
||||
if layout_out.tiling != get_memref_tiling(op.base):
|
||||
raise NotImplementedError
|
||||
memref_tiling = get_memref_tiling(op.base)
|
||||
if layout_out.tiling != memref_tiling:
|
||||
# Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
|
||||
# TODO(b/295393167): need to support strided load for bitwidth < 32.
|
||||
if layout_out.bitwidth != 32 or layout_out.tiling != (
|
||||
1,
|
||||
TARGET_SHAPE.lanes,
|
||||
):
|
||||
raise NotImplementedError
|
||||
# TODO(apaszke): Check that loads are from vmem!
|
||||
indices = [get_int_const(v, "vector.load index") for v in op.indices]
|
||||
for i, n, extent in zip(indices, ty.shape, memref_ty.shape):
|
||||
if i + n > extent:
|
||||
raise ValueError("reading out of bounds")
|
||||
*_, ss, _ = layout_out.implicit_shape(ty.shape)
|
||||
sublane_stride = 1
|
||||
# The stride of load should be the number of sublanes in memref tile when
|
||||
# loaing a single sublane.
|
||||
if (
|
||||
layout_out.bitwidth == 32
|
||||
and layout_out.tiling
|
||||
== (
|
||||
1,
|
||||
TARGET_SHAPE.lanes,
|
||||
)
|
||||
and ss == 1
|
||||
):
|
||||
sublane_stride = memref_tiling[0]
|
||||
tiling = TargetTuple(*layout_out.tiling)
|
||||
s, l = offsets = TargetTuple(*layout_out.offsets)
|
||||
check(l is not REPLICATED, "load replicated along lanes is unsupported")
|
||||
@ -2392,7 +2412,12 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
|
||||
if bounds.mask_varies_along(SUBLANES):
|
||||
assert s is not REPLICATED # Replicated loads should never go OOB
|
||||
tile = tpu.LoadOp(
|
||||
target_ty, op.base, indices_vs, bounds.get_sublane_mask())
|
||||
target_ty,
|
||||
op.base,
|
||||
indices_vs,
|
||||
bounds.get_sublane_mask(),
|
||||
sublane_stride=sublane_stride,
|
||||
)
|
||||
else:
|
||||
if load_map is not None:
|
||||
if layout_out.bitwidth != 32:
|
||||
@ -2400,9 +2425,16 @@ def _vector_load_rule( # pylint: disable=missing-function-docstring
|
||||
tile = vector.TransferReadOp(
|
||||
target_ty, op.base, indices_vs, load_map, padding)
|
||||
else:
|
||||
assert s is not REPLICATED
|
||||
sublane_mask = ir.DenseBoolArrayAttr.get(
|
||||
[True] * TARGET_SHAPE.sublanes)
|
||||
tile = tpu.LoadOp(target_ty, op.base, indices_vs, sublane_mask)
|
||||
tile = tpu.LoadOp(
|
||||
target_ty,
|
||||
op.base,
|
||||
indices_vs,
|
||||
sublane_mask,
|
||||
sublane_stride=sublane_stride,
|
||||
)
|
||||
tiles[tile_ixs] = tile
|
||||
return ctx.replace(op, assemble(ty, layout_out, tiles))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user