Copybara import of the project:

--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
This commit is contained in:
Matthew Johnson 2022-01-10 20:57:56 -08:00 committed by jax authors
parent 7bc51879d4
commit 1cf7d4ab5d
4 changed files with 62 additions and 1 deletions

View File

@ -32,6 +32,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`. * Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path. file under the given path.
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
## jaxlib 0.1.76 (Unreleased) ## jaxlib 0.1.76 (Unreleased)
* New features * New features

View File

@ -38,6 +38,7 @@ Just-in-time compilation (:code:`jit`)
jit jit
disable_jit disable_jit
ensure_compile_time_eval
xla_computation xla_computation
make_jaxpr make_jaxpr
eval_shape eval_shape

View File

@ -50,6 +50,7 @@ from jax._src.config import (
default_prng_impl as default_prng_impl, default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion, numpy_rank_promotion as numpy_rank_promotion,
) )
from .core import eval_context as ensure_compile_time_eval
from jax._src.api import ( from jax._src.api import (
ad, # TODO(phawkins): update users to avoid this. ad, # TODO(phawkins): update users to avoid this.
block_until_ready, block_until_ready,

View File

@ -814,9 +814,67 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.') raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
@contextmanager @contextmanager
def eval_context(): def ensure_compile_time_eval():
"""Context manager to ensure evaluation at trace/compile time (or error).
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
delaying the evaluation of numerical expressions (like jax.numpy function
applications) so that instead of performing those computations eagerly while
evaluating the corresponding Python expressions, their computation is carried
out separately, e.g. after optimized compilation. But this delay can be
undesirable. For example, numerical values might be needed to evaluate Python
control flow and so their evaluation cannot be delayed. As another example, it
may be beneficial to ensure compile time evaluation (or "constant folding")
for performance reasons.
This context manager ensures that JAX computations are evaluated eagerly. If
eager evaluation is not possible, a ``ConcretizationError`` is raised.
Here's a contrived example::
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
with jax.ensure_compile_time_eval():
y = jnp.sin(3.0)
z = jnp.sin(y)
if z > 0: # the value of z is availble and can be used in control flow
return jnp.sin(x)
else:
return jnp.cos(x)
Here's a real-world example from https://github.com/google/jax/issues/3974::
import jax
import jax.numpy as jnp
from jax import random
@jax.jit
def jax_fn(x):
with jax.ensure_compile_time_eval():
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
y2 = y @ y
x2 = jnp.sum(y2) * x
return x2
A similar behavior can often be achieved simply by 'hoisting' the constant
expression out of the corresponding staging API::
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
@jax.jit
def jax_fn(x):
y2 = y @ y
x2 = jnp.sum(y2)*x
return x2
But in some cases it can be more convenient to use this context manager.
"""
with new_base_main(EvalTrace): with new_base_main(EvalTrace):
yield yield
eval_context = ensure_compile_time_eval # alias, backward compatibility
@contextmanager @contextmanager
def new_sublevel() -> Generator[None, None, None]: def new_sublevel() -> Generator[None, None, None]: