25019 Commits

Author SHA1 Message Date
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
jax authors
016fca79ca Merge pull request #25787 from dfm:tri-diag-jvp
PiperOrigin-RevId: 714109627
2025-01-10 11:03:18 -08:00
jax authors
f7b6deeb01 Merge pull request #25838 from jakevdp:simplify-nightly-reporting
PiperOrigin-RevId: 714104820
2025-01-10 10:50:42 -08:00
jax authors
1cc07dd392 Update XLA dependency to use revision
08dcaad271.

PiperOrigin-RevId: 714096988
2025-01-10 10:29:50 -08:00
Jake VanderPlas
7f55cccaea Simplify failure reporting for nightly CI job 2025-01-10 10:03:03 -08:00
Dan Foreman-Mackey
167a48f677 Add a JVP rule for lax.linalg.tridiagonal_solve + various fixes. 2025-01-10 12:57:37 -05:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
jax authors
4f106b8a27 Merge pull request #25831 from jax-ml:avoid-float0-tracers
PiperOrigin-RevId: 714058085
2025-01-10 08:16:43 -08:00
Benjamin Chetioui
1893881b5f [Mosaic GPU] Add initial layout mismatch resolution for splat/strided layouts.
When it is possible to annotate an operation using both a `strided` and a
`splat` layout, we pick the `strided` layout. This is the correct choice when
propagating layouts down from parameters to the root; e.g.

```
? = add(strided, splat)
```

becomes

```
strided = add(strided, strided)
```

and requires a re-layout for the right-hand side argument.

The logic needs to be reversed to handle propagation in the opposite direction.
For example, code like

```
c : ?
use(c) : strided
use(c) : splat
```

should resolve to

```
c : splat
use(c) : strided
use(c) : splat
```

and incur a relayout in the `strided` use of `c`. This direction of propagation
is left as a `TODO` for now, to limit the amount of changes in a single commit.

PiperOrigin-RevId: 714056648
2025-01-10 08:10:57 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
Dougal
ba9b2ca5f6 Avoid creating float0 JVPTracers 2025-01-10 10:43:54 -05:00
jax authors
1fe72ee880 Merge pull request #25771 from dfm:custom-partition-ffi-tutorial
PiperOrigin-RevId: 714048619
2025-01-10 07:43:10 -08:00
Adam Paszke
d2a5e8d072 [Mosaic TPU] Add support for integer truncation from packed types
PiperOrigin-RevId: 714048232
2025-01-10 07:40:55 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Vladimir Belitskiy
86643a1b3e Skip RnnTest.test_struct_encoding_determinism.
PiperOrigin-RevId: 714027519
2025-01-10 06:20:01 -08:00
Sergei Lebedev
18018d9cc9 [pallas:mosaic_gpu] Tests now pass with x64 enabled
PiperOrigin-RevId: 714005603
2025-01-10 04:48:36 -08:00
Adam Paszke
74cf67df9d [Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
2025-01-10 04:35:38 -08:00
Chris Jones
a27566cc7b Reverts dbe9ccd6dccd83c365021677c7e17e843d4559c4
PiperOrigin-RevId: 713989952
2025-01-10 03:40:18 -08:00
jax authors
8c23689852 Merge pull request #25800 from gnecula:improve_error_switch
PiperOrigin-RevId: 713962512
2025-01-10 01:52:21 -08:00
jax authors
228d3cef0b Merge pull request #25797 from gnecula:print_debug_info
PiperOrigin-RevId: 713958983
2025-01-10 01:39:34 -08:00
George Necula
c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Yash Katariya
6319126e2d Remove extraneous print statement in a test
PiperOrigin-RevId: 713830757
2025-01-09 16:20:21 -08:00
jax authors
061408aca3 Merge pull request #25803 from sergachev:fix_rnn_desc
PiperOrigin-RevId: 713789106
2025-01-09 14:05:30 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Gunhyun Park
93ef0f13fe Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
2025-01-09 10:54:54 -08:00
jax authors
6c8b02df01 Merge pull request #25753 from dfm:shp-harm-y
PiperOrigin-RevId: 713717495
2025-01-09 10:28:32 -08:00
Dan Foreman-Mackey
5f3e0d9e5e Add sph_harm_y to jax.scipy.special and deprecate sph_harm. 2025-01-09 12:53:00 -05:00
jax authors
7718ac539a Merge pull request #25626 from hawkinsp:warnings2
PiperOrigin-RevId: 713705552
2025-01-09 09:51:15 -08:00
jax authors
9cbd70b864 Merge pull request #25801 from cherrywoods:reduce_window_docstring
PiperOrigin-RevId: 713704872
2025-01-09 09:49:19 -08:00
jax authors
e571694f08 Update XLA dependency to use revision
2f6eabb5a1.

PiperOrigin-RevId: 713704549
2025-01-09 09:47:38 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Dan Foreman-Mackey
729418094e Add a discussion of sharding to the FFI tutorial. 2025-01-09 11:24:40 -05:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Adam Paszke
f23979e2fa [NFC] Refactor conversion lowering for Mosaic TPU
PiperOrigin-RevId: 713673355
2025-01-09 08:00:04 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
fb832afc00 Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes https://github.com/jax-ml/jax/issues/25769
PiperOrigin-RevId: 713446871
2025-01-08 16:03:03 -08:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
jax authors
196eec8296 Merge pull request #25786 from vfdev-5:add-313-ft-configs
PiperOrigin-RevId: 713415451
2025-01-08 14:21:40 -08:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
vfdev-5
00806ddaf5 Added 3.13 ft requirements lock file and updated WORKSPACE 2025-01-08 22:47:29 +01:00
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Yash Katariya
b3833dc705 Match the behavior on single host wrt multi-host if tiled=False. Fixes https://github.com/jax-ml/jax/issues/25783
PiperOrigin-RevId: 713398173
2025-01-08 13:31:35 -08:00
jax authors
6e1f060ad3 Merge pull request #25527 from vfdev-5:single-python-version-build-py
PiperOrigin-RevId: 713365267
2025-01-08 11:49:59 -08:00
Justin Fu
d99a637d8b [Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
2025-01-08 11:21:19 -08:00