24618 Commits

Author SHA1 Message Date
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
Henning Becker
cb6881d9e8 Reverts bdadc53ebcd40a5091d66d2586deba82fe5e01ca
PiperOrigin-RevId: 704758075
2024-12-10 10:25:27 -08:00
jax authors
e6dfe8f380 [AutoPGLE] Share FDO profile even when compilation cache disabled.
PiperOrigin-RevId: 704757991
2024-12-10 10:23:42 -08:00
jax authors
6dbafed7bc Fix mypy failure
PiperOrigin-RevId: 704748889
2024-12-10 10:01:43 -08:00
jax authors
8813973d96 [AutoPGLE] Cleanup compiler code.
PiperOrigin-RevId: 704741308
2024-12-10 09:37:35 -08:00
Peter Hawkins
acae2f0546 Remove code in jax2tf for compatibility with TF 2.10 or earlier. 2024-12-10 15:18:59 +00:00
jax authors
263d4d1462 Merge pull request #25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00
jax authors
8e7aaa792b Merge pull request #25374 from traversaro:patch-1
PiperOrigin-RevId: 704673954
2024-12-10 06:00:05 -08:00
Silvio Traversaro
09309e6452
Update conda-forge installation docs after CUDA 12 upgrade 2024-12-10 12:11:28 +01:00
jax authors
90de28cd63 Merge pull request #25335 from gnecula:export_doc_call
PiperOrigin-RevId: 704589764
2024-12-10 00:45:20 -08:00
Gunhyun Park
12c30578b2 Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
  func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
    %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
    return %0 : tensor<6xf32>
  }
}
```

For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).

The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.

PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Yash Katariya
944d822ce6 Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
jax authors
1743f2c41e Merge pull request #25371 from jakevdp:mypy-numpy-version
PiperOrigin-RevId: 704504407
2024-12-09 19:06:56 -08:00
Jake VanderPlas
a36af966fd CI: temporarily pin numpy version for mypy check 2024-12-09 19:01:46 -08:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
jax authors
7b5cb56fc9 Merge pull request #25368 from hawkinsp:postrelease
PiperOrigin-RevId: 704480301
2024-12-09 17:47:52 -08:00
Peter Hawkins
820f51dc53 Merge branch 'release/0.4.37' into main. 2024-12-09 20:21:43 -05:00
Dan Foreman-Mackey
978d35f697 Fix expected exception type in pallas grad tests.
PiperOrigin-RevId: 704408603
2024-12-09 14:02:07 -08:00
jax authors
71c48cba1c Update XLA dependency to use revision
a041e1b155.

PiperOrigin-RevId: 704406817
2024-12-09 13:56:14 -08:00
Gleb Pobudzey
e1e174fbc4 Adding more tests for multi-head attention 2024-12-09 20:49:06 +00:00
Dan Foreman-Mackey
32df37e6e4 Port symmetric tridiagonal reduction GPU kernel to FFI.
PiperOrigin-RevId: 704382200
2024-12-09 12:41:23 -08:00
Peter Hawkins
ffb07cdadb Update versions for v0.4.37 release. 2024-12-09 15:39:59 -05:00
Dougal
95892fdac8 Use private names for args in api_util to avoid shadowing kwargs keys.
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
2024-12-09 15:36:53 -05:00
IvyZX
65b6088411 Avoid index out of range error in carry structure check 2024-12-09 15:36:32 -05:00
Kanglan Tang
66b900540a Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e
PiperOrigin-RevId: 704378732
2024-12-09 12:30:36 -08:00
jax authors
1c07ec6183 Merge pull request #25272 from justinjfu:pallas_tpu_docs_update
PiperOrigin-RevId: 704376603
2024-12-09 12:22:31 -08:00
jax authors
4db533be41 Merge pull request #25355 from IvyZX:loop-fix
PiperOrigin-RevId: 704373501
2024-12-09 12:11:58 -08:00
jax authors
dba3358dd4 Merge pull request #25294 from jakevdp:array-api-tests
PiperOrigin-RevId: 704363905
2024-12-09 11:46:00 -08:00
jax authors
56fcd38d46 Merge pull request #25351 from jax-ml:dependabot/github_actions/actions/cache-4.2.0
PiperOrigin-RevId: 704363738
2024-12-09 11:44:08 -08:00
Andrey Portnoy
cc22334c21 [Mosaic GPU] Add CUPTI profiler alongside events-based implementation 2024-12-09 14:31:20 -05:00
Dan Foreman-Mackey
092d2a0db5 Add error message when using custom_vmap with reverse-mode AD, and add docstrings.
The `custom_vmap` API is discussed in https://github.com/jax-ml/jax/issues/9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 704353963
2024-12-09 11:17:44 -08:00
IvyZX
bd77a703fd Avoid index out of range error in carry structure check 2024-12-09 10:44:28 -08:00
Hyeontaek Lim
296d1670bf [JAX] Add concurrent execution support in colocated Python
This change makes asynchronous execution run without holding a mutex. This
allows colocated Python executions from multiple Python threads to run
concurrently.

PiperOrigin-RevId: 704340663
2024-12-09 10:43:30 -08:00
jax authors
d908e0add9 Merge pull request #25349 from jax-ml:fix-25329
PiperOrigin-RevId: 704311844
2024-12-09 09:20:54 -08:00
dependabot[bot]
b6863dfcb5
Bump actions/cache from 4.1.2 to 4.2.0
Bumps [actions/cache](https://github.com/actions/cache) from 4.1.2 to 4.2.0.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](6849a64899...1bd1e32a3b)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-12-09 17:14:14 +00:00
Bart Chrzaszcz
6f69774c00 #sdy enable test_compute_offload_mesh_with_linear_layout for Shardy.
PiperOrigin-RevId: 704301465
2024-12-09 08:46:48 -08:00
Berkin Ilbeyi
f17b2bc2d3 Reenable for_loop_test on TPU v5p.
PiperOrigin-RevId: 704298792
2024-12-09 08:38:41 -08:00
Tomás Longeri
b76d264fe7 [Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by always using implicit shapes
PiperOrigin-RevId: 704297805
2024-12-09 08:35:14 -08:00
Dougal
dd74394e63 Use private names for args in api_util to avoid shadowing kwargs keys.
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
2024-12-09 11:24:11 -05:00
Ayaka
9c98c0cbbf [Pallas TPU] Improve lowerings for boolean comparison operations
The error when negating a boolean value (https://github.com/jax-ml/jax/issues/24243) has been fixed, so we can lower the boolean comparison operations using boolean algebra instead of using the previous workaround.

Besides, the original tests uses `allclose` on boolean arrays, which is wrong. I have changed them to `assertArraysEqual`.

PiperOrigin-RevId: 704294742
2024-12-09 08:23:51 -08:00
Sunita Nadampalli
e370deee0f add mkldnn+acl build config for aarch64 platform 2024-12-09 16:03:14 +00:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
cc258f5f61 Merge pull request #25320 from ROCm:gh-9948-fix-kernel-build-upstream
PiperOrigin-RevId: 704279150
2024-12-09 07:29:31 -08:00
Sergei Lebedev
1ac6b762dd Ensured that JAX type checks under pytype on Python 3.12
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
2024-12-09 06:53:08 -08:00
jax authors
5a1c4c5783 Merge pull request #25338 from carlosgmartin:fix_numpy_linalg_matrix_norm_ord_type_annotation
PiperOrigin-RevId: 704245037
2024-12-09 05:17:26 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
Adam Paszke
adb2bf629c [Mosaic TPU] Allow downgrading the IR during serialization for forward compat
This is to uphold the monthly stability promise made by jax.export.

PiperOrigin-RevId: 704233290
2024-12-09 04:32:41 -08:00
Chris Jones
a94474d016 [pallas] Add DotAlgorithmPreset note to CHANGELOG.
PiperOrigin-RevId: 704216341
2024-12-09 03:26:20 -08:00
Chris Jones
3ec55c7723 [pallas:triton] Add support for DotAlgorithmPreset precision arguments to dot.
PiperOrigin-RevId: 704208558
2024-12-09 02:52:47 -08:00
carlosgmartin
efa35ea9f9 Fix type annotation for numpy.linalg.matrix_norm argument 'ord'. 2024-12-08 20:11:06 -05:00