As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
This is the first step as part of the JAX CI rework project.
Changes:
* Adds new `ci_{os_name}_{arch}` configs that consolidates the different configs that we use in CI builds under a single config.
* Consolidates Python specific RBE Linux CPU and RBE Linux CUDA configs into Python agnostic `rbe_linux_x86_64` and `rbe_linux_x86_64_cuda`. These new RBE configs inherit the settings in the corresponding `ci_` config and pass in additional RBE specific flags such as platform details, remote execution backend, and authentication details. Hermetic Python version details will now be passed directly in the CI build scripts.
* Adds new RBE Windows configs.
* Removes JAVA flags from RBE configs. These are ignored from Bazel 5+. (See related TF PR: https://github.com/tensorflow/tensorflow/pull/54547)
* Renames some configs: `cuda_nvcc` is now `build_cuda_with_nvcc`, `cuda_clang` is now `build_cuda_with_clang`, `rbe_cross_compile_macos_x86` is now `rbe_cross_compile_darwin_x86_64`, `rbe_cross_compile_linux_arm64` is now `rbe_cross_compile_linux_aarch64`.
* Separates platform specific configs and feature specific configs into their own section.
* Removes unused `--define` configs
* Adds new test configs that will be used when running `bazel test`. `non_multiaccelerator` will be used in RBE Linux CUDA test builds, `non_multiaccelerator_local` and `multiaccelerator_local` will be used in Linux CUDA test builds which depend on local jaxlib and plugin wheels instead of building them along with the rest of the test targets.
* Replaces `--spawn_strategy=standalone` with `--spawn_strategy=local`. `standalone` has been [deprecated by Bazel](https://bazel.build/docs/user-manual#spawn-strategy).
PiperOrigin-RevId: 684532777
This type is unused by JAX, so there is no replacement.
(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)
PiperOrigin-RevId: 684451556
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.
PiperOrigin-RevId: 684116306
Fixes https://github.com/jax-ml/jax/issues/23972.
In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).
In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.
This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.
PiperOrigin-RevId: 684065893
Even when the total size of manual axes is 1, and we can skip creating the `ManualComputationOp`, we need to have the body of what was supposed to be the `shard_map` operate under this new context.
PiperOrigin-RevId: 684055903
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)
2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO
This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.
The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.
PiperOrigin-RevId: 683688470
There is currently an issue with the Mosaic compiler that prevents emitting code that returns semaphores in the presence of the grid argument.
PiperOrigin-RevId: 683681627
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.
PiperOrigin-RevId: 683663541
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.
PiperOrigin-RevId: 683629935
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.
See https://github.com/google/jax/issues/20385.
PiperOrigin-RevId: 683564340