Slightly re-arranged Pallas Mosaic GPU pipelining logic

This change prepares a few pipelining optimizations which will be done in a
follow up.

PiperOrigin-RevId: 672530087
This commit is contained in:
Sergei Lebedev 2024-09-09 06:56:03 -07:00 committed by jax authors
parent fe63b991dd
commit 05cdcb8ce5

View File

@ -251,7 +251,7 @@ def lower_jaxpr_to_module(
start_indices: Sequence[ir.Value],
step: ir.Value,
shape: Sequence[int],
) -> ir.Value:
) -> Sequence[mgpu.DynamicSlice]:
return tuple(
mgpu.ds(
arith_dialect.addi(
@ -264,37 +264,35 @@ def lower_jaxpr_to_module(
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
)
@mgpu.single_thread()
def fetch(step: ir.Value, slot: ir.Value) -> None:
for start_indices, b_gmem, b_smem in zip(
in_start_indices, in_buffers_gmem, in_buffers_smem
):
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
b_smem_shape = ir.MemRefType(b_smem.type).shape[1:]
launch_ctx.async_copy(
src_ref=b_gmem,
dst_ref=mgpu.memref_slice(b_smem, slot),
gmem_slice=gmem_slice(start_indices, step, b_smem_shape),
barrier=barriers[slot],
swizzle=None,
arrive=True,
uniform=False,
)
def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=in_buffers_gmem[idx],
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
gmem_slice=gmem_slice(
in_start_indices[idx],
step,
ir.MemRefType(in_buffers_smem[idx].type).shape[1:],
),
barrier=barriers[slot],
swizzle=None,
arrive=True,
uniform=False,
)
@mgpu.single_thread()
def store(step: ir.Value, slot: ir.Value) -> None:
for start_indices, b_gmem, b_smem in zip(
out_start_indices, out_buffers_gmem, out_buffers_smem
):
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
b_smem_shape = ir.MemRefType(b_smem.type).shape[1:]
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(b_smem, slot),
dst_ref=b_gmem,
gmem_slice=gmem_slice(start_indices, step, b_smem_shape),
swizzle=None,
uniform=False,
)
def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
launch_ctx.async_copy(
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
dst_ref=out_buffers_gmem[idx],
gmem_slice=gmem_slice(
out_start_indices[idx],
step,
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
),
swizzle=None,
uniform=False,
)
# Compute the number of steps along each sequential axis.
if sequential_axes:
@ -325,8 +323,10 @@ def lower_jaxpr_to_module(
else:
num_steps = 1
for slot in range(min(num_stages, num_steps)):
fetch(_as_index(slot), _as_index(slot))
with mgpu.single_thread():
for slot in range(min(num_stages, num_steps)):
for idx in range(grid_mapping.num_inputs):
fetch(idx, _as_index(slot), _as_index(slot))
@mgpu.fori(_as_index(num_steps), ())
def _(step, _):
@ -341,14 +341,18 @@ def lower_jaxpr_to_module(
[mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem],
)
mgpu.commit_shared()
store(step, slot)
with mgpu.single_thread():
for idx in range(grid_mapping.num_outputs):
store(idx, step, slot)
next_step = arith_dialect.addi(step, _as_index(num_stages))
next_step_in_bounds = arith_dialect.cmpi(
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
)
with mgpu.when(next_step_in_bounds):
fetch(next_step, slot)
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
for idx in range(grid_mapping.num_inputs):
fetch(idx, next_step, slot)
return ()