mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[pallas:mosaic_gpu] plgpu.copy_smem_to_gmem
no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead. PiperOrigin-RevId: 691724662
This commit is contained in:
parent
7d504cd95a
commit
85662f6dd8
@ -27,6 +27,7 @@ Functions
|
||||
|
||||
barrier_arrive
|
||||
barrier_wait
|
||||
commit_smem
|
||||
copy_gmem_to_smem
|
||||
copy_smem_to_gmem
|
||||
emit_pipeline
|
||||
|
@ -158,6 +158,7 @@ def emit_pipeline(
|
||||
)
|
||||
|
||||
# Copy the output from SMEM to GMEM.
|
||||
gpu_primitives.commit_smem()
|
||||
map(lambda bref: bref.copy_out(slot, indices), out_brefs)
|
||||
|
||||
fetch_step = step + max_concurrent_steps
|
||||
|
@ -66,7 +66,6 @@ def _copy_smem_to_gmem_lowering(
|
||||
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
|
||||
src, src_transforms = lowering._handle_indexing(src, src_transforms)
|
||||
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
|
||||
mgpu.commit_shared()
|
||||
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
|
||||
return ()
|
||||
|
||||
@ -105,6 +104,7 @@ def copy_smem_to_gmem(
|
||||
|
||||
See also:
|
||||
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
|
||||
:func:`jax.experimental.mosaic.gpu.commit_smem`
|
||||
"""
|
||||
if src.memory_space is not gpu_core.SMEM:
|
||||
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
|
||||
|
@ -144,6 +144,7 @@ def attention(q, k, v, config: TuningConfig):
|
||||
# TODO(apaszke): Invert and multiply to avoid expensive divisions.
|
||||
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
|
||||
qo_smem[...] = acc.astype(dtype)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q)],
|
||||
)
|
||||
|
@ -239,6 +239,7 @@ class PallasCallTest(PallasTest):
|
||||
)
|
||||
def kernel(x_ref, o_ref_gmem, scratch_ref):
|
||||
scratch_ref[...] = x_ref[...] + 1
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(scratch_ref.at[indexer], o_ref_gmem.at[indexer])
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
|
||||
@ -294,6 +295,7 @@ class PallasCallTest(PallasTest):
|
||||
plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier=barrier_ref)
|
||||
plgpu.barrier_wait(barrier_ref)
|
||||
else:
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(x_ref, o_ref)
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
|
||||
@ -1046,6 +1048,7 @@ class PipelineTest(PallasTest):
|
||||
|
||||
o_smem.at[slot][...] = x_smem.at[slot][...] + 1.0
|
||||
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
o_smem.at[slot], o_gmem.at[gmem_slice, pl.ds(step * 16, 16)]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user