[pallas:mosaic_gpu] Added test for custom pretty-printing rules

PiperOrigin-RevId: 745145207
This commit is contained in:
Sergei Lebedev 2025-04-08 07:58:52 -07:00 committed by jax authors
parent b926fac66e
commit f5d73b89ca
2 changed files with 77 additions and 2 deletions

View File

@ -857,13 +857,14 @@ def _wgmma_ref_pp_eqn(
acc, a, b, *leaves = eqn.invars
a_transforms_treedef = eqn.params["a_transforms_tree"]
b_transforms_treedef = eqn.params["b_transforms_tree"]
split = getattr(a_transforms_treedef, "num_leaves", 0)
a_transforms = (
a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves])
a_transforms_treedef.unflatten(leaves[:split])
if a_transforms_treedef is not None
else []
)
b_transforms = (
b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :])
b_transforms_treedef.unflatten(leaves[split:])
if b_transforms_treedef is not None
else []
)

View File

@ -2634,6 +2634,80 @@ class CoreMapWGTest(
...
class PrettyPrintingTest(PallasTest):
def test_load(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32),
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
)
def kernel(x_ref, o_ref):
for i in range(2):
x = plgpu.load(x_ref, (i,))
o_ref[i, ...] = x
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32)))
def test_copy_primitives(self):
num_steps = 4
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32),
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
)
def kernel(x_gmem, o_gmem):
# ``plgpu.emit_pipeline`` is implemented in terms of async copy and
# synchronization primitives.
plgpu.emit_pipeline(
kernel_body,
in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))],
out_specs=[
pl.BlockSpec(
(64, 64),
lambda i: (0, i),
)
],
grid=(num_steps,),
max_concurrent_steps=2,
)(x_gmem, o_gmem)
def kernel_body(_, x_smem, o_smem):
o_smem[...] = x_smem[...] + 1.0
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32)))
def test_wgmma(self):
transforms = ()
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
in_specs=[
plgpu.GPUBlockSpec(transforms=transforms),
plgpu.GPUBlockSpec(transforms=transforms),
],
)
def kernel(a_ref, b_ref, o_ref):
def scope(acc_ref):
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
return acc_ref[...]
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
_ = str(
jax.make_jaxpr(kernel)(
jax.ShapeDtypeStruct((64, 128), jnp.float16),
jax.ShapeDtypeStruct((128, 192), jnp.float16),
)
)
class ExamplesTest(PallasTest):
# Basic