[pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM

Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
This commit is contained in:
Sergei Lebedev 2024-10-31 02:20:33 -07:00 committed by jax authors
parent 7d504cd95a
commit 85662f6dd8
5 changed files with 7 additions and 1 deletions

View File

@ -27,6 +27,7 @@ Functions
barrier_arrive
barrier_wait
commit_smem
copy_gmem_to_smem
copy_smem_to_gmem
emit_pipeline

View File

@ -158,6 +158,7 @@ def emit_pipeline(
)
# Copy the output from SMEM to GMEM.
gpu_primitives.commit_smem()
map(lambda bref: bref.copy_out(slot, indices), out_brefs)
fetch_step = step + max_concurrent_steps

View File

@ -66,7 +66,6 @@ def _copy_smem_to_gmem_lowering(
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
src, src_transforms = lowering._handle_indexing(src, src_transforms)
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
mgpu.commit_shared()
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
return ()
@ -105,6 +104,7 @@ def copy_smem_to_gmem(
See also:
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
"""
if src.memory_space is not gpu_core.SMEM:
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")

View File

@ -144,6 +144,7 @@ def attention(q, k, v, config: TuningConfig):
# TODO(apaszke): Invert and multiply to avoid expensive divisions.
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
qo_smem[...] = acc.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)],
)

View File

@ -239,6 +239,7 @@ class PallasCallTest(PallasTest):
)
def kernel(x_ref, o_ref_gmem, scratch_ref):
scratch_ref[...] = x_ref[...] + 1
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
plgpu.wait_smem_to_gmem(0)
@ -294,6 +295,7 @@ class PallasCallTest(PallasTest):
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref)
plgpu.barrier_wait(barrier_ref)
else:
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(x_ref, o_ref)
plgpu.wait_smem_to_gmem(0)
@ -1046,6 +1048,7 @@ class PipelineTest(PallasTest):
o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
)