25125 Commits

Author SHA1 Message Date
jax authors
743872dfed Merge pull request #25817 from jakevdp:deprecation-utils
PiperOrigin-RevId: 714132238
2025-01-10 12:10:53 -08:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
jax authors
5d0ee43222 Merge pull request #25741 from jakevdp:solve-dep
PiperOrigin-RevId: 714124347
2025-01-10 11:46:33 -08:00
jax authors
aed79707e2 Merge pull request #25791 from mattjj:logsumexp-where-grad-nan
PiperOrigin-RevId: 714118085
2025-01-10 11:27:35 -08:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
jax authors
016fca79ca Merge pull request #25787 from dfm:tri-diag-jvp
PiperOrigin-RevId: 714109627
2025-01-10 11:03:18 -08:00
jax authors
f7b6deeb01 Merge pull request #25838 from jakevdp:simplify-nightly-reporting
PiperOrigin-RevId: 714104820
2025-01-10 10:50:42 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00
jax authors
1cc07dd392 Update XLA dependency to use revision
08dcaad271.

PiperOrigin-RevId: 714096988
2025-01-10 10:29:50 -08:00
Jake VanderPlas
7f55cccaea Simplify failure reporting for nightly CI job 2025-01-10 10:03:03 -08:00
Dan Foreman-Mackey
167a48f677 Add a JVP rule for lax.linalg.tridiagonal_solve + various fixes. 2025-01-10 12:57:37 -05:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
jax authors
4f106b8a27 Merge pull request #25831 from jax-ml:avoid-float0-tracers
PiperOrigin-RevId: 714058085
2025-01-10 08:16:43 -08:00
Benjamin Chetioui
1893881b5f [Mosaic GPU] Add initial layout mismatch resolution for splat/strided layouts.
When it is possible to annotate an operation using both a `strided` and a
`splat` layout, we pick the `strided` layout. This is the correct choice when
propagating layouts down from parameters to the root; e.g.

```
? = add(strided, splat)
```

becomes

```
strided = add(strided, strided)
```

and requires a re-layout for the right-hand side argument.

The logic needs to be reversed to handle propagation in the opposite direction.
For example, code like

```
c : ?
use(c) : strided
use(c) : splat
```

should resolve to

```
c : splat
use(c) : strided
use(c) : splat
```

and incur a relayout in the `strided` use of `c`. This direction of propagation
is left as a `TODO` for now, to limit the amount of changes in a single commit.

PiperOrigin-RevId: 714056648
2025-01-10 08:10:57 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
Dougal
ba9b2ca5f6 Avoid creating float0 JVPTracers 2025-01-10 10:43:54 -05:00
jax authors
1fe72ee880 Merge pull request #25771 from dfm:custom-partition-ffi-tutorial
PiperOrigin-RevId: 714048619
2025-01-10 07:43:10 -08:00
Adam Paszke
d2a5e8d072 [Mosaic TPU] Add support for integer truncation from packed types
PiperOrigin-RevId: 714048232
2025-01-10 07:40:55 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Vladimir Belitskiy
86643a1b3e Skip RnnTest.test_struct_encoding_determinism.
PiperOrigin-RevId: 714027519
2025-01-10 06:20:01 -08:00
Sergei Lebedev
18018d9cc9 [pallas:mosaic_gpu] Tests now pass with x64 enabled
PiperOrigin-RevId: 714005603
2025-01-10 04:48:36 -08:00
Adam Paszke
74cf67df9d [Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
2025-01-10 04:35:38 -08:00
Chris Jones
a27566cc7b Reverts dbe9ccd6dccd83c365021677c7e17e843d4559c4
PiperOrigin-RevId: 713989952
2025-01-10 03:40:18 -08:00
jax authors
8c23689852 Merge pull request #25800 from gnecula:improve_error_switch
PiperOrigin-RevId: 713962512
2025-01-10 01:52:21 -08:00
jax authors
228d3cef0b Merge pull request #25797 from gnecula:print_debug_info
PiperOrigin-RevId: 713958983
2025-01-10 01:39:34 -08:00
George Necula
c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Yash Katariya
6319126e2d Remove extraneous print statement in a test
PiperOrigin-RevId: 713830757
2025-01-09 16:20:21 -08:00
jax authors
061408aca3 Merge pull request #25803 from sergachev:fix_rnn_desc
PiperOrigin-RevId: 713789106
2025-01-09 14:05:30 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Gunhyun Park
93ef0f13fe Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
2025-01-09 10:54:54 -08:00
jax authors
6c8b02df01 Merge pull request #25753 from dfm:shp-harm-y
PiperOrigin-RevId: 713717495
2025-01-09 10:28:32 -08:00
Dan Foreman-Mackey
5f3e0d9e5e Add sph_harm_y to jax.scipy.special and deprecate sph_harm. 2025-01-09 12:53:00 -05:00
jax authors
7718ac539a Merge pull request #25626 from hawkinsp:warnings2
PiperOrigin-RevId: 713705552
2025-01-09 09:51:15 -08:00
jax authors
9cbd70b864 Merge pull request #25801 from cherrywoods:reduce_window_docstring
PiperOrigin-RevId: 713704872
2025-01-09 09:49:19 -08:00
jax authors
e571694f08 Update XLA dependency to use revision
2f6eabb5a1.

PiperOrigin-RevId: 713704549
2025-01-09 09:47:38 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Dan Foreman-Mackey
729418094e Add a discussion of sharding to the FFI tutorial. 2025-01-09 11:24:40 -05:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Adam Paszke
f23979e2fa [NFC] Refactor conversion lowering for Mosaic TPU
PiperOrigin-RevId: 713673355
2025-01-09 08:00:04 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
fb832afc00 Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes https://github.com/jax-ml/jax/issues/25769
PiperOrigin-RevId: 713446871
2025-01-08 16:03:03 -08:00
Matthew Johnson
f0392a1535 fix grad(logsumexp) to produce 0s where where is False 2025-01-08 23:38:06 +00:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
jax authors
196eec8296 Merge pull request #25786 from vfdev-5:add-313-ft-configs
PiperOrigin-RevId: 713415451
2025-01-08 14:21:40 -08:00