mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Copybara import of the project:
-- 4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>: add jax.ensure_compile_time_eval to public api aka jax.core.eval_context COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311 PiperOrigin-RevId: 420928687
This commit is contained in:
parent
7bc51879d4
commit
1cf7d4ab5d
@ -32,6 +32,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
|
||||
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
|
||||
file under the given path.
|
||||
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
|
||||
|
||||
## jaxlib 0.1.76 (Unreleased)
|
||||
* New features
|
||||
|
@ -38,6 +38,7 @@ Just-in-time compilation (:code:`jit`)
|
||||
|
||||
jit
|
||||
disable_jit
|
||||
ensure_compile_time_eval
|
||||
xla_computation
|
||||
make_jaxpr
|
||||
eval_shape
|
||||
|
@ -50,6 +50,7 @@ from jax._src.config import (
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
block_until_ready,
|
||||
|
60
jax/core.py
60
jax/core.py
@ -814,9 +814,67 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
def ensure_compile_time_eval():
|
||||
"""Context manager to ensure evaluation at trace/compile time (or error).
|
||||
|
||||
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
|
||||
delaying the evaluation of numerical expressions (like jax.numpy function
|
||||
applications) so that instead of performing those computations eagerly while
|
||||
evaluating the corresponding Python expressions, their computation is carried
|
||||
out separately, e.g. after optimized compilation. But this delay can be
|
||||
undesirable. For example, numerical values might be needed to evaluate Python
|
||||
control flow and so their evaluation cannot be delayed. As another example, it
|
||||
may be beneficial to ensure compile time evaluation (or "constant folding")
|
||||
for performance reasons.
|
||||
|
||||
This context manager ensures that JAX computations are evaluated eagerly. If
|
||||
eager evaluation is not possible, a ``ConcretizationError`` is raised.
|
||||
|
||||
Here's a contrived example::
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
with jax.ensure_compile_time_eval():
|
||||
y = jnp.sin(3.0)
|
||||
z = jnp.sin(y)
|
||||
if z > 0: # the value of z is availble and can be used in control flow
|
||||
return jnp.sin(x)
|
||||
else:
|
||||
return jnp.cos(x)
|
||||
|
||||
Here's a real-world example from https://github.com/google/jax/issues/3974::
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
with jax.ensure_compile_time_eval():
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
y2 = y @ y
|
||||
x2 = jnp.sum(y2) * x
|
||||
return x2
|
||||
|
||||
A similar behavior can often be achieved simply by 'hoisting' the constant
|
||||
expression out of the corresponding staging API::
|
||||
|
||||
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||
|
||||
@jax.jit
|
||||
def jax_fn(x):
|
||||
y2 = y @ y
|
||||
x2 = jnp.sum(y2)*x
|
||||
return x2
|
||||
|
||||
But in some cases it can be more convenient to use this context manager.
|
||||
"""
|
||||
with new_base_main(EvalTrace):
|
||||
yield
|
||||
eval_context = ensure_compile_time_eval # alias, backward compatibility
|
||||
|
||||
@contextmanager
|
||||
def new_sublevel() -> Generator[None, None, None]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user