Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.
PiperOrigin-RevId: 395423006
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.
PiperOrigin-RevId: 391272270
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
When checking the data type of the dynamic arguments in jax.value_and_grad the
PyTree is unflattened with `None` (the output of `_check_input_dtype_grad`) as
value for each leaf. This causes an issue if a custom PyTree does not accept
None as a value for the leaves (issue #7546) even though the tree that is
returned from the data type check is never used.
This commit solves this issue by iterating over tree_leaves when checking data
types rather than using tree_map.
This **will** be a **breaking** change, as pxla.ShardedDeviceArray constructor won't be valid anymore:
- for the next Jax release
- on the condition _USE_EXPERIMENTAL_CPP_SDA is switch to `_xla_extension_version > xx` and with the associated jaxlib release.
I am already adding the impact for the users in the CHANGELOG, we can still move it to the next version depending on when it's shipped.
Similarly to JAX.jit, for which we have a C++ `DeviceArray` and a Python `_DeviceArray`, we will introduce 2 objects for ShardedDeviceArray, with the Python object only for JAX extensions not compatible with the C++ object (e.g. Cloud TPU).
- Add `make_sharded_device_array` to be used within JAX and for hackers that need to construct SDA objects.
- Make sure the C++ object is valid by
(a) extending `DeviceArrayBase` (done in Python), as it brings a bunch of methods and enable `isinstance(x, DeviceArray)`
(b) Adding the same methods as the Python SDA.
NOTE: mypy has troubled with the " -> pxla.ShardedDeviceArray` function return type annotation, I had to remove 2.
PiperOrigin-RevId: 389876734
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:
add 'inline' option to xla_call for jaxpr inlining
--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:
add 'inline' to jit docstring
--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:
new_sublevel in jax2tf
PiperOrigin-RevId: 371542778
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
Adds a new CompiledFunctionCache object that can be passed to the CompiledFunction constructor. Multiple CompiledFunctions can share the same cache capacity.
This change is in preparation for adding jit decorators to many standard library functions. We do not want to drastically increase the number of cached computations, and having a cache shared between functions allows us to avoid this.
Also allow cache entries to persist past the lifetime of the enclosing jax.jit(f) call so long as `f` remains alive. This mirrors the behavior of the existing linear_util cache that JAX uses in Python.
PiperOrigin-RevId: 368664536