Jake VanderPlas
6541a62099
jax.core: deprecate a number of APIs
2024-12-10 11:11:32 -08:00
Henning Becker
cb6881d9e8
Reverts bdadc53ebcd40a5091d66d2586deba82fe5e01ca
...
PiperOrigin-RevId: 704758075
2024-12-10 10:25:27 -08:00
jax authors
e6dfe8f380
[AutoPGLE] Share FDO profile even when compilation cache disabled.
...
PiperOrigin-RevId: 704757991
2024-12-10 10:23:42 -08:00
jax authors
6dbafed7bc
Fix mypy failure
...
PiperOrigin-RevId: 704748889
2024-12-10 10:01:43 -08:00
jax authors
8813973d96
[AutoPGLE] Cleanup compiler code.
...
PiperOrigin-RevId: 704741308
2024-12-10 09:37:35 -08:00
Peter Hawkins
acae2f0546
Remove code in jax2tf for compatibility with TF 2.10 or earlier.
2024-12-10 15:18:59 +00:00
jax authors
263d4d1462
Merge pull request #25369 from jax-ml:mutable-arrays-ad
...
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00
jax authors
8e7aaa792b
Merge pull request #25374 from traversaro:patch-1
...
PiperOrigin-RevId: 704673954
2024-12-10 06:00:05 -08:00
Silvio Traversaro
09309e6452
Update conda-forge installation docs after CUDA 12 upgrade
2024-12-10 12:11:28 +01:00
jax authors
90de28cd63
Merge pull request #25335 from gnecula:export_doc_call
...
PiperOrigin-RevId: 704589764
2024-12-10 00:45:20 -08:00
Gunhyun Park
12c30578b2
Introduce lax.ragged_all_to_all
primitive
...
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
%0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
return %0 : tensor<6xf32>
}
}
```
For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).
The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.
PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Yash Katariya
944d822ce6
Add a no-op batching rule for optimization_barrier_p
...
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
jax authors
1743f2c41e
Merge pull request #25371 from jakevdp:mypy-numpy-version
...
PiperOrigin-RevId: 704504407
2024-12-09 19:06:56 -08:00
Jake VanderPlas
a36af966fd
CI: temporarily pin numpy version for mypy check
2024-12-09 19:01:46 -08:00
Dougal
fc2edbfac8
Add a freeze
primitive to delimit ref lifetimes for AD.
...
Also some basic AD through mutable_array/freeze.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
jax authors
7b5cb56fc9
Merge pull request #25368 from hawkinsp:postrelease
...
PiperOrigin-RevId: 704480301
2024-12-09 17:47:52 -08:00
Peter Hawkins
820f51dc53
Merge branch 'release/0.4.37' into main.
2024-12-09 20:21:43 -05:00
Dan Foreman-Mackey
978d35f697
Fix expected exception type in pallas grad tests.
...
PiperOrigin-RevId: 704408603
2024-12-09 14:02:07 -08:00
jax authors
71c48cba1c
Update XLA dependency to use revision
...
a041e1b155
.
PiperOrigin-RevId: 704406817
2024-12-09 13:56:14 -08:00
Gleb Pobudzey
e1e174fbc4
Adding more tests for multi-head attention
2024-12-09 20:49:06 +00:00
Dan Foreman-Mackey
32df37e6e4
Port symmetric tridiagonal reduction GPU kernel to FFI.
...
PiperOrigin-RevId: 704382200
2024-12-09 12:41:23 -08:00
Peter Hawkins
ffb07cdadb
Update versions for v0.4.37 release.
2024-12-09 15:39:59 -05:00
Dougal
95892fdac8
Use private names for args in api_util to avoid shadowing kwargs keys.
...
This is a quick fix for #25329 . We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
2024-12-09 15:36:53 -05:00
IvyZX
65b6088411
Avoid index out of range error in carry structure check
2024-12-09 15:36:32 -05:00
Kanglan Tang
66b900540a
Disable pjit ArrayPjitTest.test_device_put_grad test on TPU v5e
...
PiperOrigin-RevId: 704378732
2024-12-09 12:30:36 -08:00
jax authors
1c07ec6183
Merge pull request #25272 from justinjfu:pallas_tpu_docs_update
...
PiperOrigin-RevId: 704376603
2024-12-09 12:22:31 -08:00
jax authors
4db533be41
Merge pull request #25355 from IvyZX:loop-fix
...
PiperOrigin-RevId: 704373501
2024-12-09 12:11:58 -08:00
jax authors
dba3358dd4
Merge pull request #25294 from jakevdp:array-api-tests
...
PiperOrigin-RevId: 704363905
2024-12-09 11:46:00 -08:00
jax authors
56fcd38d46
Merge pull request #25351 from jax-ml:dependabot/github_actions/actions/cache-4.2.0
...
PiperOrigin-RevId: 704363738
2024-12-09 11:44:08 -08:00
Andrey Portnoy
cc22334c21
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
2024-12-09 14:31:20 -05:00
Dan Foreman-Mackey
092d2a0db5
Add error message when using custom_vmap with reverse-mode AD, and add docstrings.
...
The `custom_vmap` API is discussed in https://github.com/jax-ml/jax/issues/9073 , and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.
One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.
PiperOrigin-RevId: 704353963
2024-12-09 11:17:44 -08:00
IvyZX
bd77a703fd
Avoid index out of range error in carry structure check
2024-12-09 10:44:28 -08:00
Hyeontaek Lim
296d1670bf
[JAX] Add concurrent execution support in colocated Python
...
This change makes asynchronous execution run without holding a mutex. This
allows colocated Python executions from multiple Python threads to run
concurrently.
PiperOrigin-RevId: 704340663
2024-12-09 10:43:30 -08:00
jax authors
d908e0add9
Merge pull request #25349 from jax-ml:fix-25329
...
PiperOrigin-RevId: 704311844
2024-12-09 09:20:54 -08:00
dependabot[bot]
b6863dfcb5
Bump actions/cache from 4.1.2 to 4.2.0
...
Bumps [actions/cache](https://github.com/actions/cache ) from 4.1.2 to 4.2.0.
- [Release notes](https://github.com/actions/cache/releases )
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md )
- [Commits](6849a64899...1bd1e32a3b
)
---
updated-dependencies:
- dependency-name: actions/cache
dependency-type: direct:production
update-type: version-update:semver-minor
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-12-09 17:14:14 +00:00
Bart Chrzaszcz
6f69774c00
#sdy enable test_compute_offload_mesh_with_linear_layout
for Shardy.
...
PiperOrigin-RevId: 704301465
2024-12-09 08:46:48 -08:00
Berkin Ilbeyi
f17b2bc2d3
Reenable for_loop_test on TPU v5p.
...
PiperOrigin-RevId: 704298792
2024-12-09 08:38:41 -08:00
Tomás Longeri
b76d264fe7
[Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by always using implicit shapes
...
PiperOrigin-RevId: 704297805
2024-12-09 08:35:14 -08:00
Dougal
dd74394e63
Use private names for args in api_util to avoid shadowing kwargs keys.
...
This is a quick fix for #25329 . We probably shouldn't use kwargs in linear_util.
We probably shouldn't use linear_util...
2024-12-09 11:24:11 -05:00
Ayaka
9c98c0cbbf
[Pallas TPU] Improve lowerings for boolean comparison operations
...
The error when negating a boolean value (https://github.com/jax-ml/jax/issues/24243 ) has been fixed, so we can lower the boolean comparison operations using boolean algebra instead of using the previous workaround.
Besides, the original tests uses `allclose` on boolean arrays, which is wrong. I have changed them to `assertArraysEqual`.
PiperOrigin-RevId: 704294742
2024-12-09 08:23:51 -08:00
Sunita Nadampalli
e370deee0f
add mkldnn+acl build config for aarch64 platform
2024-12-09 16:03:14 +00:00
Peter Hawkins
79318a08cf
Remove dead code after minimum jaxlib version bump to v0.4.36.
...
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.
PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
cc258f5f61
Merge pull request #25320 from ROCm:gh-9948-fix-kernel-build-upstream
...
PiperOrigin-RevId: 704279150
2024-12-09 07:29:31 -08:00
Sergei Lebedev
1ac6b762dd
Ensured that JAX type checks under pytype on Python 3.12
...
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.
PiperOrigin-RevId: 704268742
2024-12-09 06:53:08 -08:00
jax authors
5a1c4c5783
Merge pull request #25338 from carlosgmartin:fix_numpy_linalg_matrix_norm_ord_type_annotation
...
PiperOrigin-RevId: 704245037
2024-12-09 05:17:26 -08:00
Paweł Paruzel
d474feda9e
Activate Tridiagonal Reduction to XLA's FFI
...
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.
PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
Adam Paszke
adb2bf629c
[Mosaic TPU] Allow downgrading the IR during serialization for forward compat
...
This is to uphold the monthly stability promise made by jax.export.
PiperOrigin-RevId: 704233290
2024-12-09 04:32:41 -08:00
Chris Jones
a94474d016
[pallas] Add DotAlgorithmPreset
note to CHANGELOG.
...
PiperOrigin-RevId: 704216341
2024-12-09 03:26:20 -08:00
Chris Jones
3ec55c7723
[pallas:triton] Add support for DotAlgorithmPreset
precision
arguments to dot
.
...
PiperOrigin-RevId: 704208558
2024-12-09 02:52:47 -08:00
carlosgmartin
efa35ea9f9
Fix type annotation for numpy.linalg.matrix_norm argument 'ord'.
2024-12-08 20:11:06 -05:00