From 85662f6dd83baa92ae2e18620923d25fbdcac420 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev@google.com>
Date: Thu, 31 Oct 2024 02:20:33 -0700
Subject: [PATCH] [pallas:mosaic_gpu] `plgpu.copy_smem_to_gmem` no longer
 transparently commits SMEM

Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
---
 docs/jax.experimental.pallas.mosaic_gpu.rst       | 1 +
 jax/_src/pallas/mosaic_gpu/pipeline.py            | 1 +
 jax/_src/pallas/mosaic_gpu/primitives.py          | 2 +-
 jax/experimental/pallas/ops/gpu/attention_mgpu.py | 1 +
 tests/pallas/mosaic_gpu_test.py                   | 3 +++
 5 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst
index 71bf9c3ff..2d3452609 100644
--- a/docs/jax.experimental.pallas.mosaic_gpu.rst
+++ b/docs/jax.experimental.pallas.mosaic_gpu.rst
@@ -27,6 +27,7 @@ Functions
 
    barrier_arrive
    barrier_wait
+   commit_smem
    copy_gmem_to_smem
    copy_smem_to_gmem
    emit_pipeline
diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py
index 8d2274f14..9a17646f0 100644
--- a/jax/_src/pallas/mosaic_gpu/pipeline.py
+++ b/jax/_src/pallas/mosaic_gpu/pipeline.py
@@ -158,6 +158,7 @@ def emit_pipeline(
         )
 
       # Copy the output from SMEM to GMEM.
+      gpu_primitives.commit_smem()
       map(lambda bref: bref.copy_out(slot, indices), out_brefs)
 
       fetch_step = step + max_concurrent_steps
diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py
index f87e96a30..1a5ed7f0d 100644
--- a/jax/_src/pallas/mosaic_gpu/primitives.py
+++ b/jax/_src/pallas/mosaic_gpu/primitives.py
@@ -66,7 +66,6 @@ def _copy_smem_to_gmem_lowering(
   dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
   src, src_transforms = lowering._handle_indexing(src, src_transforms)
   copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
-  mgpu.commit_shared()
   ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
   return ()
 
@@ -105,6 +104,7 @@ def copy_smem_to_gmem(
 
   See also:
     :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
+    :func:`jax.experimental.mosaic.gpu.commit_smem`
   """
   if src.memory_space is not gpu_core.SMEM:
     raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
index 932055096..1b240305a 100644
--- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py
+++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py
@@ -144,6 +144,7 @@ def attention(q, k, v, config: TuningConfig):
       # TODO(apaszke): Invert and multiply to avoid expensive divisions.
       acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
       qo_smem[...] = acc.astype(dtype)
+      plgpu.commit_smem()
       plgpu.copy_smem_to_gmem(
           qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)],
       )
diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py
index 8d17d8458..f60c6c7c6 100644
--- a/tests/pallas/mosaic_gpu_test.py
+++ b/tests/pallas/mosaic_gpu_test.py
@@ -239,6 +239,7 @@ class PallasCallTest(PallasTest):
     )
     def kernel(x_ref, o_ref_gmem, scratch_ref):
       scratch_ref[...] = x_ref[...] + 1
+      plgpu.commit_smem()
       plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
       plgpu.wait_smem_to_gmem(0)
 
@@ -294,6 +295,7 @@ class PallasCallTest(PallasTest):
         plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref)
         plgpu.barrier_wait(barrier_ref)
       else:
+        plgpu.commit_smem()
         plgpu.copy_smem_to_gmem(x_ref, o_ref)
         plgpu.wait_smem_to_gmem(0)
 
@@ -1046,6 +1048,7 @@ class PipelineTest(PallasTest):
 
         o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0
 
+        plgpu.commit_smem()
         plgpu.copy_smem_to_gmem(
             o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
         )