diff --git a/jax/BUILD b/jax/BUILD index b761722a5..c4a421362 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -650,9 +650,9 @@ pytype_strict_library( ":pallas_gpu_users", ], deps = [ - ":pallas", "//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:core", "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/triton:primitives", ], diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d3b0ea8ca..301504c8f 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1277,8 +1277,10 @@ def pallas_call( If missing, then we use `{kernel_name} at {file}:{line}`. compiler_params: Optional compiler parameters. If a dict is provided, it should be of the form {platform: {param_name: param_value}}, where - platform is either 'mosaic' or 'triton'. For TPUs, it is also possible - to pass in a pallas.tpu.TPUCompilerParams struct. + platform is either 'mosaic' or 'triton'. It is also possible + to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and + `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. + Returns: A function that can be called on a number of positional array arguments to diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index c40fb19ec..a9babcba0 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -27,6 +27,12 @@ package( ], ) +pytype_strict_library( + name = "core", + srcs = ["core.py"], + deps = ["//jax/_src/pallas"], +) + pytype_strict_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py new file mode 100644 index 000000000..a61dfd61b --- /dev/null +++ b/jax/_src/pallas/triton/core.py @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains Triton-specific Pallas abstractions.""" +from __future__ import annotations + +import dataclasses +from typing import ClassVar + +from jax._src.pallas import core as pallas_core + +@dataclasses.dataclass(frozen=True) +class TritonCompilerParams(pallas_core.CompilerParams): + """Compiler parameters for Triton. + + Attributes: + num_warps: The number of warps to use for the kernel. Each warp consists of + 32 threads. + num_stages: The number of stages the compiler should use for software + pipelining loops. + serialized_metadata: Additional compiler metadata. This field is unstable + and may be removed in the future. + """ + PLATFORM: ClassVar[str] = "triton" + num_warps: int | None = None + num_stages: int | None = None + serialized_metadata: str | None = None diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index b94adfb8f..5ee7077dc 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -61,11 +61,14 @@ def pallas_call_lowering( ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) + num_warps = 4 if num_warps is None else num_warps [lowering_platform] = ctx.platforms or ctx.module_context.platforms if lowering_platform == "rocm": num_stages = triton_params.pop("num_stages", 1) + num_stages = 1 if num_stages is None else num_stages else: num_stages = triton_params.pop("num_stages", 3) + num_stages = 3 if num_stages is None else num_stages if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -101,9 +104,10 @@ def pallas_call_lowering( ) if "serialized_metadata" in (triton_params or {}): # This field is unstable and may be removed in the future. - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + if triton_params["serialized_metadata"] is not None: + backend_config["serialized_metadata"] = ir.StringAttr.get( + triton_params["serialized_metadata"] + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index a24bfe415..4f38192e3 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -14,6 +14,7 @@ """Triton-specific Pallas APIs.""" +from jax._src.pallas.triton.core import TritonCompilerParams from jax._src.pallas.triton.primitives import approx_tanh from jax._src.pallas.triton.primitives import debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 63541e8cb..8e28be840 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -21,6 +21,7 @@ from typing import Any import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp import numpy as np @@ -216,9 +217,8 @@ def mha( out_specs=pl.BlockSpec( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) - ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, interpret=interpret, diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index 9be724a1f..dde80d460 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -21,6 +21,7 @@ from typing import Any import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp @@ -153,8 +154,8 @@ def attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages ), out_shape=[ jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 0c39a9bf6..e53139507 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -24,6 +24,7 @@ import jax.numpy as jnp from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def layer_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -282,9 +283,8 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index e1dfa3c5b..3e373b895 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def rms_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -83,7 +84,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 3671331b8..33b416d16 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -18,6 +18,7 @@ import functools import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def _vmappable_softmax_kernel( @@ -79,7 +80,8 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, debug=debug, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 85bda21ec..564a59ec2 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -996,7 +996,7 @@ class OpsExtraTest(PallasBaseTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("It works!") @@ -1016,7 +1016,7 @@ class OpsExtraTest(PallasBaseTest): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("x[0] =", x_ref[0])