Added a new Pallas Triton primitive -- `plgpu.debug_barrier`

Closes #23400.

PiperOrigin-RevId: 670636723
This commit is contained in:
Sergei Lebedev 2024-09-03 11:26:22 -07:00 committed by jax authors
parent ff702cb249
commit 9030aec097
3 changed files with 37 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from collections.abc import Sequence
import jax
from jax import core as jax_core
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas.triton import lowering
from jax.interpreters import mlir
@ -120,3 +121,22 @@ def _elementwise_inline_asm_lowering(
packed_element=pack,
args=args,
).result
def debug_barrier() -> None:
"""Synchronizes all kernel executions in the grid."""
return debug_barrier_p.bind()
debug_barrier_p = jax_core.Primitive("debug_barrier_p")
debug_barrier_p.multiple_results = True
@debug_barrier_p.def_abstract_eval
def _debug_barrier_abstract_eval() -> Sequence[jax_core.ShapedArray]:
return ()
@lowering.register_lowering(debug_barrier_p)
def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext):
del ctx # Unused.
gpu_dialect.barrier()
return []

View File

@ -15,4 +15,5 @@
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm

View File

@ -972,6 +972,22 @@ class OpsExtraTest(PallasBaseTest):
x = jnp.arange(256).astype(jnp.float16)
np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3)
def test_debug_barrier(self):
if self.INTERPRET:
self.skipTest("debug_barrier is not supported in interpret mode")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[...]
plgpu.debug_barrier()
x = jnp.array([4.2, 2.4]).astype(jnp.float32)
np.testing.assert_array_equal(kernel(x), x)
def test_debug_print(self):
# TODO: this test flakes on gpu
if jtu.test_device_matches(["gpu"]):