26442 Commits

Author SHA1 Message Date
Dimitar (Mitko) Asenov
d2bf034c47 [Mosaic GPU] Test the wgmma_op lowering when a is in registers.
I had to add support for wgmma layout in vector_load. Not sure if this is useful outside the test.

PiperOrigin-RevId: 735384104
2025-03-10 08:25:43 -07:00
jax authors
5a7ef40416 Merge pull request #27026 from garymm:patch-3
PiperOrigin-RevId: 735382490
2025-03-10 08:20:30 -07:00
Dan Foreman-Mackey
21884d4a14 Move (most) jaxlib linalg custom call registration into JAX.
My motivation here is to fix the plugin support for batch partitionable custom calls. Since plugin support for custom call partitioners is provided via register_plugin_callback in xla_bridge, instead of xla_client itself, it's much more straightforward to register the custom calls in JAX.

It would be possible to refactor things differently, but it actually seems like a reasonable choice to use the supported APIs from `jax.ffi` instead of `xla_client` so that we can take advantage of any new features we might add there in the future.

This is all still a little bit brittle and I'd eventually like to migrate to a version where the XLA FFI library provides a mechanism for exporting handlers, but this change is still compatible with any future changes like that.

PiperOrigin-RevId: 735381736
2025-03-10 08:17:44 -07:00
Dan Foreman-Mackey
4eada56027 Avoid using array operations within lax.py operations. 2025-03-10 11:04:32 -04:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
jax authors
f906d2b6d1 Update XLA dependency to use revision
efb27eb924.

PiperOrigin-RevId: 735322004
2025-03-10 04:14:33 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Gary Miguel
6a718b762f
Update stateful-computations.md
tree_map -> tree.map
2025-03-09 21:35:46 -07:00
jax authors
b9fb69d1fc Update XLA dependency to use revision
e89a6b46bc.

PiperOrigin-RevId: 735080797
2025-03-09 03:55:03 -07:00
jax authors
922935a916 Merge pull request #27006 from dfm:more-direct-lin-debug-info
PiperOrigin-RevId: 734884478
2025-03-08 06:46:44 -08:00
Dan Foreman-Mackey
36d515ed2c A few more fixes for debug_info tests with direct_linearize. 2025-03-08 07:47:24 -05:00
jax authors
04696b4d7b Update XLA dependency to use revision
be68e80894.

PiperOrigin-RevId: 734851053
2025-03-08 03:27:16 -08:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
jax authors
4988adccf1 Merge pull request #27010 from mattjj:direct-linearize-fixes-3
PiperOrigin-RevId: 734747001
2025-03-07 18:15:02 -08:00
Matthew Johnson
fe26c19b92 [direct-linearize] fix name_stack bugs
Surprisingly, the bug was tracked down to #26111 aka cl/730939406, specifically
the new implementation of reset_name_stack in source_info_util.py.

To repro, use the before-this-commit implementation of reset_name_stack (left
commented-out in the file), and run

```
  JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py NameStackTransformationTest.test_nested_jit_stack
```
2025-03-08 01:51:19 +00:00
jax authors
4660d7b6dd Merge pull request #27005 from mattjj:direct-linearize-fixes-2
PiperOrigin-RevId: 734736244
2025-03-07 17:17:45 -08:00
Matthew Johnson
251b93ebd7 fixups that we meant to include in #26427
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-08 00:03:26 +00:00
Jevin Jiang
041f575747 Support MHA in ragged paged attention for packed type
PiperOrigin-RevId: 734695213
2025-03-07 14:47:04 -08:00
jax authors
6095af050f Merge pull request #26427 from mattjj:direct-linearize-fixes
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
jax authors
d849779689 Merge pull request #27001 from mattjj:yash-scan
PiperOrigin-RevId: 734685031
2025-03-07 14:14:30 -08:00
jax authors
1870176eb3 Merge pull request #26979 from mattjj:26936
PiperOrigin-RevId: 734674945
2025-03-07 13:43:55 -08:00
Matthew Johnson
f4f31f89ae [scan] when num_trips==0, don't generate weird size-zero reshapes 2025-03-07 21:35:40 +00:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9 [mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
2025-03-07 20:53:29 +00:00
Hyeontaek Lim
178278863d [JAX] Fix api_benchmark broken by https://github.com/jax-ml/jax/pull/26569
`pjit_check_aval_sharding` expects `names: Sequence[str]`.

PiperOrigin-RevId: 734614264
2025-03-07 10:49:53 -08:00
jax authors
ccf7278292 Add the len(arg) to the error message for static_argnums
Helps reduce the confusion on what is considered an argnum.
Ideally there should be static_argkwg

PiperOrigin-RevId: 734591856
2025-03-07 09:49:49 -08:00
Yash Katariya
9f37b5197f [sharding_in_types] Fix a bug where empty_array in scan was created with the wrong spec when unroll > 1.
PiperOrigin-RevId: 734591110
2025-03-07 09:47:32 -08:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
1bef8b61af [Mosaic GPU] Add a better explanation for the transposed layout
Thanks to @bchetioui for the discussion!

PiperOrigin-RevId: 734564672
2025-03-07 08:19:32 -08:00
Adam Paszke
402389290c [Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+
PiperOrigin-RevId: 734558364
2025-03-07 07:59:31 -08:00
Sergei Lebedev
928caf83ee [pallas:mosaic_gpu] copy_smem_to_gmem now allows skipping cp.async.commit_group
This feature is necessary to fix the SMEM->GMEM waiting behavior in
`emit_pipeline`, which used a pessimistic condition prior to this change,
since every copy was its own commit group.

PiperOrigin-RevId: 734553668
2025-03-07 07:43:54 -08:00
Adam Paszke
65462fe684 [Mosaic GPU] Add a new layout to help with transposing WGMMA results
PiperOrigin-RevId: 734553651
2025-03-07 07:42:01 -08:00
Yash Katariya
f8b98993b8 Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings.
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
2025-03-07 07:01:34 -08:00
Dan Foreman-Mackey
b7ecfdfd95 Update ad.backward_pass to support non-linear functions of constants. 2025-03-07 09:54:06 -05:00
Adam Paszke
85c6b6a128 [Mosaic GPU] Add support for tiling stores to refs using small tiling
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.

PiperOrigin-RevId: 734517000
2025-03-07 05:19:11 -08:00
jax authors
de78d2cc71 Merge pull request #26950 from lockwo:Owen/add-pmap-typehint
PiperOrigin-RevId: 734500798
2025-03-07 04:10:35 -08:00
Daniel Suo
e6db7a9d99 Dedup non-ref constants closed in cond branch functions.
PiperOrigin-RevId: 734497907
2025-03-07 04:01:42 -08:00
jax authors
bf95bf49d4 Update XLA dependency to use revision
f1213b83af.

PiperOrigin-RevId: 734484617
2025-03-07 03:00:30 -08:00
shuw
ccbe9f7cd6 Fix lint 2025-03-07 04:52:58 +00:00
Zac Mustin
8095d842c8 roofline: Support computing flops for unary ops.
PiperOrigin-RevId: 734351741
2025-03-06 17:44:36 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
jax authors
4cab118344 Merge pull request #26927 from skye:merge_release
PiperOrigin-RevId: 734323206
2025-03-06 16:06:09 -08:00
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00
Jake VanderPlas
b441b2b7a5 Prevent tracer leaks in scipy.special.expn 2025-03-06 14:38:11 -08:00
Jevin Jiang
4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Dimitar (Mitko) Asenov
5d64b3d2dd [Mosaic GPU] Fix scf.ForOp lowering to put lowered ops at the right place.
Without this fix, lowerings of ops within the `for` body are always appended at the end, even if they have users earlier in the body. This caused an `operand #0 does not dominate this use` error.

The fix was tested in the upcoming (but not yet submitted) `test_realistic_matmul` in Pallas with Workgroup semantics.

PiperOrigin-RevId: 734157829
2025-03-06 08:40:19 -08:00
Ayaka
8c89da7cdc Minor bug fixes in error checking
PiperOrigin-RevId: 734126415
2025-03-06 06:57:52 -08:00
Nitin Srinivasan
623865fe95 Build JAX wheels instead of installing it from the source repository
This change allows us to get rid of extra env vars which used to control whether to install `jax` at head. Now, `jax` will be be built and consumed in the same way as the other wheels in the continuous jobs.

PiperOrigin-RevId: 734123590
2025-03-06 06:48:16 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00