804 Commits

Author SHA1 Message Date
Peter Hawkins
57e33bcbcd Deprecate the contents of jax.util.
PiperOrigin-RevId: 747629222
2025-04-14 17:20:30 -07:00
Jake VanderPlas
ceca6ec1fc jax.jit: deprecate non-standard call signature. 2025-04-14 10:13:05 -07:00
Dan Foreman-Mackey
1b1bd071bc Finalize deprecation of vectorized argument in callbacks.
The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.

This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.

PiperOrigin-RevId: 747413383
2025-04-14 07:43:59 -07:00
Peter Hawkins
c69e61e1a9 Remove jax.lib.xla_client.{XlaComputation,Shape}.
PiperOrigin-RevId: 746803082
2025-04-12 06:18:02 -07:00
Peter Hawkins
6fc78a5a6d Deprecate jax.lax.infeed and jax.lax.outfeed.
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.

PiperOrigin-RevId: 746595857
2025-04-11 14:42:14 -07:00
Nitin Srinivasan
5cf74cc72b Use dash instead of underscore for extras.
The new `METADATA` specification disallows use of underscore and automatically converts any usage of them to dash.

https://packaging.python.org/en/latest/specifications/core-metadata/#provides-extra-multiple-use

This should fix the following error: https://github.com/jax-ml/jax/issues/27874  from appearing in future JAX releases

PiperOrigin-RevId: 746546162
2025-04-11 12:11:38 -07:00
Peter Hawkins
ab88273596 Deprecate jax.dlpack.to_dlpack.
This is not needed under the newer DLPack protocol for users, and there's an equivalent (`__dlpack__`).

PiperOrigin-RevId: 746530351
2025-04-11 11:26:20 -07:00
George Necula
7eb397d1e5 Make trace and lower class attributes for jax.jit.
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:

```
jax.jit(f).trace(...)
```

The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.

See #27829 and #27825.

The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.

Fixes: #27829
2025-04-11 14:51:12 +03:00
Peter Hawkins
713ea3caa1 [JAX] Remove deprecated exports in jax.lib.xla_client.
PiperOrigin-RevId: 745742774
2025-04-09 14:49:34 -07:00
Dan Foreman-Mackey
2d44f985c3 Finalize deprecation of ffi_call with inline arguments.
PiperOrigin-RevId: 745261995
2025-04-08 13:09:42 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Sergei Lebedev
12811f08a8 Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 745067361
2025-04-08 03:30:36 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00
Sergei Lebedev
51c224c446 Removed deprecated jax.core.{full_lower,jaxpr_as_fun,lattice_join}
PiperOrigin-RevId: 744754730
2025-04-07 09:50:43 -07:00
Dan Foreman-Mackey
5a3fc606d4 Deprecate public export of mlir.custom_call.
PiperOrigin-RevId: 744722183
2025-04-07 07:58:20 -07:00
Peter Hawkins
70485e31b9 Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
These are available via jax.extend.mlir.dialects.

No deprecation period because jax.interpreters.mlir is not a stable API.

PiperOrigin-RevId: 744712537
2025-04-07 07:20:24 -07:00
Sergei Lebedev
9c58a112b3 jnp.array no longer accepts None
PiperOrigin-RevId: 743291099
2025-04-02 14:58:51 -07:00
Jake VanderPlas
10425ae6a9 jax.core: finalize a number of deprecations for JAX v0.6.0 2025-03-30 19:32:22 -07:00
Jake VanderPlas
18521fef08 Deprecate jax.tree_* aliases 2025-03-27 10:13:14 -07:00
jax authors
e342f2dd60 Update the minimum supported CuDNN version to 9.8 (previously 9.1).
Announce maximum supported CUDA version 12.8 (previously 12.3).

PiperOrigin-RevId: 741188737
2025-03-27 09:54:00 -07:00
Peter Hawkins
9932ff1f79 Deprecate the contents of jax.lib.xla_extension.
PiperOrigin-RevId: 741145943
2025-03-27 07:28:25 -07:00
Daniel Suo
3a593219d4 [jaxlib:cpu] Cleaning up after callback FFI refactor.
PiperOrigin-RevId: 740547947
2025-03-25 17:41:53 -07:00
Jake VanderPlas
a58592ebb0 Finalize some deprecations from jax.lib.xla_client 2025-03-25 06:46:19 -07:00
Sergei Lebedev
92f5d9caa3 Deprecated jax.tree_util.build_tree
We have no usages of it neither in JAX nor internally, but we still have to
go through the deprecation cycle, becuase `jax.tree_util` is public API.

PiperOrigin-RevId: 739196514
2025-03-21 08:54:38 -07:00
Peter Hawkins
9d534ad2cd Update version numbers after JAX 0.5.3 release. 2025-03-19 14:41:25 -04:00
Peter Hawkins
ed43119a86 JAX release v0.5.3 2025-03-18 21:38:14 -04:00
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
jax authors
d55879723e Merge pull request #26840 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 735513976
2025-03-10 14:33:14 -07:00
Skye Wanderman-Milne
a6c858f04b Merge branch 'release/0.5.2' into main 2025-03-04 18:47:20 -08:00
Skye Wanderman-Milne
ce224293b1 Prepare for JAX release 0.5.2 (patch release over 0.5.1) 2025-03-04 12:59:24 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
Anton Osokin
1f3176636d Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00
Dan Foreman-Mackey
bb9aed5eec Reimplement custom_vjp.optimize_remat using custom_dce. 2025-02-28 10:00:28 -05:00
rajasekharporeddy
9c18e8dcc1 Remove duplicate JAX version 0.4.37 heading in changelog 2025-02-28 12:32:00 +05:30
Peter Hawkins
1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
Peter Hawkins
c8c4cfa04e Update version numbers after 0.5.1 release. 2025-02-24 16:18:25 -05:00
Yash Katariya
07440f4afa Prepare for JAX release 0.5.1 2025-02-24 10:59:04 -05:00
Skye Wanderman-Milne
d5d43fc46e Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
2025-02-14 12:01:55 -08:00
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Skye Wanderman-Milne
f07243a73a Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 724076284
2025-02-06 14:30:36 -08:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Peter Hawkins
b1a2c27aa0 Remove libtpu-nightly dependency from jax[tpu].
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].
2025-02-04 20:59:30 -05:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
Skye Wanderman-Milne
2aa810fe60 Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Dan Foreman-Mackey
782138fb6f Add custom_dce to changelogs and API docs. 2025-01-27 13:03:34 -05:00