mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect. PiperOrigin-RevId: 621857046
This commit is contained in:
parent
722708052c
commit
498e81ab10
@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.27
|
||||
|
||||
* Deprecations & Removals
|
||||
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
|
||||
lowering pass via Triton Python APIs has been removed and the
|
||||
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
|
||||
|
||||
|
||||
## jaxlib 0.4.27
|
||||
|
||||
## jax 0.4.26 (April 3, 2024)
|
||||
|
@ -67,7 +67,5 @@ pytype_strict_library(
|
||||
"//jax:util",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas",
|
||||
# Users are expected to add a jax_triton dependency to use the legacy
|
||||
# lowering path.
|
||||
],
|
||||
)
|
||||
|
@ -19,74 +19,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
from typing import Any
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src import config
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
from jax._src.pallas.triton import lowering
|
||||
from jax._src import util
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationResult:
|
||||
kernel_name: str
|
||||
ttir: str
|
||||
ptx: str
|
||||
shared_mem_bytes: int
|
||||
compute_capability: int
|
||||
lowering_result: lowering.LoweringResult
|
||||
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def compile_jaxpr(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
in_shapes,
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
name: str,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
debug: bool,
|
||||
) -> CompilationResult:
|
||||
from jax_triton.triton_lib import compile_ttir_to_ptx_inplace # type: ignore
|
||||
import triton.backends.nvidia.compiler as cb # type: ignore
|
||||
|
||||
# TODO(sharadmv): handle multiple devices, right now we assume device 0
|
||||
# which is fine when we have multiple of the same GPU but this won't work in
|
||||
# general.
|
||||
device = 0
|
||||
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
|
||||
target = ("cuda", compute_capability)
|
||||
cuda_backend = cb.CUDABackend(target)
|
||||
cuda_options = cuda_backend.parse_options(
|
||||
dict(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
debug=debug,
|
||||
)
|
||||
)
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, in_shapes, grid_mapping, name, cuda_options
|
||||
)
|
||||
|
||||
ttir = str(lowering_result.module)
|
||||
ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace(
|
||||
lowering_result.module,
|
||||
cuda_backend,
|
||||
cuda_options,
|
||||
compute_capability,
|
||||
)
|
||||
return CompilationResult(
|
||||
name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result
|
||||
)
|
||||
|
||||
|
||||
def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
|
||||
@ -101,81 +44,6 @@ def avals_to_layouts(avals):
|
||||
return [list(reversed(range(aval.ndim))) for aval in avals]
|
||||
|
||||
|
||||
def _pallas_call_ptx_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
triton_params: dict[str, Any],
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
):
|
||||
compilation_result = compile_jaxpr(
|
||||
jaxpr,
|
||||
(*in_shapes, *out_shapes),
|
||||
grid_mapping,
|
||||
name,
|
||||
num_warps,
|
||||
num_stages,
|
||||
debug=debug,
|
||||
)
|
||||
# Triton returns a tuple for ROCm. We just want file path to be passed
|
||||
if ctx.module_context.platforms[0] == 'rocm':
|
||||
compilation_result.ptx = compilation_result.ptx[1]
|
||||
|
||||
if debug:
|
||||
compilation_result.lowering_result.module.dump()
|
||||
|
||||
kernel = triton_kernel_call_lib.TritonKernel(
|
||||
compilation_result.kernel_name,
|
||||
num_warps,
|
||||
compilation_result.shared_mem_bytes,
|
||||
compilation_result.ptx,
|
||||
compilation_result.ttir,
|
||||
compilation_result.compute_capability,
|
||||
1,
|
||||
1,
|
||||
1, # TODO(giorgioa): Add support for clustering on H100s on Pallas.
|
||||
)
|
||||
|
||||
grid = normalize_grid(compilation_result.lowering_result.grid)
|
||||
|
||||
kernel_params = []
|
||||
for _ in range(len(in_shapes) + len(out_shapes)):
|
||||
kernel_params.append(
|
||||
triton_kernel_call_lib.create_array_parameter(
|
||||
0, # bytes to zero # TODO(cjfj): Expose through user API.
|
||||
16, # divisible by 16
|
||||
)
|
||||
)
|
||||
|
||||
kernel_call = triton_kernel_call_lib.TritonKernelCall(
|
||||
kernel, grid[0], grid[1], grid[2], kernel_params
|
||||
)
|
||||
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
||||
for shape in out_shapes
|
||||
]
|
||||
|
||||
serialized_metadata = triton_params.get("serialized_metadata", b"")
|
||||
kernel_call_proto = kernel_call.to_proto(name, serialized_metadata)
|
||||
return mlir.custom_call(
|
||||
call_target_name="triton_kernel_call",
|
||||
result_types=out_types,
|
||||
operands=in_nodes,
|
||||
backend_config=zlib.compress(kernel_call_proto),
|
||||
operand_layouts=avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
).results
|
||||
|
||||
|
||||
def _pallas_call_ttir_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
@ -243,13 +111,6 @@ def _pallas_call_ttir_lowering(
|
||||
).results
|
||||
|
||||
|
||||
_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool(
|
||||
"jax_triton_compile_via_xla",
|
||||
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True),
|
||||
help="If True, Pallas delegates Triton kernel compilation to XLA.",
|
||||
)
|
||||
|
||||
|
||||
def pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
@ -298,12 +159,7 @@ def pallas_call_lowering(
|
||||
print(jaxpr)
|
||||
print(grid_mapping)
|
||||
|
||||
if _TRITON_COMPILE_VIA_XLA.value:
|
||||
lowering_fn = _pallas_call_ttir_lowering
|
||||
else:
|
||||
lowering_fn = _pallas_call_ptx_lowering
|
||||
|
||||
return lowering_fn(
|
||||
return _pallas_call_ttir_lowering(
|
||||
ctx,
|
||||
*in_nodes,
|
||||
jaxpr=jaxpr,
|
||||
|
@ -13,10 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton specific Pallas functions."""
|
||||
try:
|
||||
from jax._src.pallas import triton
|
||||
get_compute_capability = triton.get_compute_capability
|
||||
del triton
|
||||
except ImportError as e:
|
||||
raise ImportError("Cannot import Pallas Triton backend. "
|
||||
"Make sure you've installed jax-triton.") from e
|
||||
from jax._src.pallas import triton
|
||||
get_compute_capability = triton.get_compute_capability
|
||||
del triton
|
||||
|
@ -56,9 +56,6 @@ jax_test(
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
],
|
||||
env = {
|
||||
"JAX_TRITON_COMPILE_VIA_XLA": "0",
|
||||
},
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:pallas_gpu",
|
||||
@ -101,41 +98,6 @@ jax_test(
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "pallas_via_xla_test",
|
||||
srcs = [
|
||||
"pallas_test.py",
|
||||
],
|
||||
backend_tags = {
|
||||
"gpu": ["noasan"], # https://github.com/openai/triton/issues/2918
|
||||
},
|
||||
config_tags_overrides = {
|
||||
"gpu_a100_x32": {
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu",
|
||||
"gpu_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:pallas_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "ops_test",
|
||||
srcs = [
|
||||
|
@ -24,7 +24,7 @@ from jax._src import test_util as jtu
|
||||
from jax.experimental import pallas as pl
|
||||
try:
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
except ImportError:
|
||||
plgpu = None
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
@ -33,25 +33,21 @@ from jax._src import state
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
if jaxlib_version >= (0, 4, 24):
|
||||
from jax._src.pallas.triton.lowering import LoweringError
|
||||
else:
|
||||
LoweringError = Exception
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
try:
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
except ImportError:
|
||||
plgpu = None
|
||||
from jax.experimental.pallas.ops import attention
|
||||
from jax.experimental.pallas.ops import layer_norm
|
||||
from jax.experimental.pallas.ops import rms_norm
|
||||
from jax.experimental.pallas.ops import softmax
|
||||
try:
|
||||
from jax._src.pallas.triton.lowering import LoweringError
|
||||
from jax._src.pallas.triton.pallas_call_registration import (
|
||||
compile_jaxpr,
|
||||
_TRITON_COMPILE_VIA_XLA,
|
||||
)
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
except ModuleNotFoundError:
|
||||
LoweringError = Exception
|
||||
compile_jaxpr = None
|
||||
_TRITON_COMPILE_VIA_XLA = None
|
||||
plgpu = None
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -143,17 +139,7 @@ class PallasTest(parameterized.TestCase):
|
||||
not self.check_gpu_capability_at_least(80)):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
if (
|
||||
_TRITON_COMPILE_VIA_XLA is not None
|
||||
and not _TRITON_COMPILE_VIA_XLA.value
|
||||
):
|
||||
self.skipTest("Triton is not installed.")
|
||||
super().setUp()
|
||||
if compile_jaxpr:
|
||||
compile_jaxpr.cache_clear()
|
||||
_trace_to_jaxpr.cache_clear()
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
@ -761,29 +747,6 @@ class PallasCallTest(PallasTest):
|
||||
self.assertEqual(f(x), 2.)
|
||||
self.assertEqual(trace_count, 1)
|
||||
|
||||
def test_pallas_compilation_cache(self):
|
||||
if not compile_jaxpr:
|
||||
self.skipTest("No Triton GPU.")
|
||||
if self.INTERPRET:
|
||||
raise unittest.SkipTest("No Triton compilation in interpreter mode.")
|
||||
if _TRITON_COMPILE_VIA_XLA.value:
|
||||
raise unittest.SkipTest("Triton is compiled via XLA.")
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
|
||||
grid=1)
|
||||
def add_one(x_ref, o_ref):
|
||||
o_ref[()] = x_ref[()] + 1.
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return add_one(add_one(x))
|
||||
|
||||
x = jnp.array(0., dtype=jnp.float32)
|
||||
self.assertEqual(f(x), 2.)
|
||||
num_misses = compile_jaxpr.cache_info().misses
|
||||
self.assertEqual(num_misses, 1)
|
||||
|
||||
@parameterized.parameters(*[
|
||||
(0, 0, 1),
|
||||
(0, 1, 1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user