Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.
We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.
To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.
Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.
A very basic offloadable policy can look like this:
```
def policy(prim, *avals, **params):
return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
This flag only exists for the use of JAX's own tests, and doesn't need to exist for most JAX users.
4f805c2d8f allows this move since none of the test comparison utilities now depend on the choice of backend. (That dependency was only an administrative dependency for external users of JAX, since the only public users of the test comparison utilites are the gradient utilities, which always override the default tolerance with their own tolerances.)
PiperOrigin-RevId: 563195002
Sanitizes the name of tests so that a name matches the rules of an
identifier for pytest -k and unittest -k test filtering. Sequences
of problematic characters are replaced with a single "_".
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:
Rename opaque dtype to extended dtype.
This includes three deprecations:
- jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
- jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
- the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.
As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
JAX shape polymorphism relies on implicit assumptions.
For example, when tracing with input specification `(a, a)`,
we assume that the first two dimensions have the same size
greater or equal to 1.
Here we extend the checking that these assumptions hold. When
we call an `Exported` module from jax, with `jax_export.call_exported`
we check these assumptions statically. However, when we
stage an `Exported` using `XlaCallModule` to be called from
TensorFlow, or when we use TF graph serialization we need
to check these assumptions when we execute and compile
the op (that is when the shapes are available).
To prepare for this compile-time shape checking we add
`Exported.shape_check_module` to produce a serialized
MLIR module containing the shape checking code. This
will be added in a future change to `XlaCallModule`.
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).
TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
* allow rc2 in numpy versions when parsed by tests.
* don't cast np.empty(), which can lead to cast errors.
* NumPy 1.24 now warns on overflowing scalar int to array casts in more
places.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670