25204 Commits

Author SHA1 Message Date
Sergei Lebedev
18018d9cc9 [pallas:mosaic_gpu] Tests now pass with x64 enabled
PiperOrigin-RevId: 714005603
2025-01-10 04:48:36 -08:00
Adam Paszke
74cf67df9d [Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
2025-01-10 04:35:38 -08:00
Chris Jones
a27566cc7b Reverts dbe9ccd6dccd83c365021677c7e17e843d4559c4
PiperOrigin-RevId: 713989952
2025-01-10 03:40:18 -08:00
jax authors
8c23689852 Merge pull request #25800 from gnecula:improve_error_switch
PiperOrigin-RevId: 713962512
2025-01-10 01:52:21 -08:00
jax authors
228d3cef0b Merge pull request #25797 from gnecula:print_debug_info
PiperOrigin-RevId: 713958983
2025-01-10 01:39:34 -08:00
George Necula
c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Yash Katariya
6319126e2d Remove extraneous print statement in a test
PiperOrigin-RevId: 713830757
2025-01-09 16:20:21 -08:00
jax authors
061408aca3 Merge pull request #25803 from sergachev:fix_rnn_desc
PiperOrigin-RevId: 713789106
2025-01-09 14:05:30 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Gunhyun Park
93ef0f13fe Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
2025-01-09 10:54:54 -08:00
jax authors
6c8b02df01 Merge pull request #25753 from dfm:shp-harm-y
PiperOrigin-RevId: 713717495
2025-01-09 10:28:32 -08:00
Dan Foreman-Mackey
5f3e0d9e5e Add sph_harm_y to jax.scipy.special and deprecate sph_harm. 2025-01-09 12:53:00 -05:00
jax authors
7718ac539a Merge pull request #25626 from hawkinsp:warnings2
PiperOrigin-RevId: 713705552
2025-01-09 09:51:15 -08:00
jax authors
9cbd70b864 Merge pull request #25801 from cherrywoods:reduce_window_docstring
PiperOrigin-RevId: 713704872
2025-01-09 09:49:19 -08:00
jax authors
e571694f08 Update XLA dependency to use revision
2f6eabb5a1.

PiperOrigin-RevId: 713704549
2025-01-09 09:47:38 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Dan Foreman-Mackey
729418094e Add a discussion of sharding to the FFI tutorial. 2025-01-09 11:24:40 -05:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Adam Paszke
f23979e2fa [NFC] Refactor conversion lowering for Mosaic TPU
PiperOrigin-RevId: 713673355
2025-01-09 08:00:04 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
fb832afc00 Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes https://github.com/jax-ml/jax/issues/25769
PiperOrigin-RevId: 713446871
2025-01-08 16:03:03 -08:00
Matthew Johnson
f0392a1535 fix grad(logsumexp) to produce 0s where where is False 2025-01-08 23:38:06 +00:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
jax authors
196eec8296 Merge pull request #25786 from vfdev-5:add-313-ft-configs
PiperOrigin-RevId: 713415451
2025-01-08 14:21:40 -08:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
vfdev-5
00806ddaf5 Added 3.13 ft requirements lock file and updated WORKSPACE 2025-01-08 22:47:29 +01:00
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Yash Katariya
b3833dc705 Match the behavior on single host wrt multi-host if tiled=False. Fixes https://github.com/jax-ml/jax/issues/25783
PiperOrigin-RevId: 713398173
2025-01-08 13:31:35 -08:00
jax authors
6e1f060ad3 Merge pull request #25527 from vfdev-5:single-python-version-build-py
PiperOrigin-RevId: 713365267
2025-01-08 11:49:59 -08:00
Justin Fu
d99a637d8b [Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
2025-01-08 11:21:19 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Sharad Vikram
c1a60c676a [Pallas] Add empty/empty_like helper functions
PiperOrigin-RevId: 713344151
2025-01-08 10:49:11 -08:00
jax authors
55119490eb Update XLA dependency to use revision
1b99693486.

PiperOrigin-RevId: 713320368
2025-01-08 09:37:59 -08:00
Bart Chrzaszcz
5c097c8f62 #sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
    pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
2025-01-08 09:17:54 -08:00
Peter Hawkins
0389d617c8 Add a unittest test extension that runs test cases in parallel using threads.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.

This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.

We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.

We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.

PiperOrigin-RevId: 713312937
2025-01-08 09:11:47 -08:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Sergei Lebedev
f1f98afee8 [pallas:mosaic_gpu] Fix the tests following the changes to pl.core_map
PiperOrigin-RevId: 713283207
2025-01-08 07:24:08 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Adam Paszke
f96339be1e [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
2025-01-08 06:30:36 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Adam Paszke
e954930eaf [Mosaic TPU] Add support for true divide in bf16 on TPUv6
PiperOrigin-RevId: 713247480
2025-01-08 04:49:22 -08:00
Tzu-Wei Sung
bf94389b08 [Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
It was cmp + iota before.

PiperOrigin-RevId: 713240888
2025-01-08 04:18:53 -08:00
jax authors
4718121efe Merge pull request #25754 from andportnoy:patch-4
PiperOrigin-RevId: 713222111
2025-01-08 02:57:20 -08:00
Sergei Lebedev
90201ce2b7 Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
2025-01-08 01:39:38 -08:00
jax authors
81db3219b7 Merge pull request #25594 from zhenying-liu:activation-offloading-doc
PiperOrigin-RevId: 713170813
2025-01-07 23:26:21 -08:00