diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 71bf9c3ff..2d3452609 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -27,6 +27,7 @@ Functions barrier_arrive barrier_wait + commit_smem copy_gmem_to_smem copy_smem_to_gmem emit_pipeline diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 8d2274f14..9a17646f0 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -158,6 +158,7 @@ def emit_pipeline( ) # Copy the output from SMEM to GMEM. + gpu_primitives.commit_smem() map(lambda bref: bref.copy_out(slot, indices), out_brefs) fetch_step = step + max_concurrent_steps diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index f87e96a30..1a5ed7f0d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -66,7 +66,6 @@ def _copy_smem_to_gmem_lowering( dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_indexing(src, src_transforms) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - mgpu.commit_shared() ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) return () @@ -105,6 +104,7 @@ def copy_smem_to_gmem( See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` + :func:`jax.experimental.mosaic.gpu.commit_smem` """ if src.memory_space is not gpu_core.SMEM: raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 932055096..1b240305a 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -144,6 +144,7 @@ def attention(q, k, v, config: TuningConfig): # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)], ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 8d17d8458..f60c6c7c6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -239,6 +239,7 @@ class PallasCallTest(PallasTest): ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer]) plgpu.wait_smem_to_gmem(0) @@ -294,6 +295,7 @@ class PallasCallTest(PallasTest): plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref) plgpu.barrier_wait(barrier_ref) else: + plgpu.commit_smem() plgpu.copy_smem_to_gmem(x_ref, o_ref) plgpu.wait_smem_to_gmem(0) @@ -1046,6 +1048,7 @@ class PipelineTest(PallasTest): o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0 + plgpu.commit_smem() plgpu.copy_smem_to_gmem( o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)] )