mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
This commit is contained in:
parent
1bef8b61af
commit
eeccc67c0b
@ -1491,11 +1491,7 @@ def _debug_print_lowering_rule(
|
||||
)
|
||||
elif len(ctx.avals_in) == 1:
|
||||
[arg] = args
|
||||
@arg.foreach
|
||||
def _(val, idx):
|
||||
idx_fmt = ", ".join(["{}"] * len(idx))
|
||||
fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}")
|
||||
mgpu.debug_print(fmt_str, *idx, val, uniform=False)
|
||||
arg.debug_print(fmt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"debug_print only supports printing of scalar values, or a single array"
|
||||
|
@ -1555,6 +1555,13 @@ class FragmentedArray:
|
||||
if create_array:
|
||||
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
|
||||
|
||||
def debug_print(self, fmt: str):
|
||||
idx_fmt = ", ".join(["{}"] * len(self.shape))
|
||||
@self.foreach
|
||||
def _(val, idx):
|
||||
fmt_str = fmt.format(f"[{idx_fmt}]: {{}}")
|
||||
utils.debug_print(fmt_str, *idx, val, uniform=False)
|
||||
|
||||
def store_untiled(self, ref: ir.Value, *, vector_store: bool = True):
|
||||
if not ir.MemRefType.isinstance(ref.type):
|
||||
raise ValueError(ref)
|
||||
|
@ -714,7 +714,7 @@ class PallasCallTest(PallasTest):
|
||||
shape = (128, 64)
|
||||
size = math.prod(shape)
|
||||
def kernel(x_ref, o_ref):
|
||||
pl.debug_print("{}", x_ref[...])
|
||||
pl.debug_print("prefix {}", x_ref[...])
|
||||
spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)))
|
||||
x = jnp.arange(size, dtype=jnp.float32).reshape(shape)
|
||||
f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec)
|
||||
@ -723,8 +723,8 @@ class PallasCallTest(PallasTest):
|
||||
jax.block_until_ready(f(x))
|
||||
|
||||
output = get_output()
|
||||
results = re.findall(r"\[(\d+), (\d+)\]/\[128, 64\]: (\d+)", output)
|
||||
self.assertLen(results, size)
|
||||
results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output)
|
||||
self.assertLen(results, size, output)
|
||||
for i, j, v in results:
|
||||
i, j, v = map(int, (i, j, v))
|
||||
self.assertEqual(v, i * shape[1] + j)
|
||||
@ -774,7 +774,7 @@ class PallasCallTest(PallasTest):
|
||||
with self.capture_stdout() as output:
|
||||
jax.block_until_ready(kernel(x))
|
||||
|
||||
self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output())
|
||||
self.assertIn("x: [1, 0, 43, 23]: 6871\n", output())
|
||||
|
||||
def test_load_scalar(self):
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user