mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Compile Triton kernels via XLA by default
PiperOrigin-RevId: 609299269
This commit is contained in:
parent
f314f26283
commit
0bf8dddace
@ -18,6 +18,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax.tree.transpose` (i.e. {func}`jax.tree_util.tree_transpose`) now accepts
|
||||
`inner_treedef=None`, in which case the inner treedef will be automatically inferred.
|
||||
|
||||
* Changes
|
||||
* Pallas now uses XLA instead of the Triton Python APIs to compile Triton
|
||||
kernels. You can revert to the old behavior by setting the
|
||||
`JAX_TRITON_COMPILE_VIA_XLA` environment variable to `"0"`.
|
||||
|
||||
* Deprecations & Removals
|
||||
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
|
||||
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
|
||||
|
@ -2609,8 +2609,8 @@ def _pallas_call_ttir_lowering(
|
||||
|
||||
|
||||
_TRITON_COMPILE_VIA_XLA = config.DEFINE_bool(
|
||||
"triton_compile_via_xla",
|
||||
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", False),
|
||||
"jax_triton_compile_via_xla",
|
||||
default=config.bool_env("JAX_TRITON_COMPILE_VIA_XLA", True),
|
||||
help="If True, Pallas delegates Triton kernel compilation to XLA.",
|
||||
)
|
||||
|
||||
|
@ -53,6 +53,9 @@ jax_test(
|
||||
"gpu_x32",
|
||||
"gpu_a100_x32",
|
||||
],
|
||||
env = {
|
||||
"JAX_TRITON_COMPILE_VIA_XLA": "0",
|
||||
},
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:pallas_gpu",
|
||||
@ -118,9 +121,6 @@ jax_test(
|
||||
"gpu_x32",
|
||||
"gpu_a100_x32",
|
||||
],
|
||||
env = {
|
||||
"JAX_TRITON_COMPILE_VIA_XLA": "1",
|
||||
},
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:pallas_gpu",
|
||||
|
Loading…
x
Reference in New Issue
Block a user