This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
This follows after #26078, #26313, #26348, adding `debug_info` to more calls to `lu.wrap_init`.
As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.
These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.
In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.
The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.
In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.
PiperOrigin-RevId: 724367877
It would be good to add smaller tests that verify reads and writes to TMEM,
since we depend on it here, but that will come later.
PiperOrigin-RevId: 724328602
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
This follows after #26078, and #26313, adding `debug_info` to
more calls to `lu.wrap_init`.
As part of this I have changed the primitives `custom_vjp_call_jaxpr`
and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`,
which carries debug info. Previously, this was a `Callable`, but in
almost all cases if was really ` lu.WrappedFun.call_wrapped`.
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:
export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling
An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.
PiperOrigin-RevId: 723595727
This change is a part of the initiative to test the JAX wheels in the presubmit properly.
The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.
2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.
3. The version suffix of the wheel in the build rule output depends on the environment variables.
The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.
4. Environment variables combinations for creating wheels with different versions:
* `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
* `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
* `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 723552265
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.
Reverts 3b2410f77cdb0acc6951e1770c1229e6689b7409
PiperOrigin-RevId: 723539592
There's no need to require extra arguments. This makes our calling convention
saner since the logical dimension order stays the same (e.g. for B it's always
k before n in the shape), only the in-memory representation changes.
Other than the API change, this is a NFC.
PiperOrigin-RevId: 723449720
The recent partitionable Threefry upgrade affects binomial sampling under the RBG PRNG scheme because the implementation of `jax.random.binomial` derives internal subkeys with a call to `split`. This led a randomized test to fail by pushing its numeric closeness check just beyond its current relative tolerance. This is very likely a false failure, so we update the rtol.
PiperOrigin-RevId: 723100174