mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Pallas] Add explicit TPU compiler params interface with docstrings.
PiperOrigin-RevId: 665557475
This commit is contained in:
parent
9b3c19c5df
commit
2bd8c3f691
@ -23,7 +23,7 @@ import enum
|
||||
import functools
|
||||
import itertools
|
||||
import threading
|
||||
from typing import Any, Hashable, Union
|
||||
from typing import Any, ClassVar, Hashable, Union
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
@ -60,6 +60,11 @@ GridMappingGrid = tuple[int | DynamicGridDim, ...]
|
||||
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CompilerParams:
|
||||
"""Base class for compiler parameters."""
|
||||
PLATFORM: ClassVar[str] = "unspecified"
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NameAndSrcInfo:
|
||||
#: The name of the pallas_call or the name of the kernel function.
|
||||
|
@ -19,7 +19,7 @@ from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from typing import Any, Hashable
|
||||
from typing import Any, ClassVar, Hashable
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
@ -44,6 +44,38 @@ no_block_spec = pallas_core.no_block_spec
|
||||
_convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping
|
||||
split_list = util.split_list
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TPUCompilerParams(pallas_core.CompilerParams):
|
||||
"""Mosaic TPU compiler parameters.
|
||||
|
||||
Attributes:
|
||||
dimension_semantics: A list of dimension semantics for each grid
|
||||
dimension of the kernel. Either "parallel" for dimensions that can
|
||||
execute in any order, or "arbitrary" for dimensions that must be
|
||||
executed sequentially.
|
||||
allow_input_fusion: A list of booleans indicating whether input fusion is
|
||||
allowed for each argument.
|
||||
vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note
|
||||
that this must be used in conjunction with the
|
||||
--xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes.
|
||||
collective_id: Indicates which barrier semaphore to use for the kernel.
|
||||
Note that using the same collective_id does not guarantee that
|
||||
the same barrier semaphore will be allocated between kernels.
|
||||
internal_scratch_in_bytes: The size of the internal scratch space used by
|
||||
Mosaic.
|
||||
flags: A dictionary of command line flags for the kernel.
|
||||
serialization_format: The serialization format for the kernel body.
|
||||
device_type: The device type to compile for.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "mosaic"
|
||||
dimension_semantics: list[str] | None = None
|
||||
allow_input_fusion: list[bool] | None = None
|
||||
vmem_limit_bytes: int | None = None
|
||||
collective_id: int | None = None
|
||||
flags: dict[str, Any] | None = None
|
||||
internal_scratch_in_bytes: int | None = None
|
||||
serialization_format: int = 1
|
||||
device_type: str | None = None
|
||||
|
||||
class TPUMemorySpace(enum.Enum):
|
||||
ANY = "any"
|
||||
|
@ -80,16 +80,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
|
||||
print(jaxpr)
|
||||
if "mosaic_params" in compiler_params:
|
||||
# TODO(slebedev): Remove this branch after July 12th 2024.
|
||||
warnings.warn(
|
||||
"Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)"
|
||||
" is deprecated. Use compiler_params=dict(mosaic=...) instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert "mosaic" not in compiler_params
|
||||
mosaic_params = compiler_params["mosaic_params"]
|
||||
elif "mosaic" in compiler_params:
|
||||
if "mosaic" in compiler_params:
|
||||
mosaic_params = compiler_params["mosaic"]
|
||||
else:
|
||||
mosaic_params = {}
|
||||
|
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
import dataclasses
|
||||
from functools import partial, reduce
|
||||
import itertools
|
||||
from typing import Any
|
||||
@ -1232,7 +1233,7 @@ def pallas_call(
|
||||
debug: bool = False,
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Invokes a Pallas kernel on some inputs.
|
||||
@ -1274,7 +1275,10 @@ def pallas_call(
|
||||
where the kernel function is defined, .e.g:
|
||||
`{name} for kernel function {kernel_name} at {file}:{line}`.
|
||||
If missing, then we use `{kernel_name} at {file}:{line}`.
|
||||
compiler_params: TO BE DOCUMENTED.
|
||||
compiler_params: Optional compiler parameters. If a dict is provided, it
|
||||
should be of the form {platform: {param_name: param_value}}, where
|
||||
platform is either 'mosaic' or 'triton'. For TPUs, it is also possible
|
||||
to pass in a pallas.tpu.TPUCompilerParams struct.
|
||||
|
||||
Returns:
|
||||
A function that can be called on a number of positional array arguments to
|
||||
@ -1286,6 +1290,13 @@ def pallas_call(
|
||||
name, kernel_src_info)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
if isinstance(compiler_params, pallas_core.CompilerParams):
|
||||
if compiler_params.PLATFORM not in ["mosaic", "triton"]:
|
||||
raise ValueError(
|
||||
f"Unknown platform in compiler params: {compiler_params.PLATFORM}")
|
||||
compiler_params = {
|
||||
compiler_params.PLATFORM: dataclasses.asdict(compiler_params)
|
||||
}
|
||||
|
||||
if grid_spec is None:
|
||||
grid_spec = GridSpec(grid, in_specs, out_specs)
|
||||
|
@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.core import TPUCompilerParams
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
||||
|
@ -413,7 +413,9 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
),
|
||||
grid=8,
|
||||
),
|
||||
compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])),
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
allow_input_fusion=[False, True]
|
||||
),
|
||||
)(s, x)
|
||||
|
||||
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])
|
||||
@ -1556,12 +1558,12 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)),
|
||||
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256),
|
||||
)(x)
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
|
||||
compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)),
|
||||
)(x)
|
||||
|
||||
def test_allow_input_fusion(self):
|
||||
@ -1578,7 +1580,7 @@ class PallasCallTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))],
|
||||
out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)),
|
||||
out_shape=x,
|
||||
compiler_params=dict(mosaic=dict(allow_input_fusion=[True])),
|
||||
compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]),
|
||||
)(z)
|
||||
|
||||
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -1606,8 +1608,8 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(internal_scratch_in_bytes=requested_bytes)
|
||||
compiler_params=pltpu.TPUCompilerParams(
|
||||
internal_scratch_in_bytes=requested_bytes,
|
||||
),
|
||||
)(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user