Added pl.CompilerParams subclass for Mosaic GPU

PiperOrigin-RevId: 671066741
This commit is contained in:
Sergei Lebedev 2024-09-04 12:47:57 -07:00 committed by jax authors
parent 3672b633c3
commit a8a55e0f2e
8 changed files with 44 additions and 23 deletions

View File

@ -23,7 +23,7 @@ import enum
import functools
import itertools
import threading
from typing import Any, ClassVar, Hashable, Union
from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable
import warnings
import jax
@ -66,10 +66,14 @@ OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
SEMAPHORE_INTERPRET_DTYPE = jnp.int16
@dataclasses.dataclass(frozen=True)
class CompilerParams:
@runtime_checkable
class CompilerParams(Protocol):
"""Base class for compiler parameters."""
PLATFORM: ClassVar[str] = "unspecified"
PLATFORM: ClassVar[str]
# Subclasses must be dataclasses.
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:

View File

@ -19,14 +19,14 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any, ClassVar, Hashable
from typing import Any, ClassVar, Hashable, Literal
import jax
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import util
import jax.numpy as jnp
from jax._src.pallas import core as pallas_core
import jax.numpy as jnp
import numpy as np
map, unsafe_map = util.safe_map, map
@ -68,7 +68,7 @@ class TPUCompilerParams(pallas_core.CompilerParams):
device_type: The device type to compile for.
"""
PLATFORM: ClassVar[str] = "mosaic"
dimension_semantics: Sequence[str] | None = None
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None
allow_input_fusion: Sequence[bool] | None = None
vmem_limit_bytes: int | None = None
collective_id: int | None = None

View File

@ -14,8 +14,10 @@
"""Contains GPU-specific Pallas abstractions."""
from collections.abc import Sequence
import dataclasses
import enum
from typing import ClassVar, Literal
from jax import core as jax_core
from jax._src.pallas import core as pallas_core
import jax.numpy as jnp
@ -23,6 +25,23 @@ import jax.numpy as jnp
AbstractMemoryRef = pallas_core.AbstractMemoryRef
@dataclasses.dataclass(frozen=True)
class GPUCompilerParams(pallas_core.CompilerParams):
"""Mosaic GPU compiler parameters.
Attributes:
dimension_semantics: A list of dimension semantics for each grid
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "sequential" for dimensions that must be
executed sequentially.
num_stages: The number of pipline stages in the kernel. Defaults to 1,
meaning no pipelining is done.
"""
PLATFORM: ClassVar[str] = "mosaic_gpu"
dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None
num_stages: int = 1
class GPUMemorySpace(enum.Enum):
GMEM = "gmem"
SMEM = "smem"

View File

@ -20,7 +20,7 @@ from collections.abc import Sequence
import dataclasses
import functools
import math
from typing import Any, Literal, TypedDict, cast
from typing import Any, cast
import jax
from jax._src import core as jax_core
@ -152,11 +152,6 @@ def _eval_index_map(
return tuple(result)
class Params(TypedDict, total=False):
num_stages: int
dimension_semantics: Sequence[Literal["sequential", "parallel"]]
def lower_jaxpr_to_module(
grid_mapping: pallas_core.GridMapping,
jaxpr: jax_core.Jaxpr,
@ -199,7 +194,7 @@ def lower_jaxpr_to_module(
grid += (1,) * (3 - len(grid))
block = (128,) + (1,) * (len(grid) - 1)
params = Params(**compiler_params.get("mosaic_gpu", {}))
params = compiler_params.get("mosaic_gpu", {})
num_stages = params.get("num_stages", 1)
dimension_semantics = params.get(
"dimension_semantics", ["parallel"] * len(grid_mapping.grid)

View File

@ -1291,7 +1291,7 @@ def pallas_call(
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "triton"]:
if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
compiler_params = {

View File

@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
from jax._src.deprecations import register as _register_deprecation
from jax._src.pallas.core import Blocked
from jax._src.pallas.core import BlockSpec
from jax._src.pallas.core import CompilerParams
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import no_block_spec

View File

@ -14,7 +14,9 @@
"""PagedAttention TPU kernel."""
from collections.abc import Sequence
import functools
from typing import Literal
import jax
from jax import lax
@ -516,6 +518,7 @@ def paged_attention(
)
q_dtype_for_kernel_launch = q.dtype
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
if inline_seq_dim:
kernel = paged_flash_attention_kernel_inline_seq_dim
grid = (
@ -525,7 +528,7 @@ def paged_attention(
if megacore_mode == "kv_head"
else num_kv_heads,
)
dimension_sematics = ("parallel", "arbitrary", "arbitrary")
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
else:
kernel = paged_flash_attention_kernel
grid = (
@ -536,7 +539,7 @@ def paged_attention(
else num_kv_heads,
pages_per_sequence // pages_per_compute_block,
) # type: ignore
dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
if k_scales_pages is not None and v_scales_pages is not None:
in_specs = [
@ -641,7 +644,7 @@ def paged_attention(
scratch_shapes=scratch_shapes,
),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=dimension_sematics),
dimension_semantics=dimension_semantics),
out_shape=[
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),

View File

@ -80,16 +80,15 @@ class PallasCallTest(PallasTest):
@parameterized.product(num_stages=[1, 2, 3])
def test_add_one_grid_pipelined(self, num_stages):
@functools.partial(
pl.pallas_call,
in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))],
out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32),
compiler_params=dict(
mosaic_gpu=dict(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
),
compiler_params=plgpu.GPUCompilerParams(
dimension_semantics=["parallel", "sequential"],
num_stages=num_stages,
),
grid=(2, 1),
)