This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.
Before:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
p:f32[5,3] = add_any m o
q:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
s:f32[5,3] = add_any p r
t:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
v:f32[5,3] = add_any s u
w:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
y:f32[5,3] = add_any v x
in (y,) }
] a b c d e
in (f,) }
```
Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.
After:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jnp.ones((3,))
In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
f:f32[5,3] = pjit[
name=unstack
jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
l:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] k
m:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] j
n:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] i
o:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] h
p:f32[1,3] = broadcast_in_dim[
broadcast_dimensions=(1,)
shape=(1, 3)
sharding=None
] g
q:f32[5,3] = concatenate[dimension=0] p o n m l
in (q,) }
] a b c d e
in (f,) }
```
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.
This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.
PiperOrigin-RevId: 689708392
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
This type is unused by JAX, so there is no replacement.
(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)
PiperOrigin-RevId: 684451556
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.
We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.
PiperOrigin-RevId: 684447186
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 682830525
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.
PiperOrigin-RevId: 682656411
`jax.experimental.host_callback` has been deprecated since March 2024
(JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 681004255
For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.
PiperOrigin-RevId: 674019460