Updated remaining `mgpu.once usages to use mgpu.single_thread`

PiperOrigin-RevId: 640534664
This commit is contained in:
Sergei Lebedev 2024-06-05 08:30:21 -07:00 committed by jax authors
parent 621814bd7d
commit 69dd14e58a
3 changed files with 3 additions and 3 deletions

View File

@ -151,7 +151,7 @@ def lower_jaxpr_to_module(
# TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray.
[barrier] = mgpu.BarrierArray(1, arrival_count=1)
with mgpu.once():
with mgpu.single_thread():
nvgpu_dialect.mbarrier_arrive_expect_tx(
barrier.barrier_array.value,
_index(

View File

@ -389,7 +389,7 @@ class LaunchContext:
# nvgpu TMA instructions expect reversed indices...
rev_dyn_based_indices = reversed(dyn_base_indices)
uniform_ctx = mgpu.once if uniform else contextlib.nullcontext
uniform_ctx = mgpu.single_thread if uniform else contextlib.nullcontext
if gmem_ref is src_ref:
with uniform_ctx():

View File

@ -424,7 +424,7 @@ class FragmentedArray:
memref.store(warp_result, scratch, [warp_id])
utils.commit_shared()
zero_index = c(0, index)
with mgpu.once():
with mgpu.single_thread():
scratch_vec = vector.load(
ir.VectorType.get((4,), self.mlir_dtype),
scratch,