Without this fix, the newly added test fails:
```
File "/Users/necula/Source/jax/jax/_src/lax/lax.py", line 2382, in _convert_element_type_jvp_rule
return ad_util.Zero(tangent.aval.update(dtype=dtypes.float0, weak_type=False))
^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'aval'
```
because `tangent` is a float constant.
This comes up in LLM models, where we trace twice (one for eval_shape (usually the init function) and another during jit) when the output jaxpr is the same. This shouldn't happen and we should cache as much as possible.
The only caveat here is that in eval_shape the `traced_for` on `DebugInfo` is set to `jit`. But maybe it's ok to do that if we want to deprecate eval_shape for a AOT style method on `jax.jit` or have it be a thin wrapper around something like `jax.jit(f).eval_shape`
PiperOrigin-RevId: 599602407
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
As part of making JAX's behavior more transparent, it must be clear not only
when code is slow because it's spending all its time missing caches (and hence
retracing/recompiling), but also _why_ it missed those caches. That is, just
knowing (from e.g. setting jax_log_compiles) that code is retracing a lot
doesn't tell the user what to do to fix things. But once the user knows that
the cache misses are due to changing dtypes, or due to jit being passed a new
callable object on every iteration of a loop, it's often clear what to do. And
JAX can provide that information
The main idea here is that pointing out which parts of the cache key differs
from previously-seen keys can constitute a pretty good explanation.
This PR adds an explanation mechanism. It can be enabled in a few different ways:
* setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy;
* setting the config option `jax.config.update('jax_explain_cache_misses', True)`;
* using the context manager `jax._src.config.explain_cache_misses` context
manager (not in public namespace yet);
* when parsing command line flags with absl, using the
`--jax_explain_cache_misses` flag.
Co-authored-by: Yash Katariya <yashkatariya@google.com>
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Ever since the jit-pjit merge, the "Python" jit test has actually just called the same code as the "C++" jit test. We don't have a C++-free jit path any more. Remove the "Python" tests since they don't test anything.
PiperOrigin-RevId: 581965049
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.
To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.
PiperOrigin-RevId: 581302290
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.
PiperOrigin-RevId: 570189648
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.
We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.
Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.
The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:
* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.
PiperOrigin-RevId: 569062402