718 Commits

Author SHA1 Message Date
George Necula
3f5f3e1c47 [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
2024-12-05 08:42:41 -08:00
George Necula
5fe5206b6a [shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
2024-12-05 08:02:38 -08:00
labs-code-app[bot]
762301fc5d Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
2024-11-26 13:57:47 +00:00
Nitin Srinivasan
6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00
Peter Hawkins
dfe27a1682 Mention stackless in the release notes. 2024-11-20 14:53:52 -05:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Peter Hawkins
c5e8ae80f9 Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs.
Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827.

Fixes #24875
2024-11-18 20:33:27 -05:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
4fe9164548 Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Jake VanderPlas
f401c97967 finalize deprecation of jax.clear_backends 2024-11-14 09:22:09 -08:00
jax authors
ed9fdbbf0a Merge pull request #24842 from jakevdp:batched-toeplitz
PiperOrigin-RevId: 695917476
2024-11-12 16:52:25 -08:00
Dan Foreman-Mackey
5808170a10 Add GPU overflow bugfix (#24846) to changelog. 2024-11-12 08:57:52 -08:00
Jake VanderPlas
3f98c57f7b jax.scipy.linalg.toeplitz: support implicit batching 2024-11-11 15:32:43 -08:00
jax authors
c8f5b2bb13 Merge pull request #24481 from jakevdp:key-array-error
PiperOrigin-RevId: 694626415
2024-11-08 13:47:05 -08:00
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Jake VanderPlas
1af3b01c1c register_dataclass: allow marking static fields via field(static=True) 2024-11-06 11:18:11 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Jake VanderPlas
e9acaa8484 Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
2024-11-04 10:55:21 -08:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Matthew Johnson
0f3ba4250d support exec_time_optimization_effort and memory_fitting_effort xla compilation
options

PiperOrigin-RevId: 692322944
2024-11-01 16:25:50 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Jake VanderPlas
d4c46825d6 Finalize deprecation of xb, xc, & xe symbols in jax.interpreters.xla
PiperOrigin-RevId: 689792265
2024-10-25 08:12:44 -07:00
George Necula
c62b19883f Fix copy and paste error in CHANGELOG. 2024-10-25 16:11:35 +03:00
George Necula
9088adda68 [jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
Peter Hawkins
2aeda17829 Merge branch 'release/0.4.35' 2024-10-23 08:50:31 -04:00
Peter Hawkins
e4f3f8f064 Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
Peter Hawkins
e9c7ff0b7d Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
2024-10-11 11:42:40 -07:00
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Yuxuan Jiang
757a77ede0
Fix wrong date in changelog 2024-10-06 23:16:30 +08:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
Peter Hawkins
b0b7a60e63 Merge branch 'release/0.4.34' 2024-10-04 10:56:18 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Jake VanderPlas
49ad220e57 Finalize deprecation of XLACompatibleSharding
PiperOrigin-RevId: 681156145
2024-10-01 14:02:34 -07:00
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Peter Hawkins
0e082f978b Deprecate jax.lib.xla_client.Device.
jax.Device is a longstanding public name for this class.

PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
e05c37c667 Finalize deprecation of pretty-printing utils in jax.core.pp_*
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
6a3736a1d7 Add a note to the changelog about the new CPU thunks backend, enabled in 0.4.32. 2024-09-19 15:38:52 -04:00
Peter Hawkins
bef36c431d Add Python 3.13 wheels to changelog. 2024-09-18 18:57:03 +00:00
rajasekharporeddy
2714469397 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros 2024-09-18 17:06:28 +05:30