This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
This allows us to get more cache hits globally. For example:
Before:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
After:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
Reverts b615266175effe4aefeb903620a19f3719a604da
PiperOrigin-RevId: 675746175
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.
PiperOrigin-RevId: 674324964
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.
Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
This allows us to get more cache hits globally. For example:
Before:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
```
After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
```
Also, we can remove the hack (which I didn't like) in multihost_utils.py.
PiperOrigin-RevId: 665574475
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
The motivation for doing this is 2-fold:
1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.
2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.
Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.
Designed with @froystig!
PiperOrigin-RevId: 644087787
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.
The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.
PiperOrigin-RevId: 644051624
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't.
Since we can keep the existing behavior and still merge the implementation is a good cleanup!
Fixes https://github.com/google/jax/issues/21116
PiperOrigin-RevId: 641347140
These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.
We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).
The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.