Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.
PiperOrigin-RevId: 713780181
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.
Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.
To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` - see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.
PiperOrigin-RevId: 713411171
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.
Extend custom_partitioner tests in pjit_test.py for Shardy sharding rule.
PiperOrigin-RevId: 713399604
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.
We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.
PiperOrigin-RevId: 713352542
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.
This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.
We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.
We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.
PiperOrigin-RevId: 713312937
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.
If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.
PiperOrigin-RevId: 713296722
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.
PiperOrigin-RevId: 713270421