mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
This commit is contained in:
parent
0519db15ab
commit
de7beb91a7
@ -13,6 +13,7 @@ Classes
|
||||
GPUBlockSpec
|
||||
GPUCompilerParams
|
||||
GPUMemorySpace
|
||||
Layout
|
||||
SwizzleTransform
|
||||
TilingTransform
|
||||
TransposeTransform
|
||||
@ -28,6 +29,7 @@ Functions
|
||||
barrier_wait
|
||||
copy_gmem_to_smem
|
||||
copy_smem_to_gmem
|
||||
layout_cast
|
||||
set_max_registers
|
||||
wait_smem_to_gmem
|
||||
wgmma
|
||||
|
@ -16,7 +16,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
import enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
@ -541,6 +542,33 @@ def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc):
|
||||
return acc.value
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
#: [m, n] matrix, where m % 64 == 0 == n % 8.
|
||||
WGMMA = mgpu.WGMMA_LAYOUT
|
||||
#: [m] matrix, where m % 64 == 0.
|
||||
WGMMA_ROW = mgpu.WGMMA_ROW_LAYOUT
|
||||
|
||||
|
||||
layout_cast_p = jax_core.Primitive("layout_cast")
|
||||
|
||||
|
||||
@layout_cast_p.def_abstract_eval
|
||||
def _layout_cast_abstract_eval(x, new_layout):
|
||||
del new_layout # Unused.
|
||||
return x
|
||||
|
||||
|
||||
@lowering.register_lowering_rule(layout_cast_p)
|
||||
def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout):
|
||||
del ctx # Unused.
|
||||
return x.to_layout(new_layout.value)
|
||||
|
||||
|
||||
def layout_cast(x: Any, new_layout: Layout):
|
||||
"""Casts the layout of the given array."""
|
||||
return layout_cast_p.bind(x, new_layout=new_layout)
|
||||
|
||||
|
||||
set_max_registers_p = jax_core.Primitive("set_max_registers_p")
|
||||
set_max_registers_p.multiple_results = True
|
||||
|
||||
|
@ -25,13 +25,15 @@ from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh
|
||||
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
|
||||
from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast
|
||||
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma
|
||||
|
@ -760,6 +760,17 @@ class PallasCallTest(PallasTest):
|
||||
rotate(x, expected)
|
||||
np.testing.assert_array_equal(f(x), expected)
|
||||
|
||||
def test_layout_cast(self, shape=(256, 64)):
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
||||
)
|
||||
def kernel(o_ref):
|
||||
o_ref[...] = plgpu.layout_cast(jnp.full(shape, 42.0), plgpu.Layout.WGMMA)
|
||||
|
||||
x = jnp.full(shape, 42.0)
|
||||
np.testing.assert_array_equal(kernel(), x)
|
||||
|
||||
|
||||
class PipelineTest(PallasTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user