This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.
There are three options for running the tests:
1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.
PiperOrigin-RevId: 735765819
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:
* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).
* When an object method is called, the method call runs as a colocated Python
function call on the backend.
* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.
This change provides an intial API implementation. Main limitations are as
follows:
* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.
* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.
These limitations will be lifted as the object API implementation is improved.
PiperOrigin-RevId: 729629265
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.
PiperOrigin-RevId: 726920440
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.
PiperOrigin-RevId: 719918121
* We don't need to keep a separate thread-local stack of objects: the config state already has a thread local.
* We don't need to keep an explicit stack of contexts at all: we can maintain it in the context manager frames.
* When checking for incompatible nested compute_ons, we can just check the current state: no need to look higher in the stack!
PiperOrigin-RevId: 719892989
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.
In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.
PiperOrigin-RevId: 718950828
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.
During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.
PiperOrigin-RevId: 704911253
Shardy custom_partitioning.
The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.
When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.
When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.
PiperOrigin-RevId: 703187962
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.
PiperOrigin-RevId: 700443656
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.
Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.
PiperOrigin-RevId: 698493184
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.
This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.
It is currently in an early development stage. Please proceed with a caution
when using the API.
PiperOrigin-RevId: 690705899
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.
PiperOrigin-RevId: 683119193
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.
See https://github.com/openxla/stablehlo/pull/2259
PiperOrigin-RevId: 678649138