But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.
PiperOrigin-RevId: 744888557
UserThreadSemantics controls the thread semantics of the Pallas user's code, whereas LoweringSemantics controls the level at which Mosaic GPU emits code.
PiperOrigin-RevId: 744857085
Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion.
For example, in the following snippet
```python
@fuse
def f(x, y):
z1, z2 = fusable_f(x, y)
return g(z1, z2)
```
it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions:
```python
@fuse
def f(x, y):
z1, z2 = fusable_f(x, y)
return g1(z1), g2(z2)
```
This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g.
By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example:
```python
@fuse
def f(x, y):
z1, z2 = fusable_f(x, y)
return z1 + z2
```
because z1 and z2 interact with each other in the output fusion.
The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way.
PiperOrigin-RevId: 744784770
These are available via jax.extend.mlir.dialects.
No deprecation period because jax.interpreters.mlir is not a stable API.
PiperOrigin-RevId: 744712537
I removed `trivial_ctx` from the public `jax.interpreters.partial_eval`
submodule without going through a deprecation cycle, because it is highly
unlikely anyone is using it.
PiperOrigin-RevId: 744645764
This change also fixes the transpose handling in the lowering and completely removes the use of the TransposeTransform. Instead we rely on strides. If we don't discover any issues with this, we will remove the transpose transform also from the mlir dialect.
PiperOrigin-RevId: 744618241
to_elt must run in the parent context, while from_elt must run in the batching
context. We previously had it precisely backward!
Tests didn't catch it because our tests are extremely minimal, and in
particular didn't check a to_elt that binds primitives.
Imported from GitHub PR https://github.com/jax-ml/jax/pull/26906
Allows overriding the slice index used by XLA.
More explicit control over which slice a device ends up in is desirable:
- Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice.
- For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host.
(Companion PR in XLA: https://github.com/openxla/xla/pull/23347)
Copybara import of the project:
--
45aa7ce316bb05ebcc3f3ed2d888385923285e58 by Georg Stefan Schmid <gschmid@nvidia.com>:
[jax.distributed] Allow overriding XLA slice_index
Merging this change closes#26906
COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/26906 from gspschmid:gschmid/jax-override-boot-id 45aa7ce316bb05ebcc3f3ed2d888385923285e58
PiperOrigin-RevId: 744012253
I also changed existing pretty printers for transforms to use {} instead
of [], so that transforms are visually distinct from slicing.
PiperOrigin-RevId: 743869470
For now, we can only specify priority 0 (on-demand) or priority 1 (background) in local DMA.
Also added priority to pretty print by making `dma_start` to `dma_start(px)` which means priority x.
Full example:
```
{ lambda ; a:MemRef<any>{int32[8,128]} b:MemRef<any>{int32[8,128]} c:MemRef<any>{int32[8,128]}
d:MemRef<any>{int32[8,128]} e:MemRef<vmem>{int32[8,128]} f:MemRef<vmem>{int32[8,128]}
g:MemRef<semaphore_mem>{dma_sem[]} h:MemRef<semaphore_mem>{dma_sem[]}. let
dma_start(p1) a[...] -> e[...] g[...]
dma_start(p0) b[...] -> f[...] h[...]
dma_wait e[...] g[...]
dma_wait f[...] h[...]
dma_start(p0) e[...] -> c[...] g[...]
dma_start(p1) f[...] -> d[...] h[...]
dma_wait c[...] g[...]
dma_wait d[...] h[...]
in () }
```
PiperOrigin-RevId: 743815050