mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Slightly re-arranged Pallas Mosaic GPU pipelining logic
This change prepares a few pipelining optimizations which will be done in a follow up. PiperOrigin-RevId: 672530087
This commit is contained in:
parent
fe63b991dd
commit
05cdcb8ce5
@ -251,7 +251,7 @@ def lower_jaxpr_to_module(
|
||||
start_indices: Sequence[ir.Value],
|
||||
step: ir.Value,
|
||||
shape: Sequence[int],
|
||||
) -> ir.Value:
|
||||
) -> Sequence[mgpu.DynamicSlice]:
|
||||
return tuple(
|
||||
mgpu.ds(
|
||||
arith_dialect.addi(
|
||||
@ -264,37 +264,35 @@ def lower_jaxpr_to_module(
|
||||
for axis, (start_index, dim) in enumerate(zip(start_indices, shape))
|
||||
)
|
||||
|
||||
@mgpu.single_thread()
|
||||
def fetch(step: ir.Value, slot: ir.Value) -> None:
|
||||
for start_indices, b_gmem, b_smem in zip(
|
||||
in_start_indices, in_buffers_gmem, in_buffers_smem
|
||||
):
|
||||
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
|
||||
b_smem_shape = ir.MemRefType(b_smem.type).shape[1:]
|
||||
launch_ctx.async_copy(
|
||||
src_ref=b_gmem,
|
||||
dst_ref=mgpu.memref_slice(b_smem, slot),
|
||||
gmem_slice=gmem_slice(start_indices, step, b_smem_shape),
|
||||
barrier=barriers[slot],
|
||||
swizzle=None,
|
||||
arrive=True,
|
||||
uniform=False,
|
||||
)
|
||||
def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None:
|
||||
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
|
||||
launch_ctx.async_copy(
|
||||
src_ref=in_buffers_gmem[idx],
|
||||
dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot),
|
||||
gmem_slice=gmem_slice(
|
||||
in_start_indices[idx],
|
||||
step,
|
||||
ir.MemRefType(in_buffers_smem[idx].type).shape[1:],
|
||||
),
|
||||
barrier=barriers[slot],
|
||||
swizzle=None,
|
||||
arrive=True,
|
||||
uniform=False,
|
||||
)
|
||||
|
||||
@mgpu.single_thread()
|
||||
def store(step: ir.Value, slot: ir.Value) -> None:
|
||||
for start_indices, b_gmem, b_smem in zip(
|
||||
out_start_indices, out_buffers_gmem, out_buffers_smem
|
||||
):
|
||||
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
|
||||
b_smem_shape = ir.MemRefType(b_smem.type).shape[1:]
|
||||
launch_ctx.async_copy(
|
||||
src_ref=mgpu.memref_slice(b_smem, slot),
|
||||
dst_ref=b_gmem,
|
||||
gmem_slice=gmem_slice(start_indices, step, b_smem_shape),
|
||||
swizzle=None,
|
||||
uniform=False,
|
||||
)
|
||||
def store(idx: int, step: ir.Value, slot: ir.Value) -> None:
|
||||
# TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls.
|
||||
launch_ctx.async_copy(
|
||||
src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot),
|
||||
dst_ref=out_buffers_gmem[idx],
|
||||
gmem_slice=gmem_slice(
|
||||
out_start_indices[idx],
|
||||
step,
|
||||
ir.MemRefType(out_buffers_smem[idx].type).shape[1:],
|
||||
),
|
||||
swizzle=None,
|
||||
uniform=False,
|
||||
)
|
||||
|
||||
# Compute the number of steps along each sequential axis.
|
||||
if sequential_axes:
|
||||
@ -325,8 +323,10 @@ def lower_jaxpr_to_module(
|
||||
else:
|
||||
num_steps = 1
|
||||
|
||||
for slot in range(min(num_stages, num_steps)):
|
||||
fetch(_as_index(slot), _as_index(slot))
|
||||
with mgpu.single_thread():
|
||||
for slot in range(min(num_stages, num_steps)):
|
||||
for idx in range(grid_mapping.num_inputs):
|
||||
fetch(idx, _as_index(slot), _as_index(slot))
|
||||
|
||||
@mgpu.fori(_as_index(num_steps), ())
|
||||
def _(step, _):
|
||||
@ -341,14 +341,18 @@ def lower_jaxpr_to_module(
|
||||
[mgpu.memref_slice(b_smem, slot) for b_smem in buffers_smem],
|
||||
)
|
||||
mgpu.commit_shared()
|
||||
store(step, slot)
|
||||
|
||||
with mgpu.single_thread():
|
||||
for idx in range(grid_mapping.num_outputs):
|
||||
store(idx, step, slot)
|
||||
|
||||
next_step = arith_dialect.addi(step, _as_index(num_stages))
|
||||
next_step_in_bounds = arith_dialect.cmpi(
|
||||
arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps)
|
||||
)
|
||||
with mgpu.when(next_step_in_bounds):
|
||||
fetch(next_step, slot)
|
||||
with mgpu.when(next_step_in_bounds), mgpu.single_thread():
|
||||
for idx in range(grid_mapping.num_inputs):
|
||||
fetch(idx, next_step, slot)
|
||||
|
||||
return ()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user