Add a build wheel, pyproject.toml and setup.py.
The directory structure in jax repo is:
jax/
└── plugins/
└── cuda/
├── __init__.py
├── pyproject.toml
└── setup.py
Installed package structure is:
jax_plugins/
└── xla_cuda_cu12/
├── __init__.py
└── xla_cuda_plugin.so
The major cuda version will be part of the package name.
The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"
PiperOrigin-RevId: 565187954
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.
We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.
To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.
Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.
A very basic offloadable policy can look like this:
```
def policy(prim, *avals, **params):
return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014
The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.
The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
Some flags do not affect the compilation output. These should
not be part of the cache key, otherwise changing them will change
the key causing an unnecessary cache miss.
Synchronize the exclusions between the command-line flags and
DebugOptions. Add if-this-then-that lint checks to keep them
in sync.
PiperOrigin-RevId: 564474189
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.
Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.
Testing: revised unit test.
PiperOrigin-RevId: 564461564
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393