This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 685689593
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.
PiperOrigin-RevId: 685679646
Cases where we error
* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error
PiperOrigin-RevId: 684983567
Change in preparation for removing HLO ops from the XLA Python bindings.
In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.
PiperOrigin-RevId: 684892102
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.
PiperOrigin-RevId: 684797980
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.
The structure of the test is similar to the vector version of the tests above.
PiperOrigin-RevId: 684569107
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
This is the first step as part of the JAX CI rework project.
Changes:
* Adds new `ci_{os_name}_{arch}` configs that consolidates the different configs that we use in CI builds under a single config.
* Consolidates Python specific RBE Linux CPU and RBE Linux CUDA configs into Python agnostic `rbe_linux_x86_64` and `rbe_linux_x86_64_cuda`. These new RBE configs inherit the settings in the corresponding `ci_` config and pass in additional RBE specific flags such as platform details, remote execution backend, and authentication details. Hermetic Python version details will now be passed directly in the CI build scripts.
* Adds new RBE Windows configs.
* Removes JAVA flags from RBE configs. These are ignored from Bazel 5+. (See related TF PR: https://github.com/tensorflow/tensorflow/pull/54547)
* Renames some configs: `cuda_nvcc` is now `build_cuda_with_nvcc`, `cuda_clang` is now `build_cuda_with_clang`, `rbe_cross_compile_macos_x86` is now `rbe_cross_compile_darwin_x86_64`, `rbe_cross_compile_linux_arm64` is now `rbe_cross_compile_linux_aarch64`.
* Separates platform specific configs and feature specific configs into their own section.
* Removes unused `--define` configs
* Adds new test configs that will be used when running `bazel test`. `non_multiaccelerator` will be used in RBE Linux CUDA test builds, `non_multiaccelerator_local` and `multiaccelerator_local` will be used in Linux CUDA test builds which depend on local jaxlib and plugin wheels instead of building them along with the rest of the test targets.
* Replaces `--spawn_strategy=standalone` with `--spawn_strategy=local`. `standalone` has been [deprecated by Bazel](https://bazel.build/docs/user-manual#spawn-strategy).
PiperOrigin-RevId: 684532777