649 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Yue Sheng
f255fb700a Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Jake VanderPlas
3551fcc077 Deprecate several APIs in jax.lib.xla_bridge
PiperOrigin-RevId: 658274719
2024-07-31 23:00:35 -07:00
Jake VanderPlas
8bcd288621 Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.

PiperOrigin-RevId: 657717875
2024-07-30 13:49:37 -07:00
jax authors
256956ad58 Merge pull request #22704 from gnecula:pallas_better_errors
PiperOrigin-RevId: 657571604
2024-07-30 06:39:44 -07:00
Peter Hawkins
c1cd7f9e2d Drop support for mhlo in JAX's public API.
PiperOrigin-RevId: 657551590
2024-07-30 05:29:52 -07:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
2106a25977 Finish jax and jaxlib v0.4.31 release
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Peter Hawkins
fd23b8733d Bump minimum SciPy version to 1.10.
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago

PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
Peter Hawkins
47e6da3332 Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices.
The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes https://github.com/google/jax/issues/3589
Fixes https://github.com/google/jax/issues/15429

PiperOrigin-RevId: 653562611
2024-07-18 04:09:44 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
Gleb Pobudzey
46103f6ff3 Updated the repr of GPU devices to be more consistent with TPUs/CPUs.
For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
2024-07-11 06:54:20 -07:00
Peter Hawkins
262a4f482c Deprecate support for custom lowering rules that return tuple-wrapped ir.Values.
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.

PiperOrigin-RevId: 650345051
2024-07-08 12:54:44 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Jake VanderPlas
fbcb157ad3 Finalize deprecation of several previously-deprecated jax.core functions:
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
2024-06-28 07:28:28 -07:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
vfdev-5
70b4823348 Updated jnp.ceil/floor/trunc to preserve int dtypes
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Yash Katariya
045ea944c8 Finish jax and jaxlib 0.4.30 release
PiperOrigin-RevId: 644402635
2024-06-18 08:56:39 -07:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
9d5932a190 Deprecate passing of arrays in place of dtypes. 2024-06-14 05:40:04 -07:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
8f5f8df112 Merge pull request #21863 from jakevdp:dep-tracer-hash
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
jax authors
5aedafc214 Merge pull request #21847 from gnecula:export_deprecate
PiperOrigin-RevId: 643099957
2024-06-13 13:17:29 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Yash Katariya
023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
jax authors
27140fe6de Merge pull request #21772 from jakevdp:beta-dep
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77 jax.scipy.special.beta: deprecate x,y in favor of a,b 2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5 Update release notes for 0.4.29 release. 2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00