Compile Triton kernels via XLA by default

PiperOrigin-RevId: 609299269
This commit is contained in:
Sergei Lebedev 2024-02-22 02:31:33 -08:00 committed by jax authors
parent f314f26283
commit 0bf8dddace
3 changed files with 10 additions and 5 deletions

View File

@ -18,6 +18,11 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.
* Changes
* Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable to `"0"`.
* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D

View File

@ -2609,8 +2609,8 @@ def _pallas_call_ttir_lowering(
_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool(
"triton_compile_via_xla",
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", False),
"jax_triton_compile_via_xla",
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True),
help="If True, Pallas delegates Triton kernel compilation to XLA.",
)

View File

@ -53,6 +53,9 @@ jax_test(
"gpu_x32",
"gpu_a100_x32",
],
env = {
"JAX_TRITON_COMPILE_VIA_XLA": "0",
},
shard_count = 4,
deps = [
"//jax:pallas_gpu",
@ -118,9 +121,6 @@ jax_test(
"gpu_x32",
"gpu_a100_x32",
],
env = {
"JAX_TRITON_COMPILE_VIA_XLA": "1",
},
shard_count = 4,
deps = [
"//jax:pallas_gpu",