[mgpu] Debug print arrays.

PiperOrigin-RevId: 734576543
This commit is contained in:
Christos Perivolaropoulos 2025-03-07 08:57:46 -08:00 committed by jax authors
parent 1bef8b61af
commit eeccc67c0b
3 changed files with 12 additions and 9 deletions

View File

@ -1491,11 +1491,7 @@ def _debug_print_lowering_rule(
)
elif len(ctx.avals_in) == 1:
[arg] = args
@arg.foreach
def _(val, idx):
idx_fmt = ", ".join(["{}"] * len(idx))
fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}")
mgpu.debug_print(fmt_str, *idx, val, uniform=False)
arg.debug_print(fmt)
else:
raise NotImplementedError(
"debug_print only supports printing of scalar values, or a single array"

View File

@ -1555,6 +1555,13 @@ class FragmentedArray:
if create_array:
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
def debug_print(self, fmt: str):
idx_fmt = ", ".join(["{}"] * len(self.shape))
@self.foreach
def _(val, idx):
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
utils.debug_print(fmt_str, *idx, val, uniform=False)
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
if not ir.MemRefType.isinstance(ref.type):
raise ValueError(ref)

View File

@ -714,7 +714,7 @@ class PallasCallTest(PallasTest):
shape = (128, 64)
size = math.prod(shape)
def kernel(x_ref, o_ref):
pl.debug_print("{}", x_ref[...])
pl.debug_print("prefix {}", x_ref[...])
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)
@ -723,8 +723,8 @@ class PallasCallTest(PallasTest):
jax.block_until_ready(f(x))
output = get_output()
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
self.assertLen(results, size)
results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output)
self.assertLen(results, size, output)
for i, j, v in results:
i, j, v = map(int, (i, j, v))
self.assertEqual(v, i * shape[1] + j)
@ -774,7 +774,7 @@ class PallasCallTest(PallasTest):
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
self.assertIn("x: [1, 0, 43, 23]: 6871\n", output())
def test_load_scalar(self):
@functools.partial(