This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.
Note that we will be introducing some restrictions under Shardy for JAX export:
- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)
We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.
PiperOrigin-RevId: 732878916
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
] a
in (b,) }
```
instead of the previous:
```
{ lambda ; a:f32[]. let
b:f32[] = pjit[
name=<lambda>
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
] a
in (b,) }
```
The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.
Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
that it more closely matches the CDF for low probably events (less than
2**-nmant).
Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.
PiperOrigin-RevId: 732757388
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.
PiperOrigin-RevId: 732618760
For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`
This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
* `'bthj,bthD->bthjD'`
* `'bthjD,bthi->ijD'`
The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.
This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**
Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.
This is in preparation of changes implementing transform inference.
PiperOrigin-RevId: 732091266
The CUDA 12.8 release significantly improved the MMA docs, letting us
improve upon the previously used "magic number" scheme. Sadly, the docs
are still incorrect, but at least I can begin to make some sense of those
parameters.
PiperOrigin-RevId: 732033585
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.
* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times
* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.
PiperOrigin-RevId: 731745062
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.
There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.
PiperOrigin-RevId: 731732301
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.
PiperOrigin-RevId: 731731530
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.
PiperOrigin-RevId: 731635736
A relatively common pattern I've observed is the following:
```python
_, metrics = some_jax_function()
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:
```python
_, metrics = some_jax_function()
# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
PiperOrigin-RevId: 731626446