mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Updated remaining `mgpu.once
usages to use
mgpu.single_thread
`
PiperOrigin-RevId: 640534664
This commit is contained in:
parent
621814bd7d
commit
69dd14e58a
@ -151,7 +151,7 @@ def lower_jaxpr_to_module(
|
||||
# TODO(slebedev): Consider enforcing this in the mgpu.BarrierArray.
|
||||
[barrier] = mgpu.BarrierArray(1, arrival_count=1)
|
||||
|
||||
with mgpu.once():
|
||||
with mgpu.single_thread():
|
||||
nvgpu_dialect.mbarrier_arrive_expect_tx(
|
||||
barrier.barrier_array.value,
|
||||
_index(
|
||||
|
@ -389,7 +389,7 @@ class LaunchContext:
|
||||
# nvgpu TMA instructions expect reversed indices...
|
||||
rev_dyn_based_indices = reversed(dyn_base_indices)
|
||||
|
||||
uniform_ctx = mgpu.once if uniform else contextlib.nullcontext
|
||||
uniform_ctx = mgpu.single_thread if uniform else contextlib.nullcontext
|
||||
|
||||
if gmem_ref is src_ref:
|
||||
with uniform_ctx():
|
||||
|
@ -424,7 +424,7 @@ class FragmentedArray:
|
||||
memref.store(warp_result, scratch, [warp_id])
|
||||
utils.commit_shared()
|
||||
zero_index = c(0, index)
|
||||
with mgpu.once():
|
||||
with mgpu.single_thread():
|
||||
scratch_vec = vector.load(
|
||||
ir.VectorType.get((4,), self.mlir_dtype),
|
||||
scratch,
|
||||
|
Loading…
x
Reference in New Issue
Block a user