This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.
As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling
An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.
PiperOrigin-RevId: 723595727
This change is a part of the initiative to test the JAX wheels in the presubmit properly.
The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.
2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.
3. The version suffix of the wheel in the build rule output depends on the environment variables.
The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.
4. Environment variables combinations for creating wheels with different versions:
* `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
* `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
* `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 723552265
Hopper and Blackwell MMA instructions can share a lot of the same logic, which is
why I ended up splitting out a large fraction of WGMMA implementation into a common
utility. This should be an NFC for WGMMA, but it allows us to concisely implement
unrolling of MMAs of different sizes into a number of tcgen05.mma instructions.
PiperOrigin-RevId: 723544349
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.
Reverts 3b2410f77cdb0acc6951e1770c1229e6689b7409
PiperOrigin-RevId: 723539592
The previous example implementation loaded TMEM in a layout that was very hard to
efficiently store into SMEM or GMEM. With the new TMEMRef abstraction, we can implement
loads that yield a FragmentedArray with a new tiled layout that allows for efficient
swizzled stores to SMEM.
The new layout is very similar to the one we've been using for WGMMA on Hopper, only the
initial row tiling is increased to 128 (making each warp hold 32 rows, not 16 as previously).
PiperOrigin-RevId: 723506876
The previous impelmentation depends on LLVM intrinsics that have not been submitted
yet. This replaces them with inline PTX (as far as I can tell there's no downside to
that) that's encapsulated into convenience functions.
PiperOrigin-RevId: 723498248
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.
Other than the API change, this is a NFC.
PiperOrigin-RevId: 723449720
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].