778 Commits

Author SHA1 Message Date
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
jax authors
d55879723e Merge pull request #26840 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 735513976
2025-03-10 14:33:14 -07:00
Skye Wanderman-Milne
a6c858f04b Merge branch 'release/0.5.2' into main 2025-03-04 18:47:20 -08:00
Skye Wanderman-Milne
ce224293b1 Prepare for JAX release 0.5.2 (patch release over 0.5.1) 2025-03-04 12:59:24 -08:00
Jake VanderPlas
8cec6e636a jax.numpy ndim/shape/size: deprecate non-array input 2025-03-04 10:42:32 -08:00
Anton Osokin
1f3176636d Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00
Dan Foreman-Mackey
bb9aed5eec Reimplement custom_vjp.optimize_remat using custom_dce. 2025-02-28 10:00:28 -05:00
rajasekharporeddy
9c18e8dcc1 Remove duplicate JAX version 0.4.37 heading in changelog 2025-02-28 12:32:00 +05:30
Peter Hawkins
1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
Peter Hawkins
c8c4cfa04e Update version numbers after 0.5.1 release. 2025-02-24 16:18:25 -05:00
Yash Katariya
07440f4afa Prepare for JAX release 0.5.1 2025-02-24 10:59:04 -05:00
Skye Wanderman-Milne
d5d43fc46e Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
2025-02-14 12:01:55 -08:00
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Skye Wanderman-Milne
f07243a73a Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 724076284
2025-02-06 14:30:36 -08:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Peter Hawkins
b1a2c27aa0 Remove libtpu-nightly dependency from jax[tpu].
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].
2025-02-04 20:59:30 -05:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
Skye Wanderman-Milne
2aa810fe60 Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Dan Foreman-Mackey
782138fb6f Add custom_dce to changelogs and API docs. 2025-01-27 13:03:34 -05:00
Peter Hawkins
9fa2912254 Update version numbers after 0.5.0 release 2025-01-17 13:30:59 -05:00
Peter Hawkins
c25fb92c44 Release JAX 0.5.0 2025-01-17 10:28:03 -05:00
Peter Hawkins
3a8f31aa83 Update the JAX version to 0.5.0.
This is because of the breaking change to PRNG key semantics, and the version follows JAX's new effver versioning scheme (https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
2025-01-15 14:08:15 -05:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
392a851769 Increase the minimum SciPy version to 1.11.1.
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Jake VanderPlas
40fe4b8797 Finalize deprecation of some symbols from jax.lib.xla_client 2024-12-23 10:14:16 -08:00
Jake VanderPlas
c206ae7fe8 changelog: link to api compatibility & python version docs 2024-12-23 09:39:45 -08:00
Dan Foreman-Mackey
c6131ee527 Add support for N-D FFTs with D>3. 2024-12-19 15:23:30 +00:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
Yash Katariya
8b734808e8 Remove jax_enable_memories config flag. It defaulted to True for a very long time and it's time to remove the flag.
PiperOrigin-RevId: 707590263
2024-12-18 10:15:45 -08:00
Peter Hawkins
ee45718457 Increase the minimum NumPy version to v1.25.
Per SPEC 0, we drop NumPy v1.24 support on Dec 18, 2024.
2024-12-18 08:18:57 -05:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Peter Hawkins
ff52aedf67 Update version numbers after release. 2024-12-17 18:16:25 -05:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
Jake VanderPlas
c73f306099 Finalize deprecation of jnp.round_
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00