16405 Commits

Author SHA1 Message Date
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gary Miguel
69d66f66df vmap mismatch size error message: handle *args
Fixes: https://github.com/jax-ml/jax/issues/26908
2025-03-05 13:08:54 -08:00
Owen Lockwood
3e4dc0d490 add pmap axes hints 2025-03-05 12:14:24 -08:00
Adam Paszke
8df00e2666 [Mosaic GPU] Remove support for large tiles on Blackwell
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.

PiperOrigin-RevId: 733786884
2025-03-05 10:34:53 -08:00
Dan Foreman-Mackey
4a93c8b30c Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
2025-03-05 10:22:40 -08:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
Adam Paszke
4493889cda [Mosaic GPU] Add support for small tiles for (WG)MMA LHS
Thanks to the previous refactor the change is quite trivial and mostly
focuses on adding tests.

PiperOrigin-RevId: 733754797
2025-03-05 09:01:20 -08:00
Adam Paszke
d119138766 [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.

PiperOrigin-RevId: 733737302
2025-03-05 08:06:06 -08:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
Dan Foreman-Mackey
342cb7b99a Attempt 2 at landing custom_vjp.optimize_remat using custom_dce.
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.

Reverts 1f3176636d304398b00a7d2cb0933859618affd8

PiperOrigin-RevId: 733643149
2025-03-05 02:06:35 -08:00
Christos Perivolaropoulos
51719a1afe [mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
2025-03-04 19:41:04 -08:00
Skye Wanderman-Milne
cebedb9f1a Update version number after 0.5.2 release 2025-03-04 18:49:12 -08:00
Yash Katariya
766315f791 Make sure concat + vmap of sharded input and replicated input works properly.
In this case, the example boils down to:

```
inp1 = f32[16@x, 4]
inp2 = f32[4]

def f(x: f32[4], y: f32[4])
  return jnp.concat([x, y], axis=-1)

vmap(f, in_axes=(0, None))(inp1)
```

This example was breaking in concat batching rule because we didn't broadcast with the right sharding.

PiperOrigin-RevId: 733536944
2025-03-04 18:35:13 -08:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
c145102ef4 Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
2025-03-04 15:21:01 -08:00
jax authors
b238bad703 Merge pull request #26901 from NeilGirdhar:etils
PiperOrigin-RevId: 733466732
2025-03-04 14:28:51 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
jax authors
4a73134b2f Merge pull request #26912 from dfm:resolve-args-error-message
PiperOrigin-RevId: 733378431
2025-03-04 10:26:43 -08:00
Neil Girdhar
52ab8c4cc2 Fix detection of epath
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])

[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.
2025-03-04 11:44:27 -05:00
jax authors
97db925a7d Merge pull request #26765 from Qwlouse:patch-1
PiperOrigin-RevId: 733339465
2025-03-04 08:30:45 -08:00
Adam Paszke
cdae5fcfc7 [Mosaic GPU] Make sure to do the async proxy fence before wargroup sync
This is the ordering we want for a proper release of generic SMEM stores
into the async proxy. The old order was problematic: once the warpgroup
barrier was complete, some warps could get deselected before they get to
the fence. For as long as the first warp would make progress, it could go
through the fence along and start issuing TMA copies before other warps
have synchronized with the async proxy.

I have not observed this problem in any of our kernels so far, but this
order seems safer to me.

PiperOrigin-RevId: 733333814
2025-03-04 08:11:15 -08:00
Dan Foreman-Mackey
8b1b039e0d Improve error messages when input argument resolution fails in custom_* APIs. 2025-03-04 10:31:35 -05:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
Ayaka
ea53c7616b Fix thread safety of JAX error checking
Fix thread safety of JAX error checking by making the global states thread local

PiperOrigin-RevId: 733164878
2025-03-03 20:56:01 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
d32e282ff9 Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
2025-03-03 17:26:44 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
2c7043f63d Merge pull request #26865 from jakevdp:fix-indexing-error
PiperOrigin-RevId: 733085471
2025-03-03 15:38:20 -08:00
jax authors
f9f47217df Merge pull request #26862 from jakevdp:logsumexp-docs
PiperOrigin-RevId: 733080943
2025-03-03 15:24:10 -08:00
jax authors
4944dcb977 Merge pull request #26897 from jakevdp:cond-doc
PiperOrigin-RevId: 733077065
2025-03-03 15:13:23 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Jake VanderPlas
84ca80d215 doc: in lax.cond, note that both branches will be traced 2025-03-03 13:05:24 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
carlosgmartin
897e1a1310 Fix linalg.norm to return zero for proper norms of empty matrices. 2025-03-03 15:02:34 -05:00
Adam Paszke
e9f95cc3a7 [Mosaic GPU] Make the small WGMMA tile independent of transpose flags
Now the small tiling is always `(8, swizzle // bytewidth(dtype))`, no matter whether the input
is transposed or not. This should simply the follow-up refactoring of the code and make it easier
to enable small tiling for LHS too.

PiperOrigin-RevId: 732933005
2025-03-03 08:30:57 -08:00
Bart Chrzaszcz
ed4a7bbab1 #sdy Add JAX backwards compatibility test.
This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

PiperOrigin-RevId: 732893432
2025-03-03 06:01:34 -08:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
Christos Perivolaropoulos
b9ebd9188f [mgpu] Forach in tiled layout.
PiperOrigin-RevId: 732872906
2025-03-03 04:31:59 -08:00
Adam Paszke
11e6cfbc6a [Mosaic GPU][NFC] Move the calculation of group strides into _validate_mma
This allows us to unify this logic between Hopper and Blackwell.

PiperOrigin-RevId: 732862875
2025-03-03 03:51:20 -08:00
jax authors
bbadf99054 Merge pull request #26697 from gnecula:pp_aliased_var_names
PiperOrigin-RevId: 732860010
2025-03-03 03:36:50 -08:00
Adam Paszke
3038348f23 [Mosaic GPU][NFC] Clean up the computation of group strides
PiperOrigin-RevId: 732849235
2025-03-03 02:50:48 -08:00
George Necula
a6c47d6f36 Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
2025-03-03 11:38:51 +01:00
Parker Schuh
b8b690e594 Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
that it more closely matches the CDF for low probably events (less than
2**-nmant).

Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.

PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
Dimitar (Mitko) Asenov
3b305c6617 [Mosaic GPU] Infer layouts (transforms) on memrefs that directly feed into the dialect wgmma op.
This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations.

PiperOrigin-RevId: 732618760
2025-03-02 03:17:13 -08:00
Dimitar (Mitko) Asenov
c60ef5a2a1 [Mosaic GPU] Wire up the slice_lengths and indices operands in lowering of the MLIR dialect.
This enables slicing via TMA and is needed for pipelining.

PiperOrigin-RevId: 732613803
2025-03-02 02:43:47 -08:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
jax authors
2a1eeb0ce8 Chnages for kernel export
PiperOrigin-RevId: 732383028
2025-03-01 00:32:39 -08:00
Anton Osokin
1f3176636d Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00