mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add TritonCompilerParams for specifying compiler arguments instead of a dict.
PiperOrigin-RevId: 671081069
This commit is contained in:
parent
a8a55e0f2e
commit
2d74c6aa05
@ -650,9 +650,9 @@ pytype_strict_library(
|
||||
":pallas_gpu_users",
|
||||
],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/triton:core",
|
||||
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/triton:primitives",
|
||||
],
|
||||
|
@ -1277,8 +1277,10 @@ def pallas_call(
|
||||
If missing, then we use `{kernel_name} at {file}:{line}`.
|
||||
compiler_params: Optional compiler parameters. If a dict is provided, it
|
||||
should be of the form {platform: {param_name: param_value}}, where
|
||||
platform is either 'mosaic' or 'triton'. For TPUs, it is also possible
|
||||
to pass in a pallas.tpu.TPUCompilerParams struct.
|
||||
platform is either 'mosaic' or 'triton'. It is also possible
|
||||
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
|
||||
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
|
||||
|
||||
|
||||
Returns:
|
||||
A function that can be called on a number of positional array arguments to
|
||||
|
@ -27,6 +27,12 @@ package(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
deps = ["//jax/_src/pallas"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "primitives",
|
||||
srcs = ["primitives.py"],
|
||||
|
38
jax/_src/pallas/triton/core.py
Normal file
38
jax/_src/pallas/triton/core.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton-specific Pallas abstractions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
from jax._src.pallas import core as pallas_core
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonCompilerParams(pallas_core.CompilerParams):
|
||||
"""Compiler parameters for Triton.
|
||||
|
||||
Attributes:
|
||||
num_warps: The number of warps to use for the kernel. Each warp consists of
|
||||
32 threads.
|
||||
num_stages: The number of stages the compiler should use for software
|
||||
pipelining loops.
|
||||
serialized_metadata: Additional compiler metadata. This field is unstable
|
||||
and may be removed in the future.
|
||||
"""
|
||||
PLATFORM: ClassVar[str] = "triton"
|
||||
num_warps: int | None = None
|
||||
num_stages: int | None = None
|
||||
serialized_metadata: str | None = None
|
@ -61,11 +61,14 @@ def pallas_call_lowering(
|
||||
)
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.pop("num_warps", 4)
|
||||
num_warps = 4 if num_warps is None else num_warps
|
||||
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||
if lowering_platform == "rocm":
|
||||
num_stages = triton_params.pop("num_stages", 1)
|
||||
num_stages = 1 if num_stages is None else num_stages
|
||||
else:
|
||||
num_stages = triton_params.pop("num_stages", 3)
|
||||
num_stages = 3 if num_stages is None else num_stages
|
||||
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:")
|
||||
@ -101,9 +104,10 @@ def pallas_call_lowering(
|
||||
)
|
||||
if "serialized_metadata" in (triton_params or {}):
|
||||
# This field is unstable and may be removed in the future.
|
||||
backend_config["serialized_metadata"] = ir.StringAttr.get(
|
||||
triton_params["serialized_metadata"]
|
||||
)
|
||||
if triton_params["serialized_metadata"] is not None:
|
||||
backend_config["serialized_metadata"] = ir.StringAttr.get(
|
||||
triton_params["serialized_metadata"]
|
||||
)
|
||||
return mlir.custom_call(
|
||||
call_target_name="__gpu$xla.gpu.triton",
|
||||
result_types=out_types,
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton.core import TritonCompilerParams
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
|
@ -21,6 +21,7 @@ from typing import Any
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -216,9 +217,8 @@ def mha(
|
||||
out_specs=pl.BlockSpec(
|
||||
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
|
||||
),
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
||||
),
|
||||
compiler_params=plgpu.TritonCompilerParams(
|
||||
num_warps=num_warps_, num_stages=num_stages),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
|
@ -21,6 +21,7 @@ from typing import Any
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -153,8 +154,8 @@ def attn_unbatched(
|
||||
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
|
||||
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
|
||||
],
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
||||
compiler_params=plgpu.TritonCompilerParams(
|
||||
num_warps=num_warps_, num_stages=num_stages
|
||||
),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
|
||||
|
@ -24,6 +24,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
|
||||
def layer_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
@ -282,9 +283,8 @@ def layer_norm(
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(
|
||||
triton=dict(num_warps=num_warps, num_stages=num_stages)
|
||||
),
|
||||
compiler_params=plgpu.TritonCompilerParams(
|
||||
num_warps=num_warps, num_stages=num_stages),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
|
@ -26,6 +26,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
|
||||
def rms_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
@ -83,7 +84,7 @@ def rms_norm_forward(
|
||||
]
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps)),
|
||||
compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
|
@ -18,6 +18,7 @@ import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
|
||||
|
||||
def _vmappable_softmax_kernel(
|
||||
@ -79,7 +80,8 @@ def softmax(
|
||||
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
|
||||
compiler_params=plgpu.TritonCompilerParams(
|
||||
num_warps=num_warps, num_stages=1),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
|
@ -996,7 +996,7 @@ class OpsExtraTest(PallasBaseTest):
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
grid=1,
|
||||
compiler_params=dict(triton=dict(num_warps=1, num_stages=1))
|
||||
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
pl.debug_print("It works!")
|
||||
@ -1016,7 +1016,7 @@ class OpsExtraTest(PallasBaseTest):
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
|
||||
grid=1,
|
||||
compiler_params=dict(triton=dict(num_warps=1, num_stages=1))
|
||||
compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1)
|
||||
)
|
||||
def kernel(x_ref, o_ref):
|
||||
pl.debug_print("x[0] =", x_ref[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user