This is deprecated as of https://github.com/google/jax/pull/15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.
PiperOrigin-RevId: 522168275
This change re-introduces symbolic zero support for `custom_vjp`.
This time:
* The forward rule API is slightly different, accepting two-field
records at pytree leaves rather than pairs.
* In the default setting where symbolic_zeros is not set, there are no
new requirements from pytree node definitions that are involved in
the primal arguments. This avoids any change in behavior on the
default path. In particular, custom pytree node definitions that
aren't completely polymorphic in unflattening can remain as is.
* There is an additional test involving a custom pytree node.
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
This is the first step in a revision to how we handle the debug info pertaining
to staged functions' parameter names and result pytree paths. To limit
complexity, this first step adds machinery required to make our MLIR lowerings'
parameter and result names work, but it does *not* yet unify it with existing
arg-name machinery used at tracing time (in partial_eval.py, e.g.
partial_eval.DebugInfo etc). That unification will come in a follow up commits.
(I wrote the unified version first, then broke it down into this sequence of
commits.)
Another thing that will arrive in follow-up commits is pmap support (handling
static_broadcasted_argnames). This PR doesn't include support for pmap because
pmap's final style implementation requires slightly different machinery than
jit/pjit's initial style implementation. Indeed this PR removes the previous
support for pmap arg/result info, and skips the corresponding tests, because
the previous support didn't handle pmap's static_broadcasted_argnums (and I
think it could even lead to silently incorrect annotations when pmap was not at
the top-level, though I didn't work out an example case to be sure that was
possible).
This commit includes the changes from PR #15079, so that PR should be merged first.
Here's the _why_ of this change:
* The pre-existing solution (from PRs #14702, #14764, and #14813) did not
handle static_argnums or static_argnames correctly. Instead it would fail,
resulting in debug info being dropped from the jaxpr and ultimately the MLIR
computation (but no Exception raised). We need to handle
static_argnums/argnames because while the corresponding parameters remain on
the Python callable signature, they are excluded from the args/kwargs
pytrees; the previous solution didn't account for that divergence.
* The best way to handle static_argnums/argnames is to work out this debug info
when we still have the original args/kwargs in hand, i.e. much earlier than
the previous mechanism. We then just have to pass this debug info to the
right places. Indeed we often already had to work out some debug-related
information at these call sites (e.g. whether the function is being staged
out for jit, or scan, or whatever), so after this change we're working out
all the debug info at the same time.
* A side benefit is that now to get this debug info we no longer need to
unflatten user pytree defs with dummy objects (to reconstruct dummy
args/kwargs trees so that we can call inspect.signature(fun).bind), since we
just use the original args/kwargs instead. Since some user pytree node types
are not fully polymorphic in their element types (e.g. their __init__ methods
sometimes contained assertions about their elements' shapes, expecting them
to be arrays), that means the new mechanism is fundamentally more compatible
with custom pytree node types.
More concretely, effecting those high-level changes led to:
* replacing the previous `core.DebugInfo` with a class `core.JaxprDebugInfo`,
which in addition to the more precise name has fields like
`arg_names: Tuple[Optional[str], ...]` and
`result_paths: Tuple[Optional[str], ...]`, rather than
`in_tree: Optional[PyTreeDef]`, reflecting the fact that we work out the
actual debug info more eagerly than before and we don't need pytrees for
dummy-unflattening;
* introducing the new `partial_eval.TracingDebugInfo` class representing the
debug info about inputs which we have available at tracing time; in a
follow-up PR, we'll adapt partial_eval.py to use this new class and we'll
delete `partial_eval.DebugInfo` and its corresponding helper methods (not
done in this commit just to reduce complexity of each change);
* moving the old `core.DebugInfo`, which before #14702 lived in
partial_eval.py, back to partial_eval.py pending cleanup (deletion) of that
partial_eval.py debug info code;
* making specific jaxpr-processing functions produce an appropriately updated
`core.JaxprDebugInfo` object for their output (e.g. `pe.dce_jaxpr` prunes
elements from the `arg_names` field), maintaining now-checked invariants like
a Jaxpr's `debug_info` should have the same number of argument names as the
jaxpr has invars (the jaxpr-processing functions updated here are enough for
top-level jit jaxprs to have debug info attached, handling the original
intended use case of jit(f).lower, but not e.g. grad-of-jit cases, which can
be handled later by updating `ad.jvp_jaxpr` and the like to produce updated
debug info on their outputs);
* add some tests for static_argnums/static_argnames.
Phew! Can't wait to land those follow-ups too :P
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203
Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.