26648 Commits

Author SHA1 Message Date
jax authors
006a6a63fe [Easy] Make pallas mesh grid handling more resilient to tuple names.
PiperOrigin-RevId: 742531956
2025-03-31 22:02:29 -07:00
Adam Paszke
994af3efb8 [Pallas TPU] Remove forward compatibility code for float -> signed conversions
This will be submitted automatically once the compatibility window has passed

PiperOrigin-RevId: 742464046
2025-03-31 17:40:09 -07:00
jax authors
16b2b91115 Merge pull request #27540 from jakevdp:pow-jax-array
PiperOrigin-RevId: 742464022
2025-03-31 17:38:02 -07:00
Jake VanderPlas
4003e2d0ee jnp.power: support __jax_array__ on inputs 2025-03-31 16:50:04 -07:00
jax authors
5c354541df Merge pull request #27627 from jakevdp:transpose-jax-array
PiperOrigin-RevId: 742447929
2025-03-31 16:42:52 -07:00
Ayaka
f59f615f6f Minor docstring updates for AOT wrappers in error checking
PiperOrigin-RevId: 742431349
2025-03-31 15:55:18 -07:00
Jake VanderPlas
ca36047ac9 __jax_array__: add support in jnp.reshape, jnp.transpose, jnp.matrix_transpose 2025-03-31 15:14:47 -07:00
jax authors
e2ee2625b4 Merge pull request #27621 from jax-ml:dependabot/github_actions/actions/setup-python-5.5.0
PiperOrigin-RevId: 742345857
2025-03-31 12:03:36 -07:00
Yash Katariya
d6b4fed5ed Propagate sharding and vma rule for axis_index_p. There's no need for pbroadcast insertion for axis_index_p in the traceable
PiperOrigin-RevId: 742334213
2025-03-31 11:33:59 -07:00
dependabot[bot]
5d69e6b64d
Bump actions/setup-python from 5.4.0 to 5.5.0
Bumps [actions/setup-python](https://github.com/actions/setup-python) from 5.4.0 to 5.5.0.
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](42375524e2...8d9ed9ac5c)

---
updated-dependencies:
- dependency-name: actions/setup-python
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-03-31 17:49:47 +00:00
jax authors
e2df37471e Merge pull request #27616 from jakevdp:array-api-devices
PiperOrigin-RevId: 742302906
2025-03-31 10:04:37 -07:00
Jake VanderPlas
200f826398 [array api] return all devices in devices() 2025-03-31 08:50:39 -07:00
Sergei Lebedev
6b719496ed [pallas:mosaic_gpu] Fixed lane-level lowering of lax.optimization_barrier
PiperOrigin-RevId: 742265860
2025-03-31 07:59:58 -07:00
Dan Foreman-Mackey
95497ca2f0 Remove legacy GPU kernel for LU decomposition.
Following the compatibility timeline described here: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

It has been 6 months since the release of 0.4.33 which is the relevant release for this kernel.

PiperOrigin-RevId: 742261532
2025-03-31 07:43:08 -07:00
Daniel Suo
b3d851d722 Add Jax tracing micro benchmarks.
Add a first benchmark for tracing/lowering pallas splash attention.

Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan.

---------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations
---------------------------------------------------------------------------------
test_pallas_mqa_splash_attention_trace       39.8 ms         39.8 ms           19
test_pallas_mqa_splash_attention_lower       42.1 ms         41.9 ms           18

PiperOrigin-RevId: 742259409
2025-03-31 07:35:22 -07:00
Daniel Suo
12526ea116 [jaxlib] Pack/unpack subbyte types to/from numpy arrays to support int2, uint2, int4, uint4, float4_e2m1fn subbyte types in CPU/GPU callbacks.
PiperOrigin-RevId: 742253272
2025-03-31 07:09:46 -07:00
Sergei Lebedev
cb51682691 [pallas:mosaic_gpu] Run all Mosaic GPU-specific tests under WG semantics
We do skip quite a few due to missing features. I tried to make the reason
for skipping clear in each case.

PiperOrigin-RevId: 742252858
2025-03-31 07:07:54 -07:00
jax authors
fc01058ee4 Update XLA dependency to use revision
f4a53456b0.

PiperOrigin-RevId: 742228024
2025-03-31 05:16:11 -07:00
Adam Paszke
d3ed327572 [Pallas:MGPU] Remove (now) unnecessary TransposeTransforms
Now that we always use small tiles, we can lay out the tiled dimension
in arbitrary order so there's no need to swap them during the TMA.

PiperOrigin-RevId: 742206980
2025-03-31 03:48:58 -07:00
Christos Perivolaropoulos
05e15ba032 [pallas:mgpu] Allow more freedom for the user to transform references.
Imlpemented untile_ref and unswizzle_ref in order to allow patterns where we need different transform stacks over the same memref. For example we may want to reg->smem transposed, then smem->gmem sliced and maybe load strided/print in between for sanity checking:

```
# Store registers transposed
o_smem_swizzled = plgpu.unswizzle_ref(o_smem_raw, swizzle_out)
o_smem_t = o_smem_swizzled.reshape(1, 1, config.block_n, config.block_m)
o_smem_t = plgpu.untile_ref(o_smem_t, (n, m))
o_smem_t = plgpu.transpose_ref(o_smem_t, (1, 0))
o_smem_t[...] = plgpu.layout_cast((regs, plgpu.Layout.WGMMA_TRANSPOSED)
plgpu.commit_smem()
del o_smem_t

# Now we need different transforms on the same smem to slice and async-store to gmem
o_smem = o_smem_raw.reshape(n, m // swizzle_elems, swizzle_elems,)
o_smem = plgpu.unswizzle_ref(o_smem, swizzle_out)
o_smem = plgpu.tile_ref(o_smem, swizzle_out)
o_smem = o_smem.at[...]
plgpu.copy_smem_to_gmem(o_smem, o_ref.at[...],)
```

Which in turn lets us write

PiperOrigin-RevId: 742194519
2025-03-31 02:49:46 -07:00
jax authors
c0562861d4 Merge pull request #27446 from jakevdp:core-deps
PiperOrigin-RevId: 742190714
2025-03-31 02:32:16 -07:00
Adam Paszke
aee27854f0 [Pallas:MGPU] Only allow small tiling in Pallas programs
This is part of the removal of support for large MMA tiling in Mosaic GPU.
It should also let us simplify some of the transpose transforms that are
no longer necessary, but I decided to separate this.

PiperOrigin-RevId: 742168801
2025-03-31 00:54:23 -07:00
Jake VanderPlas
10425ae6a9 jax.core: finalize a number of deprecations for JAX v0.6.0 2025-03-30 19:32:22 -07:00
Christos Perivolaropoulos
0edd715e96 [mgpu/pallas] Expose WGMMA_TRANSPOSED layout
PiperOrigin-RevId: 742084936
2025-03-30 16:12:33 -07:00
Christos Perivolaropoulos
a865b4e437 [mgpu] Register the mosaic_gpu dialect regardless of warpgroup/lane lowering.
In `mgpu.bitwidth()` mosaic_gpu types are being checked even in Lane lowering which fails.

PiperOrigin-RevId: 742044332
2025-03-30 10:50:51 -07:00
jax authors
5fda4c1b0e Update XLA dependency to use revision
8df9390dc9.

PiperOrigin-RevId: 741998725
2025-03-30 04:44:45 -07:00
jax authors
e7ec418eba Update XLA dependency to use revision
f50746ab31.

PiperOrigin-RevId: 741809075
2025-03-29 05:19:54 -07:00
Yash Katariya
7ca50844f3 Fix an edge-case in reshape sharding rule where the last splitting/merging dim was 1.
PiperOrigin-RevId: 741740811
2025-03-28 21:43:27 -07:00
jax authors
ebd90e06fa Merge pull request #27585 from jakevdp:default-dtype-doc
PiperOrigin-RevId: 741691513
2025-03-28 17:23:16 -07:00
Yash Katariya
80061ad4c4 Add vma rules for pmin and pmax
PiperOrigin-RevId: 741685454
2025-03-28 16:55:16 -07:00
jax authors
6b8821148d Merge pull request #27309 from jax-ml:discord
PiperOrigin-RevId: 741678068
2025-03-28 16:22:50 -07:00
Zac Cranko
93c6bb72d3 add discord release action
Update community_release_actions.yml
2025-03-28 16:12:35 -07:00
jax authors
eb54cd2c61 Remove GPU-specific dependencies from backend-independent tests.
The GPU-specific deps were added to the backend-independent tests by mistake [here](https://github.com/jax-ml/jax/pull/27113). These tests should pass using `jax` and `jaxlib` wheels only.

PiperOrigin-RevId: 741663266
2025-03-28 15:23:31 -07:00
Matthew Johnson
6fba4ecc58 PR #27576: [attrs] experimental appendattr
Imported from GitHub PR https://github.com/jax-ml/jax/pull/27576

This is an experimental extension to attrs. Attrs should be considered both experimental and deprecated.

This PR also includes some fixes for getattr/setattr.
Copybara import of the project:

--
3b1ea1a5f90b28744522670d0498ce5a6b194274 by Matthew Johnson <mattjj@google.com>:

[attrs] experimental appendattr

Merging this change closes #27576

COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/27576 from mattjj:appendattr b93795201b39b8f75890c9228368c994ae1e38e8
PiperOrigin-RevId: 741662724
2025-03-28 15:21:12 -07:00
Jake VanderPlas
dafebd0d7f DOC: add documentation note about default dtypes 2025-03-28 15:20:58 -07:00
Yash Katariya
177193662c Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter primitives
PiperOrigin-RevId: 741661360
2025-03-28 15:16:06 -07:00
jax authors
b719ac00c6 Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM.
We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth.

Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value.

PiperOrigin-RevId: 741660728
2025-03-28 15:13:33 -07:00
jax authors
2d63b6e56d Merge pull request #27583 from jakevdp:scan-doc
PiperOrigin-RevId: 741653320
2025-03-28 14:45:52 -07:00
jax authors
6edc31ae1d Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
PiperOrigin-RevId: 741651222
2025-03-28 14:38:38 -07:00
Jake VanderPlas
91dac631fb scan: improve docs & errors around dynamic length 2025-03-28 14:15:25 -07:00
jax authors
b3a2c5341d [NFC] Fix linter errors in pipeline file
PiperOrigin-RevId: 741644574
2025-03-28 14:14:56 -07:00
jax authors
47876bb3dc Merge pull request #27579 from ZacCranko:nbytes
PiperOrigin-RevId: 741636333
2025-03-28 13:50:40 -07:00
jax authors
6395a22d30 Merge pull request #27575 from hawkinsp:domain
PiperOrigin-RevId: 741635283
2025-03-28 13:47:24 -07:00
Sergei Lebedev
e838fe19d3 [pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to lane-level lowering
More work is needed to support these in the WG lowering.

PiperOrigin-RevId: 741622096
2025-03-28 13:01:10 -07:00
Sergei Lebedev
fbff338a8e [pallas:mosaic_gpu] GPUMesh now accepts axis names in a more structured way
This is hopefully less confusing then bunching them together in a single argument.

PiperOrigin-RevId: 741580827
2025-03-28 11:01:19 -07:00
Zac Cranko
d4c42d7199 implement nbytes for PRNGKeyArray 2025-03-28 10:54:48 -07:00
Peter Hawkins
ecd9f5ded8 Move aval_to_xla_shape into callback.py, which is its only user.
Specialize it to one shape per aval, since that's the only case that exists.
Remove some pointless assertions using this code.

PiperOrigin-RevId: 741569024
2025-03-28 10:28:04 -07:00
Peter Hawkins
829deb68f6 Set NB_DOMAIN=jax
This is a precautionary measure to prevent conflicts with other packages
using nanobind and registering the same types. We don't want JAX's
nanobind registrations to conflict on, say, XLA types with other
projects.
2025-03-28 17:19:34 +00:00
jax authors
fde7d16c60 Clean up: num_groups = num_q_heads // num_kv_heads
No code functionality change in this commit.

PiperOrigin-RevId: 741566312
2025-03-28 10:19:16 -07:00
jax authors
679b6102e1 Merge pull request #27488 from jakevdp:array-capabilities
PiperOrigin-RevId: 741565179
2025-03-28 10:16:08 -07:00