- Actually de-allocate buffers after a pl.run_scoped.
- Periodically run an explicit garbage collection after
de-allocating buffers.
- Add no-op implementations for a few internal/testing mosaic primitives
(prng_seed_p, prng_random_bits_p, assume_p, random_p).
It leads to much more compact kernel definitions, just look at the diff!
The combination of `pl.core_map` and `pl.run_state` is too noisy to easily
follow the kernel logic.
PiperOrigin-RevId: 740479934
OSS Jax builds for GPU backends split `jaxlib` into three wheels and since we cannot expect a stable C++ ABI among the shared libraries, we refactor to ensure:
1. C++ objects are not created/consumed by different shared libraries.
2. Static objects are declared and defined appropriately.
This PR:
1. Migrates Jax XLA FFI callback handlers from XLA's Internal FFI API to the [External FFI API](https://github.com/openxla/xla/tree/main/xla/ffi#xla-ffi-external-vs-internal-apis). Note that we update both CPU and GPU handlers because we cannot mix Internal and External APIs.
2. Updates how FFI GPU handlers are registered, now analogous to how the original GPU custom call was registered.
3. Adds an `xla::ffi::ExecutionContext` member to `ifrt::PjRtLoadedExectuable` holding opaque pointers to callbacks.
4. Updates Jax `callback.py` to call the new FFI callback handlers.
PiperOrigin-RevId: 740327296
Before this change, we handled attrs for initial-style primitives like jit/scan
like this:
1. the traceable would form a jaxpr and see what attrs were touched (by
jax_getattr or jax_setattr),
2. for each such attr, the traceable would do jax_getattr to get the current
value, tree-flatten, pass the flat valuesinto the (pure) bind, get the new
values out, tree-unflatten, then jax_setattr the result.
That approach would error if the function called `jax_setattr` to set a
previously non-existant attr. That is, this would work:
```python
from jax.experimental.attrs import jax_setattr
class Thing: ...
thing = Thing()
jax_setattr(thing, 'x', 1.0)
```
but it wouldn't work under a `jax.jit`.
This commit makes the same code work under a jit. We just
1. in partial_eval.py's `to_jaxpr`, ensure attrs added during jaxpr formation
are deleted, using a special sentinel value `dne_sentinel` to indicate the
attribute initially did not exist before tracing;
2. in pjit.py's `_get_states`, when reading initial attr values before the
pjit_p bind, if the attribute does not exist we don't try to read it and
instead just use `dne_sentinel` as the value, which is a convenient empty
pytree;
3. in pjit.py's `_attr_token` for jit caching, when forming the cache key based
on the current attr states, we map attrs that don't exist to `dne_sentinel`
(rather than just erroring when the attr doesn't exist, as before).
In short, we use a special value to indicate "does not exist".
If `jax_getattr` supported the 'default' argument, the code would be a little
cleaner since we could avoid the `if hasattr` stuff. And that's probably a
useful feature to have anyway. We can add that in a follow-up.
This PR only makes setattr-to-nonexistant-attr work with jit. We'll add scan
etc in follow-ups.
Kidger's reimplementation of this code notes that the break mode and indent are unused in the _fits function (851379b8f5/wadler_lindig/_wadler_lindig.py (L166)).
We can make the same optimization here.
Without these implementations, `ndim` and `size` were retrieved from the underlying, non-transformed reference and were inconsistent with `TransformedRef.shape`.
PiperOrigin-RevId: 739802491
1. At the lax level, before we bind the primitive, we need to insert pbroadcasts if some inputs are varying. This is equivalent to the rewrite rules that shard_map has.
2. In abstract_eval rules of primitives, we need to check if all inputs are varying across the same mesh axes and then add the `varying_manual_axes` to the output ShapedArray.
This in turn requires us to support `pbroadcast2` and `psum2` primitives in shard_map.py. These primitives don't need to insert any pbroadcasts (equivalent to `no_rewrite` in shard_map) but need to do checks and update the output aval in their abstract_eval rules.
* pbroadcast_p: Union the existing aval.varying_manual_axes + axes (passed to pbroadcast) to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is empty.
* psum2_p: Remove the named axes from aval.varying_manual_axes to calculate the output vma. For checks we need to make sure that the intersection of `aval.varying_manual_axes` and `axes` is NOT empty.
Majority of the primitives should use the standard_insert_pbroadcast and standard_vma_rule and I'll add those in the follow up CLs to other primitives
PiperOrigin-RevId: 739225392