We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.
PiperOrigin-RevId: 708143852
For (1, 128) tiling 32-bit input, it assigns (1, 128) tiling at output, which can be invalid (e.g. it should be (1, 256) for bf16)
PiperOrigin-RevId: 708112341
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.
Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.
As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
/usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```
Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.
Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`
PiperOrigin-RevId: 708035062
Adds an extra verification check. Since the source semaphore is used only for remote DMAs, we should check that device or core IDs are also provided when source semaphore is provided.
PiperOrigin-RevId: 707675228
`tile_masks` was updated to use implicit, but we skipped the reshape for `tiles`
Seems like there was even a bug before cl/707025084: `tile_masks` was never reshaped, so if the shape was 1D and a store mask was specified, there would be a mismatch in dimensions.
PiperOrigin-RevId: 707368670
The `jaxlib/cuda_plugin_extension.cc` and `jaxlib/rocm_plugin_extension.cc` files were nearly identical so this change consolidates the shared implementation into a single target.
PiperOrigin-RevId: 704785926