Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
* Python wheels follow a naming convention: standard wheels use the pattern `*-cp<python_version>-cp<python_version>-*`, while free-threaded wheels use `*-cp<python_version>-cp<python_version>t-*`. Update the pytest workflows to look for free-threaded wheels and ensure that standard wheel tests exclude free-threaded wheels.
* Skip zstandard for python3.13-nogil due to compilation failure https://github.com/indygreg/python-zstandard/issues/231.
PiperOrigin-RevId: 732070585
This test configuration dates back to the time when we were very unsure
about how to use TMA. At this point we have plenty of experience and it
makes more sense to focus the test in question on verifying WGMMA. This
also simplifies adding support for smaller RHS tiling.
PiperOrigin-RevId: 732040900
The CUDA 12.8 release significantly improved the MMA docs, letting us
improve upon the previously used "magic number" scheme. Sadly, the docs
are still incorrect, but at least I can begin to make some sense of those
parameters.
PiperOrigin-RevId: 732033585
In general, this is a good feature, but it assumed that the packing type utilized here was exclusively for backcompat, and so always applied the adjustment.
PiperOrigin-RevId: 731954456
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.
There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.
PiperOrigin-RevId: 731808733
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.
There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.
PiperOrigin-RevId: 731732301
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.
PiperOrigin-RevId: 731731530
We need to set them as `min(num_cpu_cores, num_gpus * max_tests_per_gpu, total ram in GB/6)` where max_tests_per_gpu = (GPU memory / 2GB)
PiperOrigin-RevId: 731730857
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files.
PiperOrigin-RevId: 731721522
The CUDNN_VERSION is defined as (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL).
Therefore cuDNN 9.1.0 is represented as 90100 - not as 91000.
PiperOrigin-RevId: 731641814
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.
PiperOrigin-RevId: 731635736
A relatively common pattern I've observed is the following:
```python
_, metrics = some_jax_function()
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:
```python
_, metrics = some_jax_function()
# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
PiperOrigin-RevId: 731626446