This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.
Before:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
p:f32[5,3] = add_any m o
q:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
s:f32[5,3] = add_any p r
t:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
v:f32[5,3] = add_any s u
w:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
y:f32[5,3] = add_any v x
in (y,) }
] a b c d e
in (f,) }
```
Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.
After:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
o:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
p:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
q:f32[5,3] = concatenate[dimension=0] p o n m l
in (q,) }
] a b c d e
in (f,) }
```
Change in preparation for removing HLO ops from the XLA Python bindings.
In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.
PiperOrigin-RevId: 684892102
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.
The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.
Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.
To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)
With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.
Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.
One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.
PiperOrigin-RevId: 683302687
Currently, the class only has "An enumeration." as the docstring when viewing the documentation, which is unhelpful for users. This PR adds class members, detailed descriptions and cross-references to the docstring to make it beautiful and informative.
PiperOrigin-RevId: 681866947
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.
This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.
Transposition is supported using the `transpose_algorithm` argument.
PiperOrigin-RevId: 678672686
Building on #21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.
As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.
Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.
PiperOrigin-RevId: 487621469
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.
PiperOrigin-RevId: 472705623
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.
Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.
Add testing for `sqrt`/`rsqrt` for more types.
[XLA:Python] Add cbrt to XLA:Python bindings.
PiperOrigin-RevId: 386316949
Add support for an axis= parameter to associative_scan.
We previously had two associative scan implementations, namely lax.associative_scan, and the implementations of cumsum, cumprod, etc.
lax.associative_scan was more efficient in some ways because unlike the cumsum implementation it did not pad the input array to the nearest power of two size. This appears to have been a significant cause of https://github.com/google/jax/issues/4135.
The cumsum/cummax implementation used slightly more efficient code to slice and
interleave arrays, which this change adds to associative_scan as well. Since we
are now using lax primitives that make it easy to select an axis, add support
for user-chosen scan axes as well.
We can also simplify the implementation of associative_scan: one of the
recursive base cases seems unnecessary, and we can simplify the code by removing
it.
Benchmarks from #4135 on my workstation:
Before:
bench_cumsum: 0.900s
bench_associative_scan: 0.597s
bench_scan: 0.359s
bench_np: 1.619s
After:
bench_cumsum: 0.435s
bench_associative_scan: 0.435s
bench_scan: 0.362s
bench_np: 1.669s
Before, with taskset -c 0:
bench_cumsum: 1.989s
bench_associative_scan: 1.556s
bench_scan: 0.428s
bench_np: 1.670s
After, with taskset -c 0:
bench_cumsum: 1.271s
bench_associative_scan: 1.275s
bench_scan: 0.438s
bench_np: 1.673s