[Pallas] Add explicit TPU compiler params interface with docstrings.

PiperOrigin-RevId: 665557475
This commit is contained in:
Justin Fu 2024-08-20 15:38:03 -07:00 committed by jax authors
parent 9b3c19c5df
commit 2bd8c3f691
6 changed files with 62 additions and 20 deletions

View File

@ -23,7 +23,7 @@ import enum
import functools
import itertools
import threading
from typing import Any, Hashable, Union
from typing import Any, ClassVar, Hashable, Union
import warnings
import jax
@ -60,6 +60,11 @@ GridMappingGrid = tuple[int | DynamicGridDim, ...]
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
@dataclasses.dataclass(frozen=True)
class CompilerParams:
"""Base class for compiler parameters."""
PLATFORM: ClassVar[str] = "unspecified"
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.

View File

@ -19,7 +19,7 @@ from collections.abc import Sequence
import dataclasses
import enum
import functools
from typing import Any, Hashable
from typing import Any, ClassVar, Hashable
import jax
from jax._src import core as jax_core
@ -44,6 +44,38 @@ no_block_spec = pallas_core.no_block_spec
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
split_list = util.split_list
@dataclasses.dataclass(frozen=True)
class TPUCompilerParams(pallas_core.CompilerParams):
"""Mosaic TPU compiler parameters.
Attributes:
dimension_semantics: A list of dimension semantics for each grid
dimension of the kernel. Either "parallel" for dimensions that can
execute in any order, or "arbitrary" for dimensions that must be
executed sequentially.
allow_input_fusion: A list of booleans indicating whether input fusion is
allowed for each argument.
vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note
that this must be used in conjunction with the
--xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.
collective_id: Indicates which barrier semaphore to use for the kernel.
Note that using the same collective_id does not guarantee that
the same barrier semaphore will be allocated between kernels.
internal_scratch_in_bytes: The size of the internal scratch space used by
Mosaic.
flags: A dictionary of command line flags for the kernel.
serialization_format: The serialization format for the kernel body.
device_type: The device type to compile for.
"""
PLATFORM: ClassVar[str] = "mosaic"
dimension_semantics: list[str] | None = None
allow_input_fusion: list[bool] | None = None
vmem_limit_bytes: int | None = None
collective_id: int | None = None
flags: dict[str, Any] | None = None
internal_scratch_in_bytes: int | None = None
serialization_format: int = 1
device_type: str | None = None
class TPUMemorySpace(enum.Enum):
ANY = "any"

View File

@ -80,16 +80,7 @@ def pallas_call_tpu_lowering_rule(
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
print(jaxpr)
if "mosaic_params" in compiler_params:
# TODO(slebedev): Remove this branch after July 12th 2024.
warnings.warn(
"Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)"
" is deprecated. Use compiler_params=dict(mosaic=...) instead.",
DeprecationWarning,
)
assert "mosaic" not in compiler_params
mosaic_params = compiler_params["mosaic_params"]
elif "mosaic" in compiler_params:
if "mosaic" in compiler_params:
mosaic_params = compiler_params["mosaic"]
else:
mosaic_params = {}

View File

@ -16,6 +16,7 @@
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
import dataclasses
from functools import partial, reduce
import itertools
from typing import Any
@ -1232,7 +1233,7 @@ def pallas_call(
debug: bool = False,
interpret: bool = False,
name: str | None = None,
compiler_params: dict[str, Any] | None = None,
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
cost_estimate: CostEstimate | None = None,
) -> Callable[..., Any]:
"""Invokes a Pallas kernel on some inputs.
@ -1274,7 +1275,10 @@ def pallas_call(
where the kernel function is defined, .e.g:
`{name} for kernel function {kernel_name} at {file}:{line}`.
If missing, then we use `{kernel_name} at {file}:{line}`.
compiler_params: TO BE DOCUMENTED.
compiler_params: Optional compiler parameters. If a dict is provided, it
should be of the form {platform: {param_name: param_value}}, where
platform is either 'mosaic' or 'triton'. For TPUs, it is also possible
to pass in a pallas.tpu.TPUCompilerParams struct.
Returns:
A function that can be called on a number of positional array arguments to
@ -1286,6 +1290,13 @@ def pallas_call(
name, kernel_src_info)
if compiler_params is None:
compiler_params = {}
if isinstance(compiler_params, pallas_core.CompilerParams):
if compiler_params.PLATFORM not in ["mosaic", "triton"]:
raise ValueError(
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
compiler_params = {
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
}
if grid_spec is None:
grid_spec = GridSpec(grid, in_specs, out_specs)

View File

@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.core import TPUCompilerParams
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef

View File

@ -413,7 +413,9 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
),
grid=8,
),
compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])),
compiler_params=pltpu.TPUCompilerParams(
allow_input_fusion=[False, True]
),
)(s, x)
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])
@ -1556,12 +1558,12 @@ class PallasCallTest(PallasBaseTest):
self.pallas_call(
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256),
)(x)
self.pallas_call(
kernel,
out_shape=x,
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)),
)(x)
def test_allow_input_fusion(self):
@ -1578,7 +1580,7 @@ class PallasCallTest(PallasBaseTest):
in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))],
out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)),
out_shape=x,
compiler_params=dict(mosaic=dict(allow_input_fusion=[True])),
compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]),
)(z)
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
@ -1606,8 +1608,8 @@ class PallasCallTest(PallasBaseTest):
self.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
compiler_params=dict(
mosaic=dict(internal_scratch_in_bytes=requested_bytes)
compiler_params=pltpu.TPUCompilerParams(
internal_scratch_in_bytes=requested_bytes,
),
)(x)