780 Commits

Author SHA1 Message Date
Praveen Narayanan
b6d4fe5387 Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
PiperOrigin-RevId: 735471245
2025-03-10 12:25:22 -07:00
Dan Foreman-Mackey
a981e1c4b9 Start adding primitive registration helper functions to lax.linalg.
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.

PiperOrigin-RevId: 729471970
2025-02-21 04:05:34 -08:00
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Gunhyun Park
d206cc3b50 Add lax.composite primitive
A composite function can encapsulate an operation made up of other JAX functions. The semantics of the op is implemented by the `decomposition` function. For example, a `tangent` operation can be implemented as `sin(x) / cos(x)`.

This is what the HLO looks like for a tangent composite:
```
module @jit_my_tangent_composite {
  func.func public @main(%arg0: tensor<4xf64>) -> (tensor<4xf64>) {
    %0 = stablehlo.composite "my.tangent" %arg0 {decomposition = @my.tangent} : (tensor<4xf64>) -> tensor<4xf64>
    return %0 : tensor<4xf64>
  }
  func.func private @my.tangent(%arg0: tensor<4xf64>) -> tensor<4xf64> {
    %0 = stablehlo.sine %arg0 : tensor<4xf64>
    %1 = stablehlo.cosine %arg0 : tensor<4xf64>
    %2 = stablehlo.divide %0, %1 : tensor<4xf64>
    return %2 : tensor<4xf64>
  }
}
```

Similarly, this can scale to something like Attention. By preserving such an abstraction, it greatly simplifies pattern matching. Instead of matching the set of ops that represent Attention, the matcher can simply look for a uniquely identifying composite op like "MyAttention".

This is useful for preserving high level abstraction that would otherwise be lost during lowering. The hardware-aware compiler can recognize the single composite op and emit efficient code rather than pattern-matching a generic lowering which is then replaced with your own efficient lowering. And then the decomposition function can be DCE'd away. If the hardware does not have an efficient lowering, it can inline the `decomposition` which implements the semantics of the abstraction.

For more details on the API, refer to the documentation.

PiperOrigin-RevId: 707750633
2024-12-18 19:38:37 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Gunhyun Park
12c30578b2 Introduce lax.ragged_all_to_all primitive
This version emits a StableHLO custom call. The test outputs the following MLIR module:
```
module @jit_ragged_all_to_all {
  func.func public @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32>, %arg3: tensor<3xi32>, %arg4: tensor<3xi32>, %arg5: tensor<3xi32>) -> (tensor<6xf32>) {
    %0 = stablehlo.custom_call @ragged_all_to_all(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {api_version = 4 : i32, backend_config = {replica_groups = dense<[[0, 1, 2]]> : tensor<1x3xi64>}} : (tensor<6xf32>, tensor<6xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<6xf32>
    return %0 : tensor<6xf32>
  }
}
```

For now, the API assumes `split_axis` and `concat_axis` of `all_to_all` to be the outermost (ragged) dim, and `axis_index_groups` is default to all replicas (e.g. there is only one group and covers all axis indices aka iota like the example above).

The current API is inspired from https://www.mpich.org/static/docs/v3.1/www3/MPI_Alltoallv.html which essentially also does a ragged all to all.

PiperOrigin-RevId: 704550890
2024-12-09 22:19:40 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Pearu Peterson
4d0a007d57 Add square_p 2024-11-13 20:14:37 +02:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Sergei Lebedev
4cf33c0239 Added scatter_sub_p
The new primitive is used for in-place subtract and update.

Closes #23933

PiperOrigin-RevId: 681754037
2024-10-03 00:27:31 -07:00
Dan Foreman-Mackey
bc1e1a0220 Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
2024-09-25 06:17:09 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
piotrfilipiuk
93dfe05aec Implements Ragged Dot API 2024-05-11 06:40:18 -07:00
Chase Roberts
01412f7645 pbroadcast 2024-03-18 15:12:33 -07:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Neil Girdhar
9f85beb56b Expose PrecisionLike
This is used in client code like:
https://github.com/search?q=repo%3Agoogle%2Fflax%20%20PrecisionLike&type=code
2023-12-06 14:41:22 -05:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
David Majnemer
8fe4fcc0ef Use totalorder comparisons for sort
PiperOrigin-RevId: 573289718
2023-10-13 12:21:07 -07:00
Jake VanderPlas
ce6a0c43ad jax.lax: deprecate inadvertent exports & internal utilities 2023-10-06 11:26:03 -07:00
Jake VanderPlas
665b176c2c remove deprecated jax.lax.prod function
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
jax authors
209b6b02f4 Merge pull request #17144 from jakevdp:zeta
PiperOrigin-RevId: 558193896
2023-08-18 11:04:43 -07:00
Sharad Vikram
caee3120fd Expose exp2_p in jax.lax
PiperOrigin-RevId: 557642106
2023-08-16 16:49:15 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
0ad6196ff0 Create lax.polygamma with native HLO lowering 2023-08-16 11:57:05 -07:00
Matthew Johnson
560ede0ff1 add an exp2 primitive and lax.exp2
part of fixing https://github.com/jax-ml/jax-triton/issues/204
2023-07-28 12:33:49 -07:00
Jake VanderPlas
4cfa96ef8f deprecate jax.lax.prod 2023-05-23 17:33:50 -07:00
Jake VanderPlas
7f7f995bf4 Export jax.lax.sharding_constraint_p
PiperOrigin-RevId: 534566582
2023-05-23 14:50:46 -07:00
Jake VanderPlas
8dc06ed2ce Document jax.lax.with_sharding_constraint 2023-04-26 10:19:04 -07:00
Anish Tondwalkar
adbdaa47a3 Refactor special functions into their own module.
We're going to want to decompose these using series and
continued fraction representations, and for that we'll need
control flow

PiperOrigin-RevId: 518977008
2023-03-23 15:21:15 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
8a2765ad2f Export device_put_p as jax.lax.device_put_p.
lax seems like the most natural place to export the primitive for now.
2023-02-22 14:44:15 -05:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Srinivas Vasudevan
5adfb08986 Add lax.cumlogsumexp for cumulative logsumexp operations.
PiperOrigin-RevId: 485158935
2022-10-31 15:08:52 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00