This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.
Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.
PiperOrigin-RevId: 643097852
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 642954763
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.
--
5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb by Rahul Batra <rahbatra@amd.com>:
Pallas bitwise_left_shift unit test fix
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21780 from ROCm:fix_pallas_bitwise_left_shift_test 5d4d1fa8f89451b1a11476ab0cfbadbaa476cbbb
PiperOrigin-RevId: 642636198