For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`
This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
* `'bthj,bthD->bthjD'`
* `'bthjD,bthi->ijD'`
The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.
This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**
Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.
There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.
PiperOrigin-RevId: 731732301
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.
PiperOrigin-RevId: 731731530
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.
PiperOrigin-RevId: 731635736
A relatively common pattern I've observed is the following:
```python
_, metrics = some_jax_function()
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:
```python
_, metrics = some_jax_function()
# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
PiperOrigin-RevId: 731626446
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.
PiperOrigin-RevId: 731253109
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction. In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.
PiperOrigin-RevId: 731103569
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.
PiperOrigin-RevId: 731008883
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.
PiperOrigin-RevId: 730971127
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.
PiperOrigin-RevId: 730939406
(Part of general cleanups of the lax.linalg submodule.)
This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility
We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).
Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.
PiperOrigin-RevId: 730928808
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.
PiperOrigin-RevId: 730824239