For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.
The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.
PiperOrigin-RevId: 624969222
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.
PiperOrigin-RevId: 624763603
When AlgebraicSimplifier calls `dot->SetupDerivedInstruction(new_lhs);` in HandleDot, lhs sharding was cleared when dot didn't have a sharding. With this CL, lhs preserves its sharding because the condition for clearing the sharding is narrowed down to only when shapes are incompatible.
Fixes https://github.com/google/jax/issues/20710
PiperOrigin-RevId: 624731930
* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible.
* `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop.
PiperOrigin-RevId: 624288637
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.
PiperOrigin-RevId: 624270701
Thank you to gnecula@ for adding the jax2tf_associative_scan_reductions flag and context: 5bfe1852a4
For GPU, the specific implementation of `cumsum` can make the whopping difference between a latency in microseconds versus milliseconds!
Before this change, adjusting the method of lowering `cumsum` via this scope has no effect:
```py
with jax.jax2tf_associative_scan_reductions(True):
...
```
... because the cumsum method (and other reduce methods) have their implementations set when the `jax2tf` library is imported, ie when this line is called:
```py
from jax.experimental import jax2tf
```
Thus, any future switches of the implementation (to, say associative scanning), even if they happen before the `jax2tf.convert` method executes, had no effect because methods such as `cumsum` had already been curried at import time.
This change fixes that by varying the implementation based on the current value of `config.jax2tf_associative_scan_reductions`.
We use existing tests to verify the continued correctness of this CL that affects latency. We add TPU to the list of devices to apply some limitations - One TPU unit test had suddenly failed because the scope now works: Even though TPUs use a different path to lower by default, the context above explicitly sets to associative scanning.
PiperOrigin-RevId: 624264567