mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Added a new Pallas Triton primitive -- `plgpu.debug_barrier
`
Closes #23400. PiperOrigin-RevId: 670636723
This commit is contained in:
parent
ff702cb249
commit
9030aec097
@ -20,6 +20,7 @@ from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas.triton import lowering
|
||||
from jax.interpreters import mlir
|
||||
@ -120,3 +121,22 @@ def _elementwise_inline_asm_lowering(
|
||||
packed_element=pack,
|
||||
args=args,
|
||||
).result
|
||||
|
||||
|
||||
def debug_barrier() -> None:
|
||||
"""Synchronizes all kernel executions in the grid."""
|
||||
return debug_barrier_p.bind()
|
||||
|
||||
|
||||
debug_barrier_p = jax_core.Primitive("debug_barrier_p")
|
||||
debug_barrier_p.multiple_results = True
|
||||
|
||||
@debug_barrier_p.def_abstract_eval
|
||||
def _debug_barrier_abstract_eval() -> Sequence[jax_core.ShapedArray]:
|
||||
return ()
|
||||
|
||||
@lowering.register_lowering(debug_barrier_p)
|
||||
def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext):
|
||||
del ctx # Unused.
|
||||
gpu_dialect.barrier()
|
||||
return []
|
||||
|
@ -15,4 +15,5 @@
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
|
@ -972,6 +972,22 @@ class OpsExtraTest(PallasBaseTest):
|
||||
x = jnp.arange(256).astype(jnp.float16)
|
||||
np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_debug_barrier(self):
|
||||
if self.INTERPRET:
|
||||
self.skipTest("debug_barrier is not supported in interpret mode")
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
grid=1,
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
plgpu.debug_barrier()
|
||||
|
||||
x = jnp.array([4.2, 2.4]).astype(jnp.float32)
|
||||
np.testing.assert_array_equal(kernel(x), x)
|
||||
|
||||
def test_debug_print(self):
|
||||
# TODO: this test flakes on gpu
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
|
Loading…
x
Reference in New Issue
Block a user