[pallas:mosaic_gpu] Added layout_cast

PiperOrigin-RevId: 686917796
This commit is contained in:
Sergei Lebedev 2024-10-17 08:07:31 -07:00 committed by jax authors
parent 0519db15ab
commit de7beb91a7
4 changed files with 45 additions and 2 deletions

View File

@ -13,6 +13,7 @@ Classes
GPUBlockSpec
GPUCompilerParams
GPUMemorySpace
Layout
SwizzleTransform
TilingTransform
TransposeTransform
@ -28,6 +29,7 @@ Functions
barrier_wait
copy_gmem_to_smem
copy_smem_to_gmem
layout_cast
set_max_registers
wait_smem_to_gmem
wgmma

View File

@ -16,7 +16,8 @@
from __future__ import annotations
from typing import Literal
import enum
from typing import Any, Literal
from jax._src import core as jax_core
from jax._src import effects
@ -541,6 +542,33 @@ def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
return acc.value
class Layout(enum.Enum):
#: [m, n] matrix, where m % 64 == 0 == n % 8.
WGMMA = mgpu.WGMMA_LAYOUT
#: [m] matrix, where m % 64 == 0.
WGMMA_ROW = mgpu.WGMMA_ROW_LAYOUT
layout_cast_p = jax_core.Primitive("layout_cast")
@layout_cast_p.def_abstract_eval
def _layout_cast_abstract_eval(x, new_layout):
del new_layout # Unused.
return x
@lowering.register_lowering_rule(layout_cast_p)
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
del ctx # Unused.
return x.to_layout(new_layout.value)
def layout_cast(x: Any, new_layout: Layout):
"""Casts the layout of the given array."""
return layout_cast_p.bind(x, new_layout=new_layout)
set_max_registers_p = jax_core.Primitive("set_max_registers_p")
set_max_registers_p.multiple_results = True

View File

@ -25,13 +25,15 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma

View File

@ -760,6 +760,17 @@ class PallasCallTest(PallasTest):
rotate(x, expected)
np.testing.assert_array_equal(f(x), expected)
def test_layout_cast(self, shape=(256, 64)):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
)
def kernel(o_ref):
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0), plgpu.Layout.WGMMA)
x = jnp.full(shape, 42.0)
np.testing.assert_array_equal(kernel(), x)
class PipelineTest(PallasTest):