This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.
This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.
Fixes https://github.com/jax-ml/jax/issues/24521
PiperOrigin-RevId: 694274616
Why?
Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.
PiperOrigin-RevId: 686329828
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX.
PiperOrigin-RevId: 668527139
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
The problem is that squeezing was happening on noncommitted arrays
so list(x) was moving all the shards to device 0. This will potentially
cause ooms.
PiperOrigin-RevId: 661408226
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.
To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.
memory space is exposed via JAX memories API so it doesn't have to be in the layout API.
Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.
Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.
PiperOrigin-RevId: 647487510
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare.
* Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property.
* Use `a is b` instead of `id(a) == id(b)` since the former is more concise.
PiperOrigin-RevId: 627080325
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.
The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.
PiperOrigin-RevId: 624969222
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.
PiperOrigin-RevId: 624763603
* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible.
* `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop.
PiperOrigin-RevId: 624288637