mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
support exec_time_optimization_effort and memory_fitting_effort xla compilation
options PiperOrigin-RevId: 692322944
This commit is contained in:
parent
d38da5d1b4
commit
0f3ba4250d
@ -24,6 +24,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export)
|
||||
for information on migrating to the new API.
|
||||
|
||||
* New Features
|
||||
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
||||
passing compilation options to XLA. For the moment it's undocumented and
|
||||
may be in flux.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
|
@ -189,6 +189,13 @@ def get_compile_options(
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if env_options_overrides is not None:
|
||||
# Some overrides are passed directly on build_options.
|
||||
overrides_on_build_options = [
|
||||
'exec_time_optimization_effort', 'memory_fitting_effort']
|
||||
env_options_overrides = dict(env_options_overrides)
|
||||
for name in overrides_on_build_options:
|
||||
if name in env_options_overrides:
|
||||
setattr(build_options, name, env_options_overrides.pop(name))
|
||||
compile_options.env_option_overrides = list(env_options_overrides.items())
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
|
@ -60,7 +60,7 @@ from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension, xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -1385,6 +1385,26 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
"xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5,
|
||||
})(1.0) # doesn't crash.
|
||||
|
||||
def test_exec_time_optimization_effort_compiler_option(self):
|
||||
if xla_extension_version < 294:
|
||||
raise unittest.SkipTest("test requires newer xla extension version")
|
||||
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
f_jit = jit(
|
||||
f,
|
||||
compiler_options={
|
||||
"exec_time_optimization_effort": 0.0,
|
||||
})(1.0) # doesn't crash.
|
||||
|
||||
with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"):
|
||||
f_jit = jit(
|
||||
f,
|
||||
compiler_options={
|
||||
"exec_time_compilation_effort": 0.0,
|
||||
})(1.0)
|
||||
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
Loading…
x
Reference in New Issue
Block a user