diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4493748..35d93eeae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 +* Deprecations & Removals + * Pallas now exclusively uses XLA for compiling kernels on GPU. The old + lowering pass via Triton Python APIs has been removed and the + `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. + + ## jaxlib 0.4.27 ## jax 0.4.26 (April 3, 2024) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 4d3501022..bd14339ae 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -67,7 +67,5 @@ pytype_strict_library( "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", - # Users are expected to add a jax_triton dependency to use the legacy - # lowering path. ], ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 01c921df1..4750da635 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -19,74 +19,17 @@ from __future__ import annotations -import dataclasses import io from typing import Any -import zlib import jax from jax import core as jax_core -from jax._src import config from jax._src.interpreters import mlir from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.triton import lowering -from jax._src import util - - -@dataclasses.dataclass -class CompilationResult: - kernel_name: str - ttir: str - ptx: str - shared_mem_bytes: int - compute_capability: int - lowering_result: lowering.LoweringResult - - -@util.weakref_lru_cache -def compile_jaxpr( - jaxpr: jax_core.Jaxpr, - in_shapes, - grid_mapping: pallas_core.GridMapping, - name: str, - num_warps: int, - num_stages: int, - debug: bool, -) -> CompilationResult: - from jax_triton.triton_lib import compile_ttir_to_ptx_inplace # type: ignore - import triton.backends.nvidia.compiler as cb # type: ignore - - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - compute_capability = triton_kernel_call_lib.get_compute_capability(device) - target = ("cuda", compute_capability) - cuda_backend = cb.CUDABackend(target) - cuda_options = cuda_backend.parse_options( - dict( - num_warps=num_warps, - num_stages=num_stages, - debug=debug, - ) - ) - lowering_result = lowering.lower_jaxpr_to_triton_module( - jaxpr, in_shapes, grid_mapping, name, cuda_options - ) - - ttir = str(lowering_result.module) - ptx, name, shared_mem_bytes, _ = compile_ttir_to_ptx_inplace( - lowering_result.module, - cuda_backend, - cuda_options, - compute_capability, - ) - return CompilationResult( - name, ttir, ptx, shared_mem_bytes, compute_capability, lowering_result - ) def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: @@ -101,81 +44,6 @@ def avals_to_layouts(avals): return [list(reversed(range(aval.ndim))) for aval in avals] -def _pallas_call_ptx_lowering( - ctx: mlir.LoweringRuleContext, - *in_nodes, - jaxpr: jax_core.Jaxpr, - name: str, - in_shapes: tuple[jax.ShapeDtypeStruct, ...], - out_shapes: tuple[jax.ShapeDtypeStruct, ...], - debug: bool, - input_output_aliases: tuple[tuple[int, int], ...], - grid_mapping: pallas_core.GridMapping, - triton_params: dict[str, Any], - num_warps: int, - num_stages: int, -): - compilation_result = compile_jaxpr( - jaxpr, - (*in_shapes, *out_shapes), - grid_mapping, - name, - num_warps, - num_stages, - debug=debug, - ) - # Triton returns a tuple for ROCm. We just want file path to be passed - if ctx.module_context.platforms[0] == 'rocm': - compilation_result.ptx = compilation_result.ptx[1] - - if debug: - compilation_result.lowering_result.module.dump() - - kernel = triton_kernel_call_lib.TritonKernel( - compilation_result.kernel_name, - num_warps, - compilation_result.shared_mem_bytes, - compilation_result.ptx, - compilation_result.ttir, - compilation_result.compute_capability, - 1, - 1, - 1, # TODO(giorgioa): Add support for clustering on H100s on Pallas. - ) - - grid = normalize_grid(compilation_result.lowering_result.grid) - - kernel_params = [] - for _ in range(len(in_shapes) + len(out_shapes)): - kernel_params.append( - triton_kernel_call_lib.create_array_parameter( - 0, # bytes to zero # TODO(cjfj): Expose through user API. - 16, # divisible by 16 - ) - ) - - kernel_call = triton_kernel_call_lib.TritonKernelCall( - kernel, grid[0], grid[1], grid[2], kernel_params - ) - - out_types = [ - ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - for shape in out_shapes - ] - - serialized_metadata = triton_params.get("serialized_metadata", b"") - kernel_call_proto = kernel_call.to_proto(name, serialized_metadata) - return mlir.custom_call( - call_target_name="triton_kernel_call", - result_types=out_types, - operands=in_nodes, - backend_config=zlib.compress(kernel_call_proto), - operand_layouts=avals_to_layouts(ctx.avals_in), - result_layouts=avals_to_layouts(ctx.avals_out), - operand_output_aliases=dict(input_output_aliases), - ).results - - def _pallas_call_ttir_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, @@ -243,13 +111,6 @@ def _pallas_call_ttir_lowering( ).results -_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool( - "jax_triton_compile_via_xla", - default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True), - help="If True, Pallas delegates Triton kernel compilation to XLA.", -) - - def pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, @@ -298,12 +159,7 @@ def pallas_call_lowering( print(jaxpr) print(grid_mapping) - if _TRITON_COMPILE_VIA_XLA.value: - lowering_fn = _pallas_call_ttir_lowering - else: - lowering_fn = _pallas_call_ptx_lowering - - return lowering_fn( + return _pallas_call_ttir_lowering( ctx, *in_nodes, jaxpr=jaxpr, diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 1a6c5ece6..e58bd498b 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -13,10 +13,6 @@ # limitations under the License. """Contains Triton specific Pallas functions.""" -try: - from jax._src.pallas import triton - get_compute_capability = triton.get_compute_capability - del triton -except ImportError as e: - raise ImportError("Cannot import Pallas Triton backend. " - "Make sure you've installed jax-triton.") from e +from jax._src.pallas import triton +get_compute_capability = triton.get_compute_capability +del triton diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 0ed105bde..76fb64800 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -56,9 +56,6 @@ jax_test( "gpu_a100_x32", "gpu_h100_x32", ], - env = { - "JAX_TRITON_COMPILE_VIA_XLA": "0", - }, shard_count = 4, deps = [ "//jax:pallas_gpu", @@ -101,41 +98,6 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( - name = "pallas_via_xla_test", - srcs = [ - "pallas_test.py", - ], - backend_tags = { - "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 - }, - config_tags_overrides = { - "gpu_a100_x32": { - "ondemand": False, # Include in presubmit. - }, - }, - disable_backends = [ - "cpu", - "tpu", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_a100", - "gpu_h100", - ], - enable_configs = [ - "gpu_a100_x32", - "gpu_h100_x32", - ], - shard_count = 4, - deps = [ - "//jax:pallas_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), -) - jax_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 24679d8b0..565d130e5 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -24,7 +24,7 @@ from jax._src import test_util as jtu from jax.experimental import pallas as pl try: from jax.experimental.pallas import gpu as plgpu -except (ModuleNotFoundError, ImportError): +except ImportError: plgpu = None import jax.numpy as jnp import numpy as np diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 2a28e94fa..12d751343 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -33,25 +33,21 @@ from jax._src import state from jax._src.lax.control_flow.for_loop import for_loop from jax._src.lib import version as jaxlib_version from jax._src.pallas.pallas_call import _trace_to_jaxpr +if jaxlib_version >= (0, 4, 24): + from jax._src.pallas.triton.lowering import LoweringError +else: + LoweringError = Exception from jax.interpreters import partial_eval as pe import jax.numpy as jnp from jax.experimental import pallas as pl +try: + from jax.experimental.pallas import gpu as plgpu +except ImportError: + plgpu = None from jax.experimental.pallas.ops import attention from jax.experimental.pallas.ops import layer_norm from jax.experimental.pallas.ops import rms_norm from jax.experimental.pallas.ops import softmax -try: - from jax._src.pallas.triton.lowering import LoweringError - from jax._src.pallas.triton.pallas_call_registration import ( - compile_jaxpr, - _TRITON_COMPILE_VIA_XLA, - ) - from jax.experimental.pallas import gpu as plgpu -except ModuleNotFoundError: - LoweringError = Exception - compile_jaxpr = None - _TRITON_COMPILE_VIA_XLA = None - plgpu = None import numpy as np @@ -143,17 +139,7 @@ class PallasTest(parameterized.TestCase): not self.check_gpu_capability_at_least(80)): self.skipTest("Only works on GPUs with capability >= sm80") - try: - import triton # noqa: F401 - except ImportError: - if ( - _TRITON_COMPILE_VIA_XLA is not None - and not _TRITON_COMPILE_VIA_XLA.value - ): - self.skipTest("Triton is not installed.") super().setUp() - if compile_jaxpr: - compile_jaxpr.cache_clear() _trace_to_jaxpr.cache_clear() def pallas_call(self, *args, **kwargs): @@ -761,29 +747,6 @@ class PallasCallTest(PallasTest): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) - def test_pallas_compilation_cache(self): - if not compile_jaxpr: - self.skipTest("No Triton GPU.") - if self.INTERPRET: - raise unittest.SkipTest("No Triton compilation in interpreter mode.") - if _TRITON_COMPILE_VIA_XLA.value: - raise unittest.SkipTest("Triton is compiled via XLA.") - - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - grid=1) - def add_one(x_ref, o_ref): - o_ref[()] = x_ref[()] + 1. - - @jax.jit - def f(x): - return add_one(add_one(x)) - - x = jnp.array(0., dtype=jnp.float32) - self.assertEqual(f(x), 2.) - num_misses = compile_jaxpr.cache_info().misses - self.assertEqual(num_misses, 1) - @parameterized.parameters(*[ (0, 0, 1), (0, 1, 1),