74 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Jake VanderPlas
2fb750e0ab doc: improve docs for jax.lax trig functions 2025-02-06 11:09:55 -08:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Ayaka
ad78147183 [Docs] Add docstring for RoundingMethod
Currently, the class only has "An enumeration." as the docstring when viewing the documentation, which is unhelpful for users. This PR adds class members, detailed descriptions and cross-references to the docstring to make it beautiful and informative.

PiperOrigin-RevId: 681866947
2024-10-03 07:23:22 -07:00
Dan Foreman-Mackey
bc1e1a0220 Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
2024-09-25 06:17:09 -07:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Dan Foreman-Mackey
0b4800a193 Add ffi_call tutorial
Building on #21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
2024-08-01 15:36:32 -04:00
Yash Katariya
2eb1888c98 Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Neil Girdhar
9f85beb56b Expose PrecisionLike
This is used in client code like:
https://github.com/search?q=repo%3Agoogle%2Fflax%20%20PrecisionLike&type=code
2023-12-06 14:41:22 -05:00
Matthew Johnson
96af01654f add psum_scatter to docs index
fixes #18524
2023-11-14 15:09:08 -08:00
Jake VanderPlas
8330072373 DOC: add missing jax.lax API docs 2023-10-06 10:26:36 -07:00
jax authors
209b6b02f4 Merge pull request #17144 from jakevdp:zeta
PiperOrigin-RevId: 558193896
2023-08-18 11:04:43 -07:00
Jake VanderPlas
6cd467fd57 Create lax.zeta with native HLO lowering 2023-08-16 13:43:41 -07:00
Jake VanderPlas
0ad6196ff0 Create lax.polygamma with native HLO lowering 2023-08-16 11:57:05 -07:00
Jake VanderPlas
8dc06ed2ce Document jax.lax.with_sharding_constraint 2023-04-26 10:19:04 -07:00
Jake VanderPlas
81e627d5bd DOC: make API doc titles more uniform 2023-01-18 10:59:42 -08:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Jake VanderPlas
c66f5dda60 DOC: add missing linalg functionality to docs 2022-03-15 09:55:59 -07:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Peter Hawkins
f3aa5fa92f Document lax.GatherScatterMode.
Recommend the .at[...] property in the docstrings for lax.scatter_ operators.

Add several missing lax.scatter_ operators to the index.
2021-11-22 15:43:02 -05:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Jake VanderPlas
94169b96a8 DOC: add conv_dimension_numbers and ConvGeneralDilatedDimensionNumbers to docs 2021-10-19 17:18:15 -07:00
Peter Hawkins
278ff13b66 Improve implementation of cbrt() in JAX.
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.

Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.

Add testing for `sqrt`/`rsqrt` for more types.

[XLA:Python] Add cbrt to XLA:Python bindings.

PiperOrigin-RevId: 386316949
2021-07-22 14:01:28 -07:00
Roman Novak
bc84c9fe8f Add lax.conv_general_dilated_local 2021-05-13 12:20:35 -07:00
Jake VanderPlas
33fde77bb1 Add lax.reduce_precision() 2021-04-05 09:54:14 -07:00
Jake VanderPlas
749ad95514 DOC: add transformations doc to HTML & reorganize contents 2021-03-08 16:25:04 -08:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Jake VanderPlas
a0b12bba25 DOC: fix minor formatting issues 2021-01-20 14:38:19 -08:00
Benjamin Chetioui
9c56277878 Add "Argument classes" section to jax.lax.rst. 2020-12-01 20:30:09 +01:00
Peter Hawkins
2cf2c719f2 Add documentation to several functions in jax.lax.linalg. 2020-11-05 18:53:47 -05:00
Roman Novak
da0bff2fa8 Add lax.conv_general_dilated_patches 2020-10-20 22:58:53 -07:00
Peter Hawkins
d3db7bd4be Optimize lax.associative_scan, reimplement cumsum, etc. on top of associative_scan.
Add support for an axis= parameter to associative_scan.

We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.

lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.

The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.

We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.

Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s

After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s

Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s

After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s
2020-10-15 20:51:55 -04:00
Peter Hawkins
db43e21b1d Improve documentation for a number of lax functions. 2020-10-14 21:18:09 -04:00
Jake Vanderplas
e0ebb144f9
Add switch and associative_scan to lax docs (#3946) 2020-08-03 12:32:32 -07:00
Chase Roberts
2b7a39f92b
Add pshuffle to docs (#3742) 2020-07-14 09:05:45 -04:00