Add TritonCompilerParams for specifying compiler arguments instead of a dict.

PiperOrigin-RevId: 671081069
This commit is contained in:
Justin Fu 2024-09-04 13:31:35 -07:00 committed by jax authors
parent a8a55e0f2e
commit 2d74c6aa05
12 changed files with 73 additions and 18 deletions

View File

@ -650,9 +650,9 @@ pytype_strict_library(
":pallas_gpu_users",
],
deps = [
":pallas",
"//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:core",
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],

View File

@ -1277,8 +1277,10 @@ def pallas_call(
If missing, then we use `{kernel_name} at {file}:{line}`.
compiler_params: Optional compiler parameters. If a dict is provided, it
should be of the form {platform: {param_name: param_value}}, where
platform is either 'mosaic' or 'triton'. For TPUs, it is also possible
to pass in a pallas.tpu.TPUCompilerParams struct.
platform is either 'mosaic' or 'triton'. It is also possible
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
Returns:
A function that can be called on a number of positional array arguments to

View File

@ -27,6 +27,12 @@ package(
],
)
pytype_strict_library(
name = "core",
srcs = ["core.py"],
deps = ["//jax/_src/pallas"],
)
pytype_strict_library(
name = "primitives",
srcs = ["primitives.py"],

View File

@ -0,0 +1,38 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Triton-specific Pallas abstractions."""
from __future__ import annotations
import dataclasses
from typing import ClassVar
from jax._src.pallas import core as pallas_core
@dataclasses.dataclass(frozen=True)
class TritonCompilerParams(pallas_core.CompilerParams):
"""Compiler parameters for Triton.
Attributes:
num_warps: The number of warps to use for the kernel. Each warp consists of
32 threads.
num_stages: The number of stages the compiler should use for software
pipelining loops.
serialized_metadata: Additional compiler metadata. This field is unstable
and may be removed in the future.
"""
PLATFORM: ClassVar[str] = "triton"
num_warps: int | None = None
num_stages: int | None = None
serialized_metadata: str | None = None

View File

@ -61,11 +61,14 @@ def pallas_call_lowering(
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
num_warps = 4 if num_warps is None else num_warps
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
if lowering_platform == "rocm":
num_stages = triton_params.pop("num_stages", 1)
num_stages = 1 if num_stages is None else num_stages
else:
num_stages = triton_params.pop("num_stages", 3)
num_stages = 3 if num_stages is None else num_stages
if debug:
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
@ -101,9 +104,10 @@ def pallas_call_lowering(
)
if "serialized_metadata" in (triton_params or {}):
# This field is unstable and may be removed in the future.
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
if triton_params["serialized_metadata"] is not None:
backend_config["serialized_metadata"] = ir.StringAttr.get(
triton_params["serialized_metadata"]
)
return mlir.custom_call(
call_target_name="__gpu$xla.gpu.triton",
result_types=out_types,

View File

@ -14,6 +14,7 @@
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton.core import TritonCompilerParams
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm

View File

@ -21,6 +21,7 @@ from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
import jax.numpy as jnp
import numpy as np
@ -216,9 +217,8 @@ def mha(
out_specs=pl.BlockSpec(
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
),
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages),
out_shape=out_shape,
debug=debug,
interpret=interpret,

View File

@ -21,6 +21,7 @@ from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
import jax.numpy as jnp
@ -153,8 +154,8 @@ def attn_unbatched(
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
],
compiler_params=dict(
triton=dict(num_warps=num_warps_, num_stages=num_stages)
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps_, num_stages=num_stages
),
out_shape=[
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o

View File

@ -24,6 +24,7 @@ import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
def layer_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
@ -282,9 +283,8 @@ def layer_norm(
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps, num_stages=num_stages),
grid=(),
out_shape=out_shape,
debug=False,

View File

@ -26,6 +26,7 @@ import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
def rms_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
@ -83,7 +84,7 @@ def rms_norm_forward(
]
method = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps)),
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,

View File

@ -18,6 +18,7 @@ import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
def _vmappable_softmax_kernel(
@ -79,7 +80,8 @@ def softmax(
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(
kernel,
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
compiler_params=plgpu.TritonCompilerParams(
num_warps=num_warps, num_stages=1),
grid=(),
out_shape=out_shape,
debug=debug,

View File

@ -996,7 +996,7 @@ class OpsExtraTest(PallasBaseTest):
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
compiler_params=dict(triton=dict(num_warps=1, num_stages=1))
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
)
def kernel(x_ref, o_ref):
pl.debug_print("It works!")
@ -1016,7 +1016,7 @@ class OpsExtraTest(PallasBaseTest):
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
compiler_params=dict(triton=dict(num_warps=1, num_stages=1))
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
)
def kernel(x_ref, o_ref):
pl.debug_print("x[0] =", x_ref[0])