1625 Commits

Author SHA1 Message Date
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
George Necula
3faff78ca8 [better_errors] Ensure that tracer errors in for_loop points to use code
Fixes: 23637
2025-01-13 15:33:30 +00:00
Dan Foreman-Mackey
167a48f677 Add a JVP rule for lax.linalg.tridiagonal_solve + various fixes. 2025-01-10 12:57:37 -05:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
jax authors
8c23689852 Merge pull request #25800 from gnecula:improve_error_switch
PiperOrigin-RevId: 713962512
2025-01-10 01:52:21 -08:00
George Necula
c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Gunhyun Park
93ef0f13fe Clarify documentation of composites.
There were some confusion regarding how to properly add attributes to the op in https://github.com/jax-ml/jax/issues/25767.

PiperOrigin-RevId: 713726697
2025-01-09 10:54:54 -08:00
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Yunlong Liu
3ff000ee3e fix the degenerated case 2025-01-06 16:08:07 +00:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
jax authors
1719986aaa [Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
jax authors
44d67e1379 Merge pull request #25648 from hawkinsp:warnings3
PiperOrigin-RevId: 708415848
2024-12-20 13:41:43 -08:00
Jake VanderPlas
beee98ab4a Add int4/uint4 support to bitcast_convert_type 2024-12-20 12:45:24 -08:00
Peter Hawkins
59e5ce22d3 Avoid calls to warnings.catch_warnings in JAX core code.
warnings.catch_warnings is not thread-safe. However it is always used to avoid complex-to-real conversion warnings, which we can avoid in other ways.
2024-12-20 15:43:03 -05:00
Oleg Shyshkov
db464b3f0a Clarify documentation for output_offsets operand of ragged_all_to_all.
PiperOrigin-RevId: 708321802
2024-12-20 07:52:11 -08:00
Matthew Johnson
b6482f126e add mutable array ref error checks to cond and custom_vjp 2024-12-20 01:44:50 +00:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Gunhyun Park
d206cc3b50 Add lax.composite primitive
A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`.

This is what the HLO looks like for a tangent composite:
```
module @jit_my_tangent_composite {
  func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) {
    %0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64>
    return %0 : tensor<4xf64>
  }
  func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> {
    %0 = stablehlo.sine %arg0 : tensor<4xf64>
    %1 = stablehlo.cosine %arg0 : tensor<4xf64>
    %2 = stablehlo.divide %0, %1 : tensor<4xf64>
    return %2 : tensor<4xf64>
  }
}
```

Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention".

This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction.

For more details on the API, refer to the documentation.

PiperOrigin-RevId: 707750633
2024-12-18 19:38:37 -08:00
jax authors
f65ecedde7 Merge pull request #25593 from mattjj:ref-errors-4
PiperOrigin-RevId: 707733777
2024-12-18 18:23:33 -08:00
Matthew Johnson
e52856261f add mutable array ref error checks to scan 2024-12-19 01:33:39 +00:00
Yash Katariya
af63e443ef [sharding_in_types] Check out_avals with mesh context too. This is because users can pass their own shardings to functions like einsum, reshape, broadcast`, etc
PiperOrigin-RevId: 707672801
2024-12-18 14:42:40 -08:00
Christos Perivolaropoulos
aaabb9752f Partial discharge for scan_p ops.
PiperOrigin-RevId: 707558502
2024-12-18 08:23:06 -08:00
Pearu Peterson
f592173c6c Use StableHLO acos and update complex acos accuracy tests. 2024-12-18 15:19:38 +02:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
Oleg Shyshkov
6d82a6fc90 Allow lax.ragged_all_to_all input and output operands to have different ragged dimension sizes.
We need to guaranty that the outermost dimension of the output is big enough to fit all received elements, but it's not necessary for input and output outermost dimensions to be exactly equal.

PiperOrigin-RevId: 707011916
2024-12-17 02:20:10 -08:00
Jake VanderPlas
74e9275bf2 Fix incorrect capitalization in scan error message 2024-12-16 11:37:31 -08:00
jax authors
5a3fa500b5 Merge pull request #25459 from hawkinsp:sort
PiperOrigin-RevId: 705869484
2024-12-13 06:55:32 -08:00
Peter Hawkins
0922feb2f5 Use a broadcasted gather in the sort JVP, rather than forming explicit iotas.
Use an unsigned index and promise that it is in bounds.
2024-12-13 09:23:34 -05:00
Parker Schuh
0e7f218eb0 Support axis_index inside shard_map(auto=...) by using iota and
then calling full_to_shard.

PiperOrigin-RevId: 705704369
2024-12-12 18:39:05 -08:00
Jake VanderPlas
67b3413b96 Cleanup: replace lax._abstractify with core.get_aval 2024-12-12 14:08:17 -08:00
Jake VanderPlas
40367a9eaf Cleanup: remove uses of no-op raise_to_shaped 2024-12-12 09:49:06 -08:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Jake VanderPlas
c40780b957 internal: dedupe lax broadcasting logic 2024-12-11 15:03:39 -08:00
jax authors
e55bbc778a Merge pull request #25422 from jakevdp:broadcast-rank
PiperOrigin-RevId: 705245013
2024-12-11 14:38:24 -08:00
Jake VanderPlas
76d8b9c5a4 internal: simplify broadcast_shapes logic 2024-12-11 13:50:20 -08:00
Jake VanderPlas
65d2ca632c jax.lax: raise TypeError for mismatched dtypes 2024-12-11 11:59:10 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
jax authors
263d4d1462 Merge pull request #25369 from jax-ml:mutable-arrays-ad
PiperOrigin-RevId: 704685653
2024-12-10 06:36:02 -08:00