mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:mosaic_gpu] Added test for custom pretty-printing rules
PiperOrigin-RevId: 745145207
This commit is contained in:
parent
b926fac66e
commit
f5d73b89ca
@ -857,13 +857,14 @@ def _wgmma_ref_pp_eqn(
|
||||
acc, a, b, *leaves = eqn.invars
|
||||
a_transforms_treedef = eqn.params["a_transforms_tree"]
|
||||
b_transforms_treedef = eqn.params["b_transforms_tree"]
|
||||
split = getattr(a_transforms_treedef, "num_leaves", 0)
|
||||
a_transforms = (
|
||||
a_transforms_treedef.unflatten(leaves[: a_transforms_treedef.num_leaves])
|
||||
a_transforms_treedef.unflatten(leaves[:split])
|
||||
if a_transforms_treedef is not None
|
||||
else []
|
||||
)
|
||||
b_transforms = (
|
||||
b_transforms_treedef.unflatten(leaves[a_transforms_treedef.num_leaves :])
|
||||
b_transforms_treedef.unflatten(leaves[split:])
|
||||
if b_transforms_treedef is not None
|
||||
else []
|
||||
)
|
||||
|
@ -2634,6 +2634,80 @@ class CoreMapWGTest(
|
||||
...
|
||||
|
||||
|
||||
class PrettyPrintingTest(PallasTest):
|
||||
|
||||
def test_load(self):
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32),
|
||||
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
||||
out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.SMEM),
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
for i in range(2):
|
||||
x = plgpu.load(x_ref, (i,))
|
||||
o_ref[i, ...] = x
|
||||
|
||||
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32)))
|
||||
|
||||
def test_copy_primitives(self):
|
||||
num_steps = 4
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32),
|
||||
in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)],
|
||||
out_specs=pl.BlockSpec(memory_space=plgpu.GMEM),
|
||||
)
|
||||
def kernel(x_gmem, o_gmem):
|
||||
# ``plgpu.emit_pipeline`` is implemented in terms of async copy and
|
||||
# synchronization primitives.
|
||||
plgpu.emit_pipeline(
|
||||
kernel_body,
|
||||
in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))],
|
||||
out_specs=[
|
||||
pl.BlockSpec(
|
||||
(64, 64),
|
||||
lambda i: (0, i),
|
||||
)
|
||||
],
|
||||
grid=(num_steps,),
|
||||
max_concurrent_steps=2,
|
||||
)(x_gmem, o_gmem)
|
||||
|
||||
def kernel_body(_, x_smem, o_smem):
|
||||
o_smem[...] = x_smem[...] + 1.0
|
||||
|
||||
_ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32)))
|
||||
|
||||
def test_wgmma(self):
|
||||
transforms = ()
|
||||
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
|
||||
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
|
||||
in_specs=[
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
plgpu.GPUBlockSpec(transforms=transforms),
|
||||
],
|
||||
)
|
||||
def kernel(a_ref, b_ref, o_ref):
|
||||
def scope(acc_ref):
|
||||
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
|
||||
return acc_ref[...]
|
||||
|
||||
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
|
||||
|
||||
_ = str(
|
||||
jax.make_jaxpr(kernel)(
|
||||
jax.ShapeDtypeStruct((64, 128), jnp.float16),
|
||||
jax.ShapeDtypeStruct((128, 192), jnp.float16),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExamplesTest(PallasTest):
|
||||
|
||||
# Basic
|
||||
|
Loading…
x
Reference in New Issue
Block a user