diff --git a/CHANGELOG.md b/CHANGELOG.md index 00b8373e9..140867054 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path. + * Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`). ## jaxlib 0.1.76 (Unreleased) * New features diff --git a/docs/jax.rst b/docs/jax.rst index 75a746cab..d0f7b79ae 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -38,6 +38,7 @@ Just-in-time compilation (:code:`jit`) jit disable_jit + ensure_compile_time_eval xla_computation make_jaxpr eval_shape diff --git a/jax/__init__.py b/jax/__init__.py index 83de09d00..af2cc3439 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -50,6 +50,7 @@ from jax._src.config import ( default_prng_impl as default_prng_impl, numpy_rank_promotion as numpy_rank_promotion, ) +from .core import eval_context as ensure_compile_time_eval from jax._src.api import ( ad, # TODO(phawkins): update users to avoid this. block_until_ready, diff --git a/jax/core.py b/jax/core.py index cf8cc9ac1..3b5b0225d 100644 --- a/jax/core.py +++ b/jax/core.py @@ -814,9 +814,67 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.') @contextmanager -def eval_context(): +def ensure_compile_time_eval(): + """Context manager to ensure evaluation at trace/compile time (or error). + + Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e. + delaying the evaluation of numerical expressions (like jax.numpy function + applications) so that instead of performing those computations eagerly while + evaluating the corresponding Python expressions, their computation is carried + out separately, e.g. after optimized compilation. But this delay can be + undesirable. For example, numerical values might be needed to evaluate Python + control flow and so their evaluation cannot be delayed. As another example, it + may be beneficial to ensure compile time evaluation (or "constant folding") + for performance reasons. + + This context manager ensures that JAX computations are evaluated eagerly. If + eager evaluation is not possible, a ``ConcretizationError`` is raised. + + Here's a contrived example:: + + import jax + import jax.numpy as jnp + + @jax.jit + def f(x): + with jax.ensure_compile_time_eval(): + y = jnp.sin(3.0) + z = jnp.sin(y) + if z > 0: # the value of z is availble and can be used in control flow + return jnp.sin(x) + else: + return jnp.cos(x) + + Here's a real-world example from https://github.com/google/jax/issues/3974:: + + import jax + import jax.numpy as jnp + from jax import random + + @jax.jit + def jax_fn(x): + with jax.ensure_compile_time_eval(): + y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + y2 = y @ y + x2 = jnp.sum(y2) * x + return x2 + + A similar behavior can often be achieved simply by 'hoisting' the constant + expression out of the corresponding staging API:: + + y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100) + + @jax.jit + def jax_fn(x): + y2 = y @ y + x2 = jnp.sum(y2)*x + return x2 + + But in some cases it can be more convenient to use this context manager. + """ with new_base_main(EvalTrace): yield +eval_context = ensure_compile_time_eval # alias, backward compatibility @contextmanager def new_sublevel() -> Generator[None, None, None]: