9119 Commits

Author SHA1 Message Date
Peter Hawkins
86de4783bb Remove unused function jax._src.interpreters.mlir.xla_computation_to_mlir_module.
PiperOrigin-RevId: 744934776
2025-04-07 19:26:20 -07:00
jax authors
4bae9cdaf3 Merge pull request #27814 from ZacCranko:harden-cache
PiperOrigin-RevId: 744905319
2025-04-07 17:19:11 -07:00
jax authors
9e0368653c Merge pull request #27793 from dfm:lin-out-fwd
PiperOrigin-RevId: 744901130
2025-04-07 17:03:43 -07:00
Zac Cranko
ca6e470d2f harden cache against jaxlib ver 2025-04-07 23:30:31 +00:00
Rachel Han
84e04fe608 Add custom pretty print rule for the unary ops with accuracy s.t. accuracy is not printed if it's None.
PiperOrigin-RevId: 744889524
2025-04-07 16:25:01 -07:00
Yash Katariya
0a72e856cf Add **experimental** with_dll_constraint API. This is for cases when the users wants to let SPMD decide the sharding.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.

PiperOrigin-RevId: 744888557
2025-04-07 16:21:58 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00
jax authors
3420546866 Merge pull request #27716 from jakevdp:jax-array
PiperOrigin-RevId: 744868565
2025-04-07 15:22:19 -07:00
jax authors
48a9ad0796 Reverts 006a6a63feb64bf9984526030ba008186d69d2b4
PiperOrigin-RevId: 744864022
2025-04-07 15:08:33 -07:00
Justin Fu
b18dc1dfd7 [Mosaic GPU] Add scaffolding for a new lowering "axis" (UserThreadSemantics), in addition to the existing ThreadSemantics (renamed to LoweringSemantics).
UserThreadSemantics controls the thread semantics of the Pallas user's code, whereas LoweringSemantics controls the level at which Mosaic GPU emits code.

PiperOrigin-RevId: 744857085
2025-04-07 14:47:59 -07:00
Jake VanderPlas
96e63eaee8 jnp.linalg: add symmetrize_input argument & docs 2025-04-07 14:46:38 -07:00
Jake VanderPlas
d3cfff057f jax.numpy: support __jax_array__ in remaining APIs 2025-04-07 14:08:35 -07:00
Matthew Johnson
9a3e94dec5 [shard-map] add while_map rep rule
fixes #27664
2025-04-07 21:00:59 +00:00
Parker Schuh
b6e4b93851 Add jaxlib_extension_version guard against explicit copying
in jax.device_put.

PiperOrigin-RevId: 744838237
2025-04-07 13:50:15 -07:00
Dan Foreman-Mackey
dc00f9bdae Apply output forwarding in lin rule for pjit. 2025-04-07 15:39:33 -04:00
jax authors
f23dd6429b Merge pull request #26853 from jeffcarp:scalar-event
PiperOrigin-RevId: 744807700
2025-04-07 12:15:28 -07:00
Jevin Jiang
e1e37f8d5e [Mosaic TPU] FWD compatibility needs to keep previous version at least one month.
PiperOrigin-RevId: 744796256
2025-04-07 11:44:13 -07:00
Sharad Vikram
fcf5115fdb [Pallas Fuser] Add output_fusion_mask support
Currently, the fusion API assumes by default that all of the outputs of a @fuse-decorated function are computed jointly in one big output fusion.

For example, in the following snippet
```python
@fuse
def f(x, y):
  z1, z2 = fusable_f(x, y)
  return g(z1, z2)
```

it assumes that `g` is a single function that operates on z1 and z2 jointly. However, in practice, the fusable may want two separate output fusions:
```python
@fuse
def f(x, y):
  z1, z2 = fusable_f(x, y)
  return g1(z1), g2(z2)
```

This is a special case of the general function but the fusable may not be materializing z1 and z2 at the same time so may not be able to compute this efficiently with a single function g.

By decorating a fusable with an output fusion prefix (in the above example `(True, True)`), the fusable will now be given a pair of functions `g1` and `g2` if the output fusion is "separable". For example, we'd error for the following example:

```python
@fuse
def f(x, y):
  z1, z2 = fusable_f(x, y)
  return z1 + z2
```

because z1 and z2 interact with each other in the output fusion.

The rationale for providing a PyTree prefix (as opposed to a more general mechanism) is that the fusable can group its outputs into subtrees that it can identify with the output prefix. This does restrict the types of output groups that are possible (outputs must be part of the same shared subtree, as opposed to arbitrarily scattered throughput the output pytree), but this is an okay restriction because the fusable author is responsible for the grouping and can always construct it that way.

PiperOrigin-RevId: 744784770
2025-04-07 11:11:49 -07:00
Jacob Burnim
855829e1bc Add int4, uint4 to test_util.suppported_types
To increase test coverage for these types.

PiperOrigin-RevId: 744777880
2025-04-07 10:52:33 -07:00
jax authors
5581e7d0ca Merge pull request #27735 from dfm:lin-fwd
PiperOrigin-RevId: 744757213
2025-04-07 09:57:57 -07:00
Sergei Lebedev
51c224c446 Removed deprecated jax.core.{full_lower,jaxpr_as_fun,lattice_join}
PiperOrigin-RevId: 744754730
2025-04-07 09:50:43 -07:00
Sergei Lebedev
ff00fa91ce Removed unused jax_remat_opt_barrier config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 744754343
2025-04-07 09:48:57 -07:00
Dan Foreman-Mackey
dbc3bcd3ce Apply forwarding in pjit linearization rule to avoid intermediate copies. 2025-04-07 12:13:58 -04:00
Dan Foreman-Mackey
5a3fc606d4 Deprecate public export of mlir.custom_call.
PiperOrigin-RevId: 744722183
2025-04-07 07:58:20 -07:00
Peter Hawkins
70485e31b9 Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
These are available via jax.extend.mlir.dialects.

No deprecation period because jax.interpreters.mlir is not a stable API.

PiperOrigin-RevId: 744712537
2025-04-07 07:20:24 -07:00
Sergei Lebedev
c2aa811cd6 jex.core.Var is no longer ordered
This behavior was only needed for kfac_jax which has been updated *not* to
rely on variable ordering.

PiperOrigin-RevId: 744691114
2025-04-07 05:50:41 -07:00
Adam Paszke
4596ee3cc5 Add a missing jaxlib version check in Pallas TPU lowering
PiperOrigin-RevId: 744668747
2025-04-07 04:13:39 -07:00
Sergei Lebedev
6e93fa34f3 Removed unused deprecations
PiperOrigin-RevId: 744659794
2025-04-07 03:39:20 -07:00
George Necula
ce7dc85104 [export] Add support for serializing functions with PRNG keys as inputs/outputs
This introduces version 4 of serialization, fully backwards compatible
with versions 2 and 3.

Fixes: #24143
2025-04-07 11:53:20 +02:00
Sergei Lebedev
245194ffa1 Use contextlib.nullcontext instead of trivial_ctx
I removed `trivial_ctx` from the public `jax.interpreters.partial_eval`
submodule without going through a deprecation cycle, because it is highly
unlikely anyone is using it.

PiperOrigin-RevId: 744645764
2025-04-07 02:40:56 -07:00
Dimitar (Mitko) Asenov
90cfa99a68 [Mosaic GPU] Support Slice and Transpose in the Pallas WGMMA lowering
This change also fixes the transpose handling in the lowering and completely removes the use of the TransposeTransform. Instead we rely on strides. If we don't discover any issues with this, we will remove the transpose transform also from the mlir dialect.

PiperOrigin-RevId: 744618241
2025-04-07 00:52:06 -07:00
Yash Katariya
cccc34dc23 Raise an error if the type passed to axis_types argument of Mesh and AbstractMesh is not jax.sharding.AxisType.
PiperOrigin-RevId: 744602037
2025-04-06 23:38:09 -07:00
Matthew Johnson
6bae8c75c8 [vmappable] fix trace context bugs
to_elt must run in the parent context, while from_elt must run in the batching
context. We previously had it precisely backward!

Tests didn't catch it because our tests are extremely minimal, and in
particular didn't check a to_elt that binds primitives.
2025-04-06 00:33:40 +00:00
Yash Katariya
fc5d9a4fce Check that memory_kind of an aval is always None
PiperOrigin-RevId: 744136969
2025-04-04 19:23:25 -07:00
Sergei Lebedev
aab6613944 [pallas:mosaic_gpu] Fixed a typo in _barrier_arrive_pp_eqn
PiperOrigin-RevId: 744089477
2025-04-04 15:34:06 -07:00
Gleb Pobudzey
d81c0ffeb7 [Mosaic GPU] Limit the maximum number of registers per thread to 255.
PiperOrigin-RevId: 744083257
2025-04-04 15:10:59 -07:00
Yash Katariya
549f1cd856 Don't set memory_kind to None if the mesh is AbstractMesh and the
PiperOrigin-RevId: 744077517
2025-04-04 14:51:12 -07:00
Georg Stefan Schmid
5d4ac775dd PR #26906: [jax.distributed] Allow explicitly setting slice_index
Imported from GitHub PR https://github.com/jax-ml/jax/pull/26906

Allows overriding the slice index used by XLA.

More explicit control over which slice a device ends up in is desirable:
- Various parts of the ecosystem equate slices with "devices communicating via fast interconnect". With the arrival of NVL72 we want devices managed by multiple hosts to form a single slice.
- For debugging purposes it can be useful to allow devices on the same host (managed in separate processes) to be treated as different slices. For example, [Orbax](https://github.com/google/orbax)'s local checkpointing presumes the existence of at least two slices, so overriding the boot id will allow us to test local checkpointing on a single host.

(Companion PR in XLA: https://github.com/openxla/xla/pull/23347)
Copybara import of the project:

--
45aa7ce316bb05ebcc3f3ed2d888385923285e58 by Georg Stefan Schmid <gschmid@nvidia.com>:

[jax.distributed] Allow overriding XLA slice_index

Merging this change closes #26906

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/26906 from gspschmid:gschmid/jax-override-boot-id 45aa7ce316bb05ebcc3f3ed2d888385923285e58
PiperOrigin-RevId: 744012253
2025-04-04 11:29:17 -07:00
jax authors
e2f67e0ef1 Always force synchronous pipelining when we have vmem storage and trivial
PiperOrigin-RevId: 743993611
2025-04-04 10:33:46 -07:00
jax authors
5a7dc42ad9 Merge pull request #27730 from froystig:out-shard-trunc-normal
PiperOrigin-RevId: 743990366
2025-04-04 10:22:57 -07:00
jax authors
35d75183c7 _attempt_rewriting_take_via_slice(): canonicalize the slice index before checking it's not too long, so that e.g. my_1d_array[:, ...] can be treated as a slice rather than generating a gather operation.
PiperOrigin-RevId: 743986126
2025-04-04 10:10:38 -07:00
Christos Perivolaropoulos
b9007145d7 [mgpu:pallas] Fix swizzling check bug where it was comparing w/ #bytes rather than #elems.
PiperOrigin-RevId: 743953910
2025-04-04 08:33:28 -07:00
Christos Perivolaropoulos
da7b1577e2 [mgpu:pallas] Swizzle elements computed using bitwidth rather than bytewidth.
PiperOrigin-RevId: 743933866
2025-04-04 07:21:40 -07:00
Christos Perivolaropoulos
cbae2539d4 [mgpu:pallas] Typo in UnswizzleRef.untransform_reshape() check.
PiperOrigin-RevId: 743920665
2025-04-04 06:33:09 -07:00
Adam Paszke
635805e9b0 [Mosaic GPU] Allow replicating data over warps
This extends the tiled layouts further and allows us to replace
WGMMA_COL_LAYOUT implementation with a TiledLayout.

PiperOrigin-RevId: 743909503
2025-04-04 05:43:44 -07:00
Adam Paszke
b0a920dd92 [Mosaic GPU] Don't force TiledLayout.lane_dims to partition data
This allows us to replicate elements across a warp and replace
the special WGMMAFragRowLayout with a TiledLayout.

PiperOrigin-RevId: 743903003
2025-04-04 05:11:47 -07:00
Sergei Lebedev
206dec859d [pallas:mosaic_gpu] Added pretty printing to primitives consuming refs
I also changed existing pretty printers for transforms to use {} instead
of [], so that transforms are visually distinct from slicing.

PiperOrigin-RevId: 743869470
2025-04-04 02:34:19 -07:00
jax authors
e619fc0b72 Avoid double buffering when no windowing info is present.
PiperOrigin-RevId: 743834475
2025-04-04 00:03:47 -07:00
Roy Frostig
97cecdf862 add an out_sharding option to jax.random.truncated_normal
Drop into `Auto` mode in the implementation.
2025-04-03 22:34:08 -07:00
Jevin Jiang
a9bd1e3f9d [Pallas TPU] Support DMA priority in async copy start
For now, we can only specify priority 0 (on-demand) or priority 1 (background) in local DMA.

Also added priority to pretty print by making `dma_start` to `dma_start(px)` which means priority x.

Full example:
```
{ lambda ; a:MemRef<any>{int32[8,128]} b:MemRef<any>{int32[8,128]} c:MemRef<any>{int32[8,128]}
    d:MemRef<any>{int32[8,128]} e:MemRef<vmem>{int32[8,128]} f:MemRef<vmem>{int32[8,128]}
    g:MemRef<semaphore_mem>{dma_sem[]} h:MemRef<semaphore_mem>{dma_sem[]}. let
    dma_start(p1) a[...] -> e[...] g[...]
    dma_start(p0) b[...] -> f[...] h[...]
    dma_wait e[...] g[...]
    dma_wait f[...] h[...]
    dma_start(p0) e[...] -> c[...] g[...]
    dma_start(p1) f[...] -> d[...] h[...]
    dma_wait c[...] g[...]
    dma_wait d[...] h[...]
  in () }
```

PiperOrigin-RevId: 743815050
2025-04-03 22:26:40 -07:00