762 Commits

Author SHA1 Message Date
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Skye Wanderman-Milne
f07243a73a Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
This enables CPU collectives by default, making multi-process CPU
communication work without extra configuration.

PiperOrigin-RevId: 724076284
2025-02-06 14:30:36 -08:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
Peter Hawkins
b1a2c27aa0 Remove libtpu-nightly dependency from jax[tpu].
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].
2025-02-04 20:59:30 -05:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
Skye Wanderman-Milne
2aa810fe60 Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Dan Foreman-Mackey
782138fb6f Add custom_dce to changelogs and API docs. 2025-01-27 13:03:34 -05:00
Peter Hawkins
9fa2912254 Update version numbers after 0.5.0 release 2025-01-17 13:30:59 -05:00
Peter Hawkins
c25fb92c44 Release JAX 0.5.0 2025-01-17 10:28:03 -05:00
Peter Hawkins
3a8f31aa83 Update the JAX version to 0.5.0.
This is because of the breaking change to PRNG key semantics, and the version follows JAX's new effver versioning scheme (https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
2025-01-15 14:08:15 -05:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
392a851769 Increase the minimum SciPy version to 1.11.1.
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Jake VanderPlas
40fe4b8797 Finalize deprecation of some symbols from jax.lib.xla_client 2024-12-23 10:14:16 -08:00
Jake VanderPlas
c206ae7fe8 changelog: link to api compatibility & python version docs 2024-12-23 09:39:45 -08:00
Dan Foreman-Mackey
c6131ee527 Add support for N-D FFTs with D>3. 2024-12-19 15:23:30 +00:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
Yash Katariya
8b734808e8 Remove jax_enable_memories config flag. It defaulted to True for a very long time and it's time to remove the flag.
PiperOrigin-RevId: 707590263
2024-12-18 10:15:45 -08:00
Peter Hawkins
ee45718457 Increase the minimum NumPy version to v1.25.
Per SPEC 0, we drop NumPy v1.24 support on Dec 18, 2024.
2024-12-18 08:18:57 -05:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Peter Hawkins
ff52aedf67 Update version numbers after release. 2024-12-17 18:16:25 -05:00
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
Jake VanderPlas
c73f306099 Finalize deprecation of jnp.round_
PiperOrigin-RevId: 705998500
2024-12-13 14:13:44 -08:00
jax authors
87b66f3c35 Merge pull request #25451 from jakevdp:undep-core
PiperOrigin-RevId: 705910242
2024-12-13 09:36:32 -08:00
Ivy Zheng
26c40fadfd Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
PiperOrigin-RevId: 705645570
2024-12-12 15:13:32 -08:00
Jake VanderPlas
d3406768f0 temporarily un-deprecate several jax.core APIs.
These were causing excessive log-spam for some users; I'll work to migrate
them to jax.extend before re-deprecating these.
2024-12-12 13:15:58 -08:00
Jake VanderPlas
f858a71461 Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client. 2024-12-11 09:50:33 -08:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
Peter Hawkins
820f51dc53 Merge branch 'release/0.4.37' into main. 2024-12-09 20:21:43 -05:00
Peter Hawkins
ffb07cdadb Update versions for v0.4.37 release. 2024-12-09 15:39:59 -05:00
IvyZX
65b6088411 Avoid index out of range error in carry structure check 2024-12-09 15:36:32 -05:00
IvyZX
bd77a703fd Avoid index out of range error in carry structure check 2024-12-09 10:44:28 -08:00
Peter Hawkins
ba626fa650 Bump JAX version after release.
PiperOrigin-RevId: 703472753
2024-12-06 05:58:09 -08:00
George Necula
3f5f3e1c47 [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
2024-12-05 08:42:41 -08:00
George Necula
5fe5206b6a [shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
2024-12-05 08:02:38 -08:00
labs-code-app[bot]
762301fc5d Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
2024-11-26 13:57:47 +00:00
Nitin Srinivasan
6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00
Peter Hawkins
dfe27a1682 Mention stackless in the release notes. 2024-11-20 14:53:52 -05:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00