diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b629631e..457107d8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. +* New Features + * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for + passing compilation options to XLA. For the moment it's undocumented and + may be in flux. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 8a2d6047e..113f7507c 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -189,6 +189,13 @@ def get_compile_options( compile_options.device_assignment = device_assignment if env_options_overrides is not None: + # Some overrides are passed directly on build_options. + overrides_on_build_options = [ + 'exec_time_optimization_effort', 'memory_fitting_effort'] + env_options_overrides = dict(env_options_overrides) + for name in overrides_on_build_options: + if name in env_options_overrides: + setattr(build_options, name, env_options_overrides.pop(name)) compile_options.env_option_overrides = list(env_options_overrides.items()) debug_options = compile_options.executable_build_options.debug_options diff --git a/tests/api_test.py b/tests/api_test.py index bb1d24729..8ab5d90f6 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -60,7 +60,7 @@ from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension +from jax._src.lib import xla_extension, xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1385,6 +1385,26 @@ class JitTest(jtu.BufferDonationTestCase): "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, })(1.0) # doesn't crash. + def test_exec_time_optimization_effort_compiler_option(self): + if xla_extension_version < 294: + raise unittest.SkipTest("test requires newer xla extension version") + + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "exec_time_optimization_effort": 0.0, + })(1.0) # doesn't crash. + + with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + f_jit = jit( + f, + compiler_options={ + "exec_time_compilation_effort": 0.0, + })(1.0) + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1.