mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Added pl.CompilerParams subclass for Mosaic GPU
PiperOrigin-RevId: 671066741
This commit is contained in:
parent
3672b633c3
commit
a8a55e0f2e
@ -23,7 +23,7 @@ import enum
|
||||
import functools
|
||||
import itertools
|
||||
import threading
|
||||
from typing import Any, ClassVar, Hashable, Union
|
||||
from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
@ -66,10 +66,14 @@ OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
|
||||
SEMAPHORE_INTERPRET_DTYPE = jnp.int16
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CompilerParams:
|
||||
@runtime_checkable
|
||||
class CompilerParams(Protocol):
|
||||
"""Base class for compiler parameters."""
|
||||
PLATFORM: ClassVar[str] = "unspecified"
|
||||
PLATFORM: ClassVar[str]
|
||||
|
||||
# Subclasses must be dataclasses.
|
||||
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameAndSrcInfo:
|
||||
|
@ -19,14 +19,14 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any, ClassVar, Hashable
|
||||
from typing import Any, ClassVar, Hashable, Literal
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
import jax.numpy as jnp
|
||||
from jax._src.pallas import core as pallas_core
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -68,7 +68,7 @@ class TPUCompilerParams(pallas_core.CompilerParams):
|
||||
device_type: The device type to compile for.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "mosaic"
|
||||
dimension_semantics: Sequence[str] | None = None
|
||||
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None
|
||||
allow_input_fusion: Sequence[bool] | None = None
|
||||
vmem_limit_bytes: int | None = None
|
||||
collective_id: int | None = None
|
||||
|
@ -14,8 +14,10 @@
|
||||
|
||||
"""Contains GPU-specific Pallas abstractions."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import ClassVar, Literal
|
||||
from jax import core as jax_core
|
||||
from jax._src.pallas import core as pallas_core
|
||||
import jax.numpy as jnp
|
||||
@ -23,6 +25,23 @@ import jax.numpy as jnp
|
||||
AbstractMemoryRef = pallas_core.AbstractMemoryRef
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GPUCompilerParams(pallas_core.CompilerParams):
|
||||
"""Mosaic GPU compiler parameters.
|
||||
|
||||
Attributes:
|
||||
dimension_semantics: A list of dimension semantics for each grid
|
||||
dimension of the kernel. Either "parallel" for dimensions that can
|
||||
execute in any order, or "sequential" for dimensions that must be
|
||||
executed sequentially.
|
||||
num_stages: The number of pipline stages in the kernel. Defaults to 1,
|
||||
meaning no pipelining is done.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "mosaic_gpu"
|
||||
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
|
||||
num_stages: int = 1
|
||||
|
||||
|
||||
class GPUMemorySpace(enum.Enum):
|
||||
GMEM = "gmem"
|
||||
SMEM = "smem"
|
||||
|
@ -20,7 +20,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
@ -152,11 +152,6 @@ def _eval_index_map(
|
||||
return tuple(result)
|
||||
|
||||
|
||||
class Params(TypedDict, total=False):
|
||||
num_stages: int
|
||||
dimension_semantics: Sequence[Literal["sequential", "parallel"]]
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
@ -199,7 +194,7 @@ def lower_jaxpr_to_module(
|
||||
grid += (1,) * (3 - len(grid))
|
||||
block = (128,) + (1,) * (len(grid) - 1)
|
||||
|
||||
params = Params(**compiler_params.get("mosaic_gpu", {}))
|
||||
params = compiler_params.get("mosaic_gpu", {})
|
||||
num_stages = params.get("num_stages", 1)
|
||||
dimension_semantics = params.get(
|
||||
"dimension_semantics", ["parallel"] * len(grid_mapping.grid)
|
||||
|
@ -1291,7 +1291,7 @@ def pallas_call(
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "triton"]:
|
||||
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
||||
compiler_params = {
|
||||
|
@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
from jax._src.deprecations import register as _register_deprecation
|
||||
from jax._src.pallas.core import Blocked
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import CompilerParams
|
||||
from jax._src.pallas.core import CostEstimate
|
||||
from jax._src.pallas.core import IndexingMode
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
|
@ -14,7 +14,9 @@
|
||||
|
||||
"""PagedAttention TPU kernel."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
from typing import Literal
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
@ -516,6 +518,7 @@ def paged_attention(
|
||||
)
|
||||
q_dtype_for_kernel_launch = q.dtype
|
||||
|
||||
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
|
||||
if inline_seq_dim:
|
||||
kernel = paged_flash_attention_kernel_inline_seq_dim
|
||||
grid = (
|
||||
@ -525,7 +528,7 @@ def paged_attention(
|
||||
if megacore_mode == "kv_head"
|
||||
else num_kv_heads,
|
||||
)
|
||||
dimension_sematics = ("parallel", "arbitrary", "arbitrary")
|
||||
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
|
||||
else:
|
||||
kernel = paged_flash_attention_kernel
|
||||
grid = (
|
||||
@ -536,7 +539,7 @@ def paged_attention(
|
||||
else num_kv_heads,
|
||||
pages_per_sequence // pages_per_compute_block,
|
||||
) # type: ignore
|
||||
dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore
|
||||
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
|
||||
|
||||
if k_scales_pages is not None and v_scales_pages is not None:
|
||||
in_specs = [
|
||||
@ -641,7 +644,7 @@ def paged_attention(
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
dimension_semantics=dimension_sematics),
|
||||
dimension_semantics=dimension_semantics),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
|
||||
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
|
||||
|
@ -80,16 +80,15 @@ class PallasCallTest(PallasTest):
|
||||
|
||||
@parameterized.product(num_stages=[1, 2, 3])
|
||||
def test_add_one_grid_pipelined(self, num_stages):
|
||||
|
||||
@functools.partial(
|
||||
pl.pallas_call,
|
||||
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
|
||||
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
|
||||
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
|
||||
compiler_params=dict(
|
||||
mosaic_gpu=dict(
|
||||
dimension_semantics=["parallel", "sequential"],
|
||||
num_stages=num_stages,
|
||||
),
|
||||
compiler_params=plgpu.GPUCompilerParams(
|
||||
dimension_semantics=["parallel", "sequential"],
|
||||
num_stages=num_stages,
|
||||
),
|
||||
grid=(2, 1),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user