9761 Commits

Author SHA1 Message Date
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Tzu-Wei Sung
21f5f2d45e [Pallas] Increase #rows when casting to x2.
There is a bug in XLA on v5p.

PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Zac Mustin
0c8e601f90 Support convolution in roofline.
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.

PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322 Allow make_array_from_callback to construct nonaddressable arrays.
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Sergei Lebedev
97bbc37e83 [dlpack] Support more DLPack dtypes now that we target DLPack 1.1
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.

PiperOrigin-RevId: 736882904
2025-03-14 09:10:56 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Ilya Tikhonovskiy
43b78c539f [JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.

PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b Add explicit support for float8_e4m3b11fnuz in pl.dot
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Benjamin Chetioui
d09df7c8ab [Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}.
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Tzu-Wei Sung
e235fb9760 [Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca Merge pull request #26944 from wenscarl:wenscarl/nvfp4
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Yash Katariya
14b9f48535 Allow late binding out_shardings and in_shardings in auto_axes and explicit_axes API
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
Yash Katariya
8674495fd7 [sharding_in_types] Make reshard work with np.array.
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
Pearu Peterson
f608a8c502 Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan. 2025-03-12 17:48:45 +02:00
Dan Foreman-Mackey
8b7cfcb33c Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
7ac088c14f Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2025-03-11 13:53:20 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Peter Hawkins
67aa997f84 Increase the number of iterations in a test that compares rolled versus unrolled HLO for length.
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.

PiperOrigin-RevId: 735875835
2025-03-11 13:45:19 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
shuw
f9aef8a189 Support nvfp4 2025-03-11 19:33:25 +00:00
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
jax authors
c2c68c018f Merge pull request #27059 from jakevdp:fix-while-loop
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Adam Paszke
6f7ce9d048 Skip ASAN tests for the big Mosaic GPU tests
They are timing out.

PiperOrigin-RevId: 735804647
2025-03-11 10:30:04 -07:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00
Adam Paszke
30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00
Benjamin Chetioui
7fd32ecc04 [Pallas/Mosaic GPU] Explicitly disable ops_test on Mosaic GPU pre-Hopper.
PiperOrigin-RevId: 735744473
2025-03-11 07:11:09 -07:00
Shraiysh
cb2eb15739 PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Imported from GitHub PR https://github.com/openxla/xla/pull/22800

Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.

Copybara import of the project:

--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:

Change the default value of print_operand_shape_ to false and print_large_constants_ to true.

Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.

--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:

Handle tests

--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

Merging this change closes #22800

PiperOrigin-RevId: 735690598
2025-03-11 03:17:04 -07:00
Yash Katariya
76dec38286 Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
Ayaka
988a1208a9 Better error message when raise_if_error() is called within a traced context
PiperOrigin-RevId: 735557928
2025-03-10 16:55:06 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
802cb33bf8 [Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
PiperOrigin-RevId: 735519526
2025-03-10 14:49:34 -07:00
jax authors
261e6e5fdc Merge pull request #27038 from jakevdp:vmap-sentinel
PiperOrigin-RevId: 735510065
2025-03-10 14:21:11 -07:00
jax authors
c942b0fef0 Merge pull request #26977 from jakevdp:fix-expn
PiperOrigin-RevId: 735506133
2025-03-10 14:09:32 -07:00