Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.
We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.
To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.
Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.
A very basic offloadable policy can look like this:
```
def policy(prim, *avals, **params):
return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014
The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.
The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.
Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.
Testing: revised unit test.
PiperOrigin-RevId: 564461564
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.
If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.
PiperOrigin-RevId: 564457393
Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.
We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.
PiperOrigin-RevId: 562988740
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.
TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).
The net effect of this change is mostly to tighten up many test tolerances on TPU.
PiperOrigin-RevId: 562953488
Removes callback testing function and uses io_callback
and pure_callback instead. This allows us to remove
some tests from the PureCallbackTest class.
Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest.
PiperOrigin-RevId: 562285255
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!
The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!
The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.
Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa
PiperOrigin-RevId: 561805840
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.
Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.
Testing: revised unit test.
PiperOrigin-RevId: 561803875
Metrics:
1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true.
2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation.
3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository.
Created a context manager to register/unregister event listeners.
PiperOrigin-RevId: 561771262