This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.
This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.
Fixes https://github.com/jax-ml/jax/issues/24521
PiperOrigin-RevId: 694274616
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.
There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.
PiperOrigin-RevId: 693114411
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.
Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d
PiperOrigin-RevId: 691568933
In for some reason, extra_jit_context was leaking when `pallas.core` no longer imported `pallas.pallas_call`, leading to leaking XLA Clients.
PiperOrigin-RevId: 689857071
The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state.
**Definition:**
* donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory.
* may_alias: If True, we may return the original buffer depending on the implementation.
**What problem are we solving?**
Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want.
Adding `donate` allows users to avoid this pattern of code:
```
inp = ...
out = device_put(inp, sharding)
jax.block_until_ready(out)
jax.tree.map(lambda x: x.delete(), inp)
```
Now it can just be: `jax.device_put(inp, sharding, donate=True)`
**So what are the semantics of these 2 options?** Let's create a table:
| may-alias \= None (default) | donate \= False (default) | Result |
| :---- | :---- | :---- |
| True | True | Error |
| True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe |
| False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe |
| False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No |
| None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True |
| None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False |
`donate` is best effort for now until we fix the following things:
* Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do.
* Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`.
PiperOrigin-RevId: 681073828
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.
Changes:
1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
3. Add `to_tangent_type` calls in various other places they're missing.
4. Remove non-support for float0 in custom deriviatives?
5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
This allows us to get more cache hits globally. For example:
Before:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
After:
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
Reverts b615266175effe4aefeb903620a19f3719a604da
PiperOrigin-RevId: 675746175
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.
PiperOrigin-RevId: 674324964
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.
Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
This allows us to get more cache hits globally. For example:
Before:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache miss
```
After:
```
jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr) # cpp cache hit
```
Also, we can remove the hack (which I didn't like) in multihost_utils.py.
PiperOrigin-RevId: 665574475
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.
In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.
```
name old cpu/op new cpu/op delta
jit_add_chain 59.1ms ±14% 49.4ms ±10% -16.32% (p=0.008 n=5+5)
name old time/op new time/op delta
jit_add_chain 60.3ms ±14% 50.7ms ±11% -15.99% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 645491650
The motivation for doing this is 2-fold:
1) This will help with deprecating and eventually deleting `jax.xla_computation` which allows for cross backend lowering.
2) Allow for cross-backend and multi-backend lowering via jax AOT APIs which will help cleanup some hacks implemented for `jax.export`.
Note that this is only available by `.trace.lower(lowering_platforms=('tpu',))`. You cannot use `.lower` to do cross-lowering. We can introduce top-level APIs in the future to allow for composable aot apis to make this easier if `.trace(*args).lower(lowering_platforms)` is cumbersome to write.
Designed with @froystig!
PiperOrigin-RevId: 644087787
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.
The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.
PiperOrigin-RevId: 644051624
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852