carlosgmartin
3f59fa6888
Add replace option to random.categorical to enable sampling without replacement.
2025-03-17 13:41:46 -04:00
jax authors
bf829ff612
Merge pull request #26524 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
carlosgmartin
6b69a136aa
Add jax.random.multinomial.
2025-03-12 18:15:14 -04:00
jax authors
d55879723e
Merge pull request #26840 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 735513976
2025-03-10 14:33:14 -07:00
Skye Wanderman-Milne
a6c858f04b
Merge branch 'release/0.5.2' into main
2025-03-04 18:47:20 -08:00
Skye Wanderman-Milne
ce224293b1
Prepare for JAX release 0.5.2 (patch release over 0.5.1)
2025-03-04 12:59:24 -08:00
Jake VanderPlas
8cec6e636a
jax.numpy ndim/shape/size: deprecate non-array input
2025-03-04 10:42:32 -08:00
Anton Osokin
1f3176636d
Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
...
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00
Dan Foreman-Mackey
bb9aed5eec
Reimplement custom_vjp.optimize_remat using custom_dce.
2025-02-28 10:00:28 -05:00
rajasekharporeddy
9c18e8dcc1
Remove duplicate JAX version 0.4.37 heading in changelog
2025-02-28 12:32:00 +05:30
Peter Hawkins
1e5d9a9158
Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
...
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
Peter Hawkins
c8c4cfa04e
Update version numbers after 0.5.1 release.
2025-02-24 16:18:25 -05:00
Yash Katariya
07440f4afa
Prepare for JAX release 0.5.1
2025-02-24 10:59:04 -05:00
Skye Wanderman-Milne
d5d43fc46e
Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
...
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
2025-02-14 12:01:55 -08:00
George Necula
a0812cd57e
[better_errors] Make it explicit that debug_info is not None.
...
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
tttc3
b1b56ea0b0
Enable pivoted QR on GPU via MAGMA.
...
Originally noted in #20282 , this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Jake VanderPlas
e389b707ba
Add public APIs for jax.lax monoidal reductions
2025-02-11 16:00:03 -08:00
Skye Wanderman-Milne
f07243a73a
Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
...
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.
PiperOrigin-RevId: 724076284
2025-02-06 14:30:36 -08:00
Jake VanderPlas
e4dac395a5
Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
...
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653
Reverts 95535df13b422284043623ca3a6d2a5962116fb1
PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Peter Hawkins
b1a2c27aa0
Remove libtpu-nightly dependency from jax[tpu].
...
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].
2025-02-04 20:59:30 -05:00
jax authors
95535df13b
Merge pull request #25688 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
carlosgmartin
32411a430f
Add jax.random.multinomial.
2025-01-31 18:45:55 -05:00
Skye Wanderman-Milne
2aa810fe60
Make JAX_CPU_COLLECTIVES_IMPLEMENTATION
and JAX_NUM_CPU_DEVICES
env vars
...
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Dan Foreman-Mackey
782138fb6f
Add custom_dce to changelogs and API docs.
2025-01-27 13:03:34 -05:00
Peter Hawkins
9fa2912254
Update version numbers after 0.5.0 release
2025-01-17 13:30:59 -05:00
Peter Hawkins
c25fb92c44
Release JAX 0.5.0
2025-01-17 10:28:03 -05:00
Peter Hawkins
3a8f31aa83
Update the JAX version to 0.5.0.
...
This is because of the breaking change to PRNG key semantics, and the version follows JAX's new effver versioning scheme (https://jax.readthedocs.io/en/latest/jep/25516-effver.html ).
2025-01-15 14:08:15 -05:00
Zac Mustin
2d72e8de84
Jax: Stop returning a list of cost-analyses.
...
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.
This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available )) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.
PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
Roy Frostig
a60ead6fd1
enable partitionable threefry by default
...
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d
jnp.linalg.solve: finalize deprecation of batched 1D solves
2025-01-10 10:42:32 -08:00
George Necula
dd0447a7c6
[aot] Add support for as_text(debug_info=True).
...
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
392a851769
Increase the minimum SciPy version to 1.11.1.
...
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)
PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Dan Foreman-Mackey
a7f384cc6e
Add a register_custom_type_id function to the GPU plugins.
...
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.
PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
56f0f9534d
Merge pull request #25633 from dfm:move-ffi
...
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd
Remove deprecated jax.experimental.array_api
2025-01-06 15:19:02 -08:00
Jake VanderPlas
2f7204fff6
jnp.einsum: default to optimize='auto'
2025-01-06 11:02:31 -08:00
Jake VanderPlas
245a13a329
Deprecate scipy.special.lpmn & lpmn_values
2025-01-06 09:31:15 -08:00
Dan Foreman-Mackey
cb4d97aa1f
Move jex.ffi to jax.ffi.
2024-12-29 13:06:19 +00:00
Jake VanderPlas
40fe4b8797
Finalize deprecation of some symbols from jax.lib.xla_client
2024-12-23 10:14:16 -08:00
Jake VanderPlas
c206ae7fe8
changelog: link to api compatibility & python version docs
2024-12-23 09:39:45 -08:00
Dan Foreman-Mackey
c6131ee527
Add support for N-D FFTs with D>3.
2024-12-19 15:23:30 +00:00
Jake VanderPlas
89a54a9e85
Re-land changes from https://github.com/jax-ml/jax/pull/25555
...
Reverts 25524abc67d82281e8a4093480637785c03a0150
PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
Yash Katariya
8b734808e8
Remove jax_enable_memories config flag. It defaulted to True for a very long time and it's time to remove the flag.
...
PiperOrigin-RevId: 707590263
2024-12-18 10:15:45 -08:00
Peter Hawkins
ee45718457
Increase the minimum NumPy version to v1.25.
...
Per SPEC 0, we drop NumPy v1.24 support on Dec 18, 2024.
2024-12-18 08:18:57 -05:00
jax authors
25524abc67
Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
...
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2
Remove core.concrete_aval and replace with abstractify
2024-12-17 18:18:25 -08:00
Peter Hawkins
ff52aedf67
Update version numbers after release.
2024-12-17 18:16:25 -05:00
Peter Hawkins
7de9eb20df
Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
...
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
George Necula
afcb62ea20
[export] Expand exporting to work with AbstractMesh.
...
This is a follow up from #25640 that enabled lowering with
AbstractMesh.
This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
Jake VanderPlas
c73f306099
Finalize deprecation of jnp.round_
...
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00