mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Copybara import of the project:
-- 4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>: add jax.ensure_compile_time_eval to public api aka jax.core.eval_context COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311 PiperOrigin-RevId: 420928687
This commit is contained in:
parent
7bc51879d4
commit
1cf7d4ab5d
@ -32,6 +32,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
|
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
|
||||||
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
|
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
|
||||||
file under the given path.
|
file under the given path.
|
||||||
|
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
|
||||||
|
|
||||||
## jaxlib 0.1.76 (Unreleased)
|
## jaxlib 0.1.76 (Unreleased)
|
||||||
* New features
|
* New features
|
||||||
|
@ -38,6 +38,7 @@ Just-in-time compilation (:code:`jit`)
|
|||||||
|
|
||||||
jit
|
jit
|
||||||
disable_jit
|
disable_jit
|
||||||
|
ensure_compile_time_eval
|
||||||
xla_computation
|
xla_computation
|
||||||
make_jaxpr
|
make_jaxpr
|
||||||
eval_shape
|
eval_shape
|
||||||
|
@ -50,6 +50,7 @@ from jax._src.config import (
|
|||||||
default_prng_impl as default_prng_impl,
|
default_prng_impl as default_prng_impl,
|
||||||
numpy_rank_promotion as numpy_rank_promotion,
|
numpy_rank_promotion as numpy_rank_promotion,
|
||||||
)
|
)
|
||||||
|
from .core import eval_context as ensure_compile_time_eval
|
||||||
from jax._src.api import (
|
from jax._src.api import (
|
||||||
ad, # TODO(phawkins): update users to avoid this.
|
ad, # TODO(phawkins): update users to avoid this.
|
||||||
block_until_ready,
|
block_until_ready,
|
||||||
|
60
jax/core.py
60
jax/core.py
@ -814,9 +814,67 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
|||||||
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def eval_context():
|
def ensure_compile_time_eval():
|
||||||
|
"""Context manager to ensure evaluation at trace/compile time (or error).
|
||||||
|
|
||||||
|
Some JAX APIs like ``jax.jit`` and ``jax.lax.scan`` involve staging, i.e.
|
||||||
|
delaying the evaluation of numerical expressions (like jax.numpy function
|
||||||
|
applications) so that instead of performing those computations eagerly while
|
||||||
|
evaluating the corresponding Python expressions, their computation is carried
|
||||||
|
out separately, e.g. after optimized compilation. But this delay can be
|
||||||
|
undesirable. For example, numerical values might be needed to evaluate Python
|
||||||
|
control flow and so their evaluation cannot be delayed. As another example, it
|
||||||
|
may be beneficial to ensure compile time evaluation (or "constant folding")
|
||||||
|
for performance reasons.
|
||||||
|
|
||||||
|
This context manager ensures that JAX computations are evaluated eagerly. If
|
||||||
|
eager evaluation is not possible, a ``ConcretizationError`` is raised.
|
||||||
|
|
||||||
|
Here's a contrived example::
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def f(x):
|
||||||
|
with jax.ensure_compile_time_eval():
|
||||||
|
y = jnp.sin(3.0)
|
||||||
|
z = jnp.sin(y)
|
||||||
|
if z > 0: # the value of z is availble and can be used in control flow
|
||||||
|
return jnp.sin(x)
|
||||||
|
else:
|
||||||
|
return jnp.cos(x)
|
||||||
|
|
||||||
|
Here's a real-world example from https://github.com/google/jax/issues/3974::
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import random
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def jax_fn(x):
|
||||||
|
with jax.ensure_compile_time_eval():
|
||||||
|
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||||
|
y2 = y @ y
|
||||||
|
x2 = jnp.sum(y2) * x
|
||||||
|
return x2
|
||||||
|
|
||||||
|
A similar behavior can often be achieved simply by 'hoisting' the constant
|
||||||
|
expression out of the corresponding staging API::
|
||||||
|
|
||||||
|
y = random.randint(random.PRNGKey(0), (1000,1000), 0, 100)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def jax_fn(x):
|
||||||
|
y2 = y @ y
|
||||||
|
x2 = jnp.sum(y2)*x
|
||||||
|
return x2
|
||||||
|
|
||||||
|
But in some cases it can be more convenient to use this context manager.
|
||||||
|
"""
|
||||||
with new_base_main(EvalTrace):
|
with new_base_main(EvalTrace):
|
||||||
yield
|
yield
|
||||||
|
eval_context = ensure_compile_time_eval # alias, backward compatibility
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def new_sublevel() -> Generator[None, None, None]:
|
def new_sublevel() -> Generator[None, None, None]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user