9731 Commits

Author SHA1 Message Date
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
7ac088c14f Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2025-03-11 13:53:20 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Peter Hawkins
67aa997f84 Increase the number of iterations in a test that compares rolled versus unrolled HLO for length.
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.

PiperOrigin-RevId: 735875835
2025-03-11 13:45:19 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
jax authors
c2c68c018f Merge pull request #27059 from jakevdp:fix-while-loop
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Adam Paszke
6f7ce9d048 Skip ASAN tests for the big Mosaic GPU tests
They are timing out.

PiperOrigin-RevId: 735804647
2025-03-11 10:30:04 -07:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00
Adam Paszke
30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00
Benjamin Chetioui
7fd32ecc04 [Pallas/Mosaic GPU] Explicitly disable ops_test on Mosaic GPU pre-Hopper.
PiperOrigin-RevId: 735744473
2025-03-11 07:11:09 -07:00
Shraiysh
cb2eb15739 PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true.
Imported from GitHub PR https://github.com/openxla/xla/pull/22800

Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.

Copybara import of the project:

--
e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>:

Change the default value of print_operand_shape_ to false and print_large_constants_ to true.

Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.

--
7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>:

Handle tests

--
b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

--
d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>:

Fix more tests

Merging this change closes #22800

PiperOrigin-RevId: 735690598
2025-03-11 03:17:04 -07:00
Yash Katariya
76dec38286 Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
Ayaka
988a1208a9 Better error message when raise_if_error() is called within a traced context
PiperOrigin-RevId: 735557928
2025-03-10 16:55:06 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
802cb33bf8 [Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
PiperOrigin-RevId: 735519526
2025-03-10 14:49:34 -07:00
jax authors
261e6e5fdc Merge pull request #27038 from jakevdp:vmap-sentinel
PiperOrigin-RevId: 735510065
2025-03-10 14:21:11 -07:00
jax authors
c942b0fef0 Merge pull request #26977 from jakevdp:fix-expn
PiperOrigin-RevId: 735506133
2025-03-10 14:09:32 -07:00
Praveen Narayanan
b6d4fe5387 Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
PiperOrigin-RevId: 735471245
2025-03-10 12:25:22 -07:00
jax authors
18f2f19c1a Merge pull request #26525 from wenscarl:e2m1fn
PiperOrigin-RevId: 735457804
2025-03-10 11:46:18 -07:00
Jacob Burnim
73d20cd62a [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Jake VanderPlas
8ecadfdf9d Internal: make it easier to detect the vmap sentinel 2025-03-10 11:37:50 -07:00
Nitin Srinivasan
d41e96835b Modify version test to consider "rc" versions as well
I was testing the RC promotion workflow and found that the version test failed as it does not consider pre-releases. Therefore, this commit modifies the `VERSION_PATTERN` to also consider "rc" wheels.

Fixes https://github.com/jax-ml/jax/actions/runs/13705984545/job/38331236497

PiperOrigin-RevId: 735444828
2025-03-10 11:10:18 -07:00
jax authors
ab0ce8a448 Merge pull request #26811 from dfm:direct-lin
PiperOrigin-RevId: 735388827
2025-03-10 08:39:49 -07:00
Dimitar (Mitko) Asenov
d2bf034c47 [Mosaic GPU] Test the wgmma_op lowering when a is in registers.
I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test.

PiperOrigin-RevId: 735384104
2025-03-10 08:25:43 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Dan Foreman-Mackey
36d515ed2c A few more fixes for debug_info tests with direct_linearize. 2025-03-08 07:47:24 -05:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Matthew Johnson
251b93ebd7 fixups that we meant to include in #26427
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-08 00:03:26 +00:00
Jevin Jiang
041f575747 Support MHA in ragged paged attention for packed type
PiperOrigin-RevId: 734695213
2025-03-07 14:47:04 -08:00
jax authors
6095af050f Merge pull request #26427 from mattjj:direct-linearize-fixes
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
jax authors
1870176eb3 Merge pull request #26979 from mattjj:26936
PiperOrigin-RevId: 734674945
2025-03-07 13:43:55 -08:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9 [mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
2025-03-07 20:53:29 +00:00
Yash Katariya
9f37b5197f [sharding_in_types] Fix a bug where empty_array in scan was created with the wrong spec when unroll > 1.
PiperOrigin-RevId: 734591110
2025-03-07 09:47:32 -08:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
402389290c [Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+
PiperOrigin-RevId: 734558364
2025-03-07 07:59:31 -08:00
Adam Paszke
65462fe684 [Mosaic GPU] Add a new layout to help with transposing WGMMA results
PiperOrigin-RevId: 734553651
2025-03-07 07:42:01 -08:00
Yash Katariya
f8b98993b8 Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings.
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
2025-03-07 07:01:34 -08:00
Dan Foreman-Mackey
b7ecfdfd95 Update ad.backward_pass to support non-linear functions of constants. 2025-03-07 09:54:06 -05:00
Adam Paszke
85c6b6a128 [Mosaic GPU] Add support for tiling stores to refs using small tiling
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.

PiperOrigin-RevId: 734517000
2025-03-07 05:19:11 -08:00
Daniel Suo
e6db7a9d99 Dedup non-ref constants closed in cond branch functions.
PiperOrigin-RevId: 734497907
2025-03-07 04:01:42 -08:00
shuw
ccbe9f7cd6 Fix lint 2025-03-07 04:52:58 +00:00
Zac Mustin
8095d842c8 roofline: Support computing flops for unary ops.
PiperOrigin-RevId: 734351741
2025-03-06 17:44:36 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00