I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,
def f(x: int) -> str: ...
def g(x: int) -> str: ...
callback = f if ... else g # has type object!
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.
PiperOrigin-RevId: 725223882
The x86 build stopped building completely due to a use of std::filesystem::path, which was added in 10.15.
We've dropped x86 support, but this is an easy enough fix to make and moves x86 to parity with ARM.
This kernel wasn't allowed in export, so no backwards compatibility period is required. Even so, the FFI kernels were added 6 months ago.
PiperOrigin-RevId: 724359996
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.
PiperOrigin-RevId: 724016331
This change is a part of the initiative to test the JAX wheels in the presubmit properly.
The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.
2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.
3. The version suffix of the wheel in the build rule output depends on the environment variables.
The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.
4. Environment variables combinations for creating wheels with different versions:
* `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
* `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
* `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 723552265
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.
See #25196.
A few notes
* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
compilation is done via a new PjRt extension, which uses its own compilation
pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
deprecated and will be removed after 6 months as per
[compatibility guarantees] [*]
[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees
PiperOrigin-RevId: 722773884
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.
PiperOrigin-RevId: 719280601
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.
PiperOrigin-RevId: 718795667
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.
PiperOrigin-RevId: 716844265
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.
PiperOrigin-RevId: 716596848
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.
PiperOrigin-RevId: 716274818
It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.
This also cleans up an unreachable return statement.
PiperOrigin-RevId: 714847066