9119 Commits

Author SHA1 Message Date
jax authors
880884dfdd Merge pull request #27086 from Amir-19:tma_reduction
PiperOrigin-RevId: 742886445
2025-04-01 16:44:02 -07:00
jax authors
9bb899dc7a Merge pull request #27651 from mattjj:mutable-array-vmap
PiperOrigin-RevId: 742878778
2025-04-01 16:19:23 -07:00
jax authors
747c5803c3 Merge pull request #27632 from LouisJustinTALLOT:patch-1
PiperOrigin-RevId: 742864570
2025-04-01 15:36:02 -07:00
jax authors
e75c05295f Merge pull request #26674 from nvcastet:split_distributed_gpu_pallas_2
PiperOrigin-RevId: 742860659
2025-04-01 15:23:13 -07:00
Matthew Johnson
05269a8ec9 [mutable-arrays] add vmap rule for mutable_array_p, very basic test 2025-04-01 20:18:32 +00:00
Jake VanderPlas
4908b2f167 cumulative reductions: support __jax_array__ on inputs 2025-04-01 13:02:25 -07:00
Jake VanderPlas
7b04a79fbd jnp.einsum: add support for __jax_array__ 2025-04-01 12:26:26 -07:00
jax authors
efd621a241 Merge pull request #27643 from jakevdp:select-jax-array
PiperOrigin-RevId: 742795353
2025-04-01 12:21:22 -07:00
Vladimir Belitskiy
5370ac2ec5 Remove the try/except for Shardy imports.
Shardy has been been included in JAX for a while now.

PiperOrigin-RevId: 742778405
2025-04-01 11:33:44 -07:00
jax authors
f4c727abb3 Merge pull request #26964 from olupton:auto-pgle-with-graphs
PiperOrigin-RevId: 742774171
2025-04-01 11:22:33 -07:00
Matthew Johnson
a80f6279e9 make random_gamma_grad not a primitive anymore
Fixes #16076

Co-authored-by: Roy Frostig <frostig@google.com>
2025-04-01 17:04:50 +00:00
Jake VanderPlas
a34c462875 jnp.select: support __jax_array__ for inputs 2025-04-01 09:53:29 -07:00
Yash Katariya
76271d638a Add scan_p and cond_p vma rule.
PiperOrigin-RevId: 742737384
2025-04-01 09:50:38 -07:00
Louis-Justin TALLOT
6adb728975
Clarify documentation of jnp.heaviside 2025-04-01 02:46:30 -04:00
jax authors
006a6a63fe [Easy] Make pallas mesh grid handling more resilient to tuple names.
PiperOrigin-RevId: 742531956
2025-03-31 22:02:29 -07:00
Adam Paszke
994af3efb8 [Pallas TPU] Remove forward compatibility code for float -> signed conversions
This will be submitted automatically once the compatibility window has passed

PiperOrigin-RevId: 742464046
2025-03-31 17:40:09 -07:00
Jake VanderPlas
4003e2d0ee jnp.power: support __jax_array__ on inputs 2025-03-31 16:50:04 -07:00
jax authors
5c354541df Merge pull request #27627 from jakevdp:transpose-jax-array
PiperOrigin-RevId: 742447929
2025-03-31 16:42:52 -07:00
Ayaka
f59f615f6f Minor docstring updates for AOT wrappers in error checking
PiperOrigin-RevId: 742431349
2025-03-31 15:55:18 -07:00
Jake VanderPlas
ca36047ac9 __jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose 2025-03-31 15:14:47 -07:00
Nicolas Castet
8cda2a23dd [Mosaic-GPU] [2/3] Add NVSHMEM support to Mosaic-GPU custom call 2025-03-31 22:08:20 +00:00
Yash Katariya
d6b4fed5ed Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable
PiperOrigin-RevId: 742334213
2025-03-31 11:33:59 -07:00
Olli Lupton
1355e7c650 AutoPGLE: force-disable graphs less
Previously, XLA's command buffers (CUDA graphs) would be disabled both
for PGLE profile collection and when re-compiling using the profile
data. With this change, they are only disabled when collecting the
profile data.
2025-03-31 18:01:56 +00:00
Shu Wang
aaa3ebfb8a
Add optimization barrier. 2025-03-31 12:05:30 -05:00
Jake VanderPlas
200f826398 [array api] return all devices in devices() 2025-03-31 08:50:39 -07:00
Sergei Lebedev
6b719496ed [pallas:mosaic_gpu] Fixed lane-level lowering of lax.optimization_barrier
PiperOrigin-RevId: 742265860
2025-03-31 07:59:58 -07:00
Adam Paszke
d3ed327572 [Pallas:MGPU] Remove (now) unnecessary TransposeTransforms
Now that we always use small tiles, we can lay out the tiled dimension
in arbitrary order so there's no need to swap them during the TMA.

PiperOrigin-RevId: 742206980
2025-03-31 03:48:58 -07:00
Christos Perivolaropoulos
05e15ba032 [pallas:mgpu] Allow more freedom for the user to transform references.
Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking:

```
# Store registers transposed
o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out)
o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m)
o_smem_t = plgpu.untile_ref(o_smem_t, (n, m))
o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0))
o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED)
plgpu.commit_smem()
del o_smem_t

# Now we need different transforms on the same smem to slice and async-store to gmem
o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,)
o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out)
o_smem = plgpu.tile_ref(o_smem, swizzle_out)
o_smem = o_smem.at[...]
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],)
```

Which in turn lets us write

PiperOrigin-RevId: 742194519
2025-03-31 02:49:46 -07:00
Adam Paszke
aee27854f0 [Pallas:MGPU] Only allow small tiling in Pallas programs
This is part of the removal of support for large MMA tiling in Mosaic GPU.
It should also let us simplify some of the transpose transforms that are
no longer necessary, but I decided to separate this.

PiperOrigin-RevId: 742168801
2025-03-31 00:54:23 -07:00
Amir Samani
29bd01f830 add reduction support in copy_smem_to_gmem 2025-03-30 17:41:14 -07:00
Christos Perivolaropoulos
0edd715e96 [mgpu/pallas] Expose WGMMA_TRANSPOSED layout
PiperOrigin-RevId: 742084936
2025-03-30 16:12:33 -07:00
Christos Perivolaropoulos
a865b4e437 [mgpu] Register the mosaic_gpu dialect regardless of warpgroup/lane lowering.
In `mgpu.bitwidth()` mosaic_gpu types are being checked even in Lane lowering which fails.

PiperOrigin-RevId: 742044332
2025-03-30 10:50:51 -07:00
Yash Katariya
7ca50844f3 Fix an edge-case in reshape sharding rule where the last splitting/merging dim was 1.
PiperOrigin-RevId: 741740811
2025-03-28 21:43:27 -07:00
jax authors
ebd90e06fa Merge pull request #27585 from jakevdp:default-dtype-doc
PiperOrigin-RevId: 741691513
2025-03-28 17:23:16 -07:00
Yash Katariya
80061ad4c4 Add vma rules for pmin and pmax
PiperOrigin-RevId: 741685454
2025-03-28 16:55:16 -07:00
jeffcarp
123ce5221b Add scalar event logging function 2025-03-28 23:23:42 +00:00
Matthew Johnson
6fba4ecc58 PR #27576: [attrs] experimental appendattr
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576

This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.

This PR also includes some fixes for getattr/setattr.
Copybara import of the project:

--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:

[attrs] experimental appendattr

Merging this change closes #27576

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00
Jake VanderPlas
dafebd0d7f DOC: add documentation note about default dtypes 2025-03-28 15:20:58 -07:00
Yash Katariya
177193662c Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives
PiperOrigin-RevId: 741661360
2025-03-28 15:16:06 -07:00
jax authors
2d63b6e56d Merge pull request #27583 from jakevdp:scan-doc
PiperOrigin-RevId: 741653320
2025-03-28 14:45:52 -07:00
jax authors
6edc31ae1d Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
Jake VanderPlas
91dac631fb scan: improve docs & errors around dynamic length 2025-03-28 14:15:25 -07:00
jax authors
b3a2c5341d [NFC] Fix linter errors in pipeline file
PiperOrigin-RevId: 741644574
2025-03-28 14:14:56 -07:00
jax authors
47876bb3dc Merge pull request #27579 from ZacCranko:nbytes
PiperOrigin-RevId: 741636333
2025-03-28 13:50:40 -07:00
Sergei Lebedev
e838fe19d3 [pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to lane-level lowering
More work is needed to support these in the WG lowering.

PiperOrigin-RevId: 741622096
2025-03-28 13:01:10 -07:00
Sergei Lebedev
fbff338a8e [pallas:mosaic_gpu] GPUMesh now accepts axis names in a more structured way
This is hopefully less confusing then bunching them together in a single argument.

PiperOrigin-RevId: 741580827
2025-03-28 11:01:19 -07:00
Zac Cranko
d4c42d7199 implement nbytes for PRNGKeyArray 2025-03-28 10:54:48 -07:00
Peter Hawkins
ecd9f5ded8 Move aval_to_xla_shape into callback.py, which is its only user.
Specialize it to one shape per aval, since that's the only case that exists.
Remove some pointless assertions using this code.

PiperOrigin-RevId: 741569024
2025-03-28 10:28:04 -07:00
jax authors
679b6102e1 Merge pull request #27488 from jakevdp:array-capabilities
PiperOrigin-RevId: 741565179
2025-03-28 10:16:08 -07:00
Yash Katariya
5950e722e2 Make sure vma on ShapedArray exists by default to make development easier. The field is populated inside shard_map guarded on the varying_axes_in_types config though.
PiperOrigin-RevId: 741554623
2025-03-28 09:44:03 -07:00