9697 Commits

Author SHA1 Message Date
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Matthew Johnson
251b93ebd7 fixups that we meant to include in #26427
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-08 00:03:26 +00:00
Jevin Jiang
041f575747 Support MHA in ragged paged attention for packed type
PiperOrigin-RevId: 734695213
2025-03-07 14:47:04 -08:00
jax authors
6095af050f Merge pull request #26427 from mattjj:direct-linearize-fixes
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
jax authors
1870176eb3 Merge pull request #26979 from mattjj:26936
PiperOrigin-RevId: 734674945
2025-03-07 13:43:55 -08:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9 [mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
2025-03-07 20:53:29 +00:00
Yash Katariya
9f37b5197f [sharding_in_types] Fix a bug where empty_array in scan was created with the wrong spec when unroll > 1.
PiperOrigin-RevId: 734591110
2025-03-07 09:47:32 -08:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
402389290c [Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+
PiperOrigin-RevId: 734558364
2025-03-07 07:59:31 -08:00
Adam Paszke
65462fe684 [Mosaic GPU] Add a new layout to help with transposing WGMMA results
PiperOrigin-RevId: 734553651
2025-03-07 07:42:01 -08:00
Yash Katariya
f8b98993b8 Add a divisibility check so that we make sure that sharding evenly divides the shape (until this restriction is lifted) to make sure we don't create bad shardings.
Also improve dynamic_update_slice sharding error by printing `aval.str_short()` instead of full sharding because it's concise and gives more info than the current error (i.e. it adds shape too to the error message)

Also make some formatting changes in scan lowering to make it easier to debug.

PiperOrigin-RevId: 734542862
2025-03-07 07:01:34 -08:00
Adam Paszke
85c6b6a128 [Mosaic GPU] Add support for tiling stores to refs using small tiling
The difficulty here is that our register tiling is based on the (64, 8)
shape, while the memory tiling is now (8, swizzle // bytewidth). Before,
we would assume that each register tile fits neatly within a single
memory tile, but now it is obviously not the case. Luckily, it wasn't
too hard to add.

PiperOrigin-RevId: 734517000
2025-03-07 05:19:11 -08:00
Daniel Suo
e6db7a9d99 Dedup non-ref constants closed in cond branch functions.
PiperOrigin-RevId: 734497907
2025-03-07 04:01:42 -08:00
Zac Mustin
8095d842c8 roofline: Support computing flops for unary ops.
PiperOrigin-RevId: 734351741
2025-03-06 17:44:36 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
Yash Katariya
e9486920e8 Auto complete specs in a sharding if aval.ndim > len(sharding.spec) with None. So that for a 2D input, P('data') continues to work.
PiperOrigin-RevId: 734325209
2025-03-06 16:10:14 -08:00
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00
Jevin Jiang
4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Ayaka
8c89da7cdc Minor bug fixes in error checking
PiperOrigin-RevId: 734126415
2025-03-06 06:57:52 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Benjamin Chetioui
fe577b5dc4 [Pallas/Mosaic GPU] Enable ops_test for Mosaic GPU.
For now, most of the tests are skipped.

PiperOrigin-RevId: 734026728
2025-03-06 00:45:05 -08:00
Yash Katariya
a67ab9fade Just use jit as the string in error messages instead of jit and pjit based on resource_env. This is to start deprecating the need for with mesh and replace it with use_mesh(mesh).
PiperOrigin-RevId: 733959962
2025-03-05 20:09:30 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gary Miguel
69d66f66df vmap mismatch size error message: handle *args
Fixes: https://github.com/jax-ml/jax/issues/26908
2025-03-05 13:08:54 -08:00
Adam Paszke
8df00e2666 [Mosaic GPU] Remove support for large tiles on Blackwell
We don't have many Blackwell kernels yet, so let's begin the deprecation there!
Small tiles have clearer semantics when it comes to transposes too, which allows
us to enable more test cases.

PiperOrigin-RevId: 733786884
2025-03-05 10:34:53 -08:00
Dan Foreman-Mackey
4a93c8b30c Reverts 342cb7b99a09180472823a33c7cdad8a8db77875
PiperOrigin-RevId: 733782497
2025-03-05 10:22:40 -08:00
Adam Paszke
4493889cda [Mosaic GPU] Add support for small tiles for (WG)MMA LHS
Thanks to the previous refactor the change is quite trivial and mostly
focuses on adding tests.

PiperOrigin-RevId: 733754797
2025-03-05 09:01:20 -08:00
Adam Paszke
d119138766 [Mosaic GPU][NFC] Refactor MMA SMEM descriptor creation
This makes the code path uniform for LHS/RHS and greatly clarifies the
magical computation of LBO/SBO. This change should make it significantly
easier for us to enable small tile support for the LHS.

PiperOrigin-RevId: 733737302
2025-03-05 08:06:06 -08:00
Sergei Lebedev
6230ef1d51 Removed unused import 2025-03-05 15:18:43 +00:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
Dan Foreman-Mackey
342cb7b99a Attempt 2 at landing custom_vjp.optimize_remat using custom_dce.
The original change was rolled back because there were real world use cases of custom_vjp where the fwd function had the wrong signature. To preserve backwards compatibility, we shouldn't resolve the input arguments to fwd using fwds signature. Instead, we can just ignore the signature because custom_vjp handles the resolution before we ever get here.

Reverts 1f3176636d304398b00a7d2cb0933859618affd8

PiperOrigin-RevId: 733643149
2025-03-05 02:06:35 -08:00
Christos Perivolaropoulos
51719a1afe [mgpu] Non-vector untiled stores for tiling layouts.
Useful for storing in memrefs where the minormost stride is >1.

PiperOrigin-RevId: 733551038
2025-03-04 19:41:04 -08:00
Yash Katariya
766315f791 Make sure concat + vmap of sharded input and replicated input works properly.
In this case, the example boils down to:

```
inp1 = f32[16@x, 4]
inp2 = f32[4]

def f(x: f32[4], y: f32[4])
  return jnp.concat([x, y], axis=-1)

vmap(f, in_axes=(0, None))(inp1)
```

This example was breaking in concat batching rule because we didn't broadcast with the right sharding.

PiperOrigin-RevId: 733536944
2025-03-04 18:35:13 -08:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
c145102ef4 Merge pull request #26641 from jakevdp:jnp-ndim
PiperOrigin-RevId: 733484459
2025-03-04 15:21:01 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
jax authors
8af6f70fe0 [JAX] Disable msan and asan for the profiler test running on nvidia gpu
PiperOrigin-RevId: 733380848
2025-03-04 10:34:11 -08:00
Dan Foreman-Mackey
8b1b039e0d Improve error messages when input argument resolution fails in custom_* APIs. 2025-03-04 10:31:35 -05:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Dan Foreman-Mackey
6c5ef1a404 Update jnp.unique to support upstream interface changes. 2025-03-04 05:24:52 -05:00
Ayaka
ea53c7616b Fix thread safety of JAX error checking
Fix thread safety of JAX error checking by making the global states thread local

PiperOrigin-RevId: 733164878
2025-03-03 20:56:01 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
2c7043f63d Merge pull request #26865 from jakevdp:fix-indexing-error
PiperOrigin-RevId: 733085471
2025-03-03 15:38:20 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Yash Katariya
07c4c03a05 Remove the skip for test_output_streaming_inside_scan
PiperOrigin-RevId: 733070842
2025-03-03 14:54:03 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00