This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:
```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
size = size_smem_ref[0]
pltpu.async_copy(
x_hbm_ref.at[pl.ds(0, size)],
o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```
We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.
PiperOrigin-RevId: 618322737
1. move MutableArray to core.py, and some handlers to their respective files
2. fix a bug in aliasing setup (it was just broken before, now better test coverage)
3. add eager support by enabling get_p, swap_p, and addupdate_p impls
4. improve tests slightly
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.
This simplifies a lot of the Pallas tracing and lowering logic because memory spaces are passed through the Ref type instead of through the BlockMapping.
PiperOrigin-RevId: 599670626
Some notes about this change:
* This change upgrades the `RefView` abstraction to store multiple indexers.
This allows doing things like `ref.at[0].at[0]` to recursively create a view
of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s.
* This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p)
but does *not* generalize their rules. Most of the rules will raise a
NotImplementedError if you use multiple `NDIndexer`s. Adding support will be
done in a future CL.
* With the above in mind, this change only preserves existing public facing APIs
and adding actual support will involve updating the rules.
PiperOrigin-RevId: 595229523
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.
Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.
A very basic offloadable policy can look like this:
```
def policy(prim, *avals, **params):
return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...
PiperOrigin-RevId: 529385973
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397