9119 Commits

Author SHA1 Message Date
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Yash Katariya
f8b98993b8 Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings.
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
2025-03-07 07:01:34 -08:00
Dan Foreman-Mackey
b7ecfdfd95 Update ad.backward_pass to support non-linear functions of constants. 2025-03-07 09:54:06 -05:00
jax authors
de78d2cc71 Merge pull request #26950 from lockwo:Owen/add-pmap-typehint
PiperOrigin-RevId: 734500798
2025-03-07 04:10:35 -08:00
Daniel Suo
e6db7a9d99 Dedup non-ref constants closed in cond branch functions.
PiperOrigin-RevId: 734497907
2025-03-07 04:01:42 -08:00
shuw
ccbe9f7cd6 Fix lint 2025-03-07 04:52:58 +00:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
Jake VanderPlas
b441b2b7a5 Prevent tracer leaks in scipy.special.expn 2025-03-06 14:38:11 -08:00
Ayaka
8c89da7cdc Minor bug fixes in error checking
PiperOrigin-RevId: 734126415
2025-03-06 06:57:52 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Yash Katariya
a67ab9fade Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).
PiperOrigin-RevId: 733959962
2025-03-05 20:09:30 -08:00
Yash Katariya
ba5349f896 Add a note about uneven sharding and with_sharding_constraint. Fixes https://github.com/jax-ml/jax/issues/26946
PiperOrigin-RevId: 733953836
2025-03-05 19:35:03 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gary Miguel
69d66f66df vmap mismatch size error message: handle *args
Fixes: https://github.com/jax-ml/jax/issues/26908
2025-03-05 13:08:54 -08:00
Owen Lockwood
3e4dc0d490 add pmap axes hints 2025-03-05 12:14:24 -08:00
Dan Foreman-Mackey
4a93c8b30c Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
2025-03-05 10:22:40 -08:00
shuw
c099e8081d support e2m1fn 2025-03-05 17:44:34 +00:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
Dan Foreman-Mackey
342cb7b99a Attempt 2 at landing custom_vjp.optimize_remat using custom_dce.
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.

Reverts 1f3176636d304398b00a7d2cb0933859618affd8

PiperOrigin-RevId: 733643149
2025-03-05 02:06:35 -08:00
Yash Katariya
766315f791 Make sure concat + vmap of sharded input and replicated input works properly.
In this case, the example boils down to:

```
inp1 = f32[16@x, 4]
inp2 = f32[4]

def f(x: f32[4], y: f32[4])
  return jnp.concat([x, y], axis=-1)

vmap(f, in_axes=(0, None))(inp1)
```

This example was breaking in concat batching rule because we didn't broadcast with the right sharding.

PiperOrigin-RevId: 733536944
2025-03-04 18:35:13 -08:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
c145102ef4 Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
2025-03-04 15:21:01 -08:00
jax authors
b238bad703 Merge pull request #26901 from NeilGirdhar:etils
PiperOrigin-RevId: 733466732
2025-03-04 14:28:51 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
jax authors
4a73134b2f Merge pull request #26912 from dfm:resolve-args-error-message
PiperOrigin-RevId: 733378431
2025-03-04 10:26:43 -08:00
Neil Girdhar
52ab8c4cc2 Fix detection of epath
Unfortunately, the old detection code doesn't guarantee that `epath` is
installed:
```
[utM] In [7]: importlib.util.find_spec("etils.epath")
Out[7]: ModuleSpec(name='etils.epath',
loader=<_frozen_importlib_external.SourceFileLoader object at
0x73b8492a7230>,
origin='/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath/__init__.py',
submodule_search_locations=['/home/neil/src/cmm/.venv/lib/python3.12/site-packages/etils/epath'])

[utM] In [8]: import etils.epath
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent
call last)
Cell In[8], line 1
----> 1 import etils.epath
...
ModuleNotFoundError: No module named 'importlib_resources'
```
This happened every time I ran jax with a clean environment.
2025-03-04 11:44:27 -05:00
jax authors
97db925a7d Merge pull request #26765 from Qwlouse:patch-1
PiperOrigin-RevId: 733339465
2025-03-04 08:30:45 -08:00
Dan Foreman-Mackey
8b1b039e0d Improve error messages when input argument resolution fails in custom_* APIs. 2025-03-04 10:31:35 -05:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
Ayaka
ea53c7616b Fix thread safety of JAX error checking
Fix thread safety of JAX error checking by making the global states thread local

PiperOrigin-RevId: 733164878
2025-03-03 20:56:01 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
d32e282ff9 Add fuser to jax.experimental.pallas
Note that fuser is considered experimental within Pallas and APIs are subject to change

PiperOrigin-RevId: 733117882
2025-03-03 17:26:44 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
2c7043f63d Merge pull request #26865 from jakevdp:fix-indexing-error
PiperOrigin-RevId: 733085471
2025-03-03 15:38:20 -08:00
jax authors
f9f47217df Merge pull request #26862 from jakevdp:logsumexp-docs
PiperOrigin-RevId: 733080943
2025-03-03 15:24:10 -08:00
jax authors
4944dcb977 Merge pull request #26897 from jakevdp:cond-doc
PiperOrigin-RevId: 733077065
2025-03-03 15:13:23 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Jake VanderPlas
84ca80d215 doc: in lax.cond, note that both branches will be traced 2025-03-03 13:05:24 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
carlosgmartin
897e1a1310 Fix linalg.norm to return zero for proper norms of empty matrices. 2025-03-03 15:02:34 -05:00
Bart Chrzaszcz
ed4a7bbab1 #sdy Add JAX backwards compatibility test.
This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

PiperOrigin-RevId: 732893432
2025-03-03 06:01:34 -08:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
George Necula
a6c47d6f36 Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
2025-03-03 11:38:51 +01:00
Parker Schuh
b8b690e594 Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
that it more closely matches the CDF for low probably events (less than
2**-nmant).

Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.

PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
jax authors
2a1eeb0ce8 Chnages for kernel export
PiperOrigin-RevId: 732383028
2025-03-01 00:32:39 -08:00
Anton Osokin
1f3176636d Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00