support exec_time_optimization_effort and memory_fitting_effort xla compilation

options

PiperOrigin-RevId: 692322944
This commit is contained in:
Matthew Johnson 2024-11-01 16:24:42 -07:00 committed by jax authors
parent d38da5d1b4
commit 0f3ba4250d
3 changed files with 33 additions and 1 deletions

View File

@ -24,6 +24,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
for information on migrating to the new API.
* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
## jax 0.4.35 (Oct 22, 2024)
* Breaking Changes

View File

@ -189,6 +189,13 @@ def get_compile_options(
compile_options.device_assignment = device_assignment
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
'exec_time_optimization_effort', 'memory_fitting_effort']
env_options_overrides = dict(env_options_overrides)
for name in overrides_on_build_options:
if name in env_options_overrides:
setattr(build_options, name, env_options_overrides.pop(name))
compile_options.env_option_overrides = list(env_options_overrides.items())
debug_options = compile_options.executable_build_options.debug_options

View File

@ -60,7 +60,7 @@ from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension, xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
@ -1385,6 +1385,26 @@ class JitTest(jtu.BufferDonationTestCase):
"xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5,
})(1.0) # doesn't crash.
def test_exec_time_optimization_effort_compiler_option(self):
if xla_extension_version < 294:
raise unittest.SkipTest("test requires newer xla extension version")
def f(x):
return jnp.sqrt(x ** 2) + 1.
f_jit = jit(
f,
compiler_options={
"exec_time_optimization_effort": 0.0,
})(1.0) # doesn't crash.
with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"):
f_jit = jit(
f,
compiler_options={
"exec_time_compilation_effort": 0.0,
})(1.0)
def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.