634 Commits

Author SHA1 Message Date
Peter Hawkins
47e6da3332 Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices.
The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes https://github.com/google/jax/issues/3589
Fixes https://github.com/google/jax/issues/15429

PiperOrigin-RevId: 653562611
2024-07-18 04:09:44 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
Gleb Pobudzey
46103f6ff3 Updated the repr of GPU devices to be more consistent with TPUs/CPUs.
For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
2024-07-11 06:54:20 -07:00
Peter Hawkins
262a4f482c Deprecate support for custom lowering rules that return tuple-wrapped ir.Values.
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.

PiperOrigin-RevId: 650345051
2024-07-08 12:54:44 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Jake VanderPlas
fbcb157ad3 Finalize deprecation of several previously-deprecated jax.core functions:
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
2024-06-28 07:28:28 -07:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
vfdev-5
70b4823348 Updated jnp.ceil/floor/trunc to preserve int dtypes
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Yash Katariya
045ea944c8 Finish jax and jaxlib 0.4.30 release
PiperOrigin-RevId: 644402635
2024-06-18 08:56:39 -07:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
9d5932a190 Deprecate passing of arrays in place of dtypes. 2024-06-14 05:40:04 -07:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
8f5f8df112 Merge pull request #21863 from jakevdp:dep-tracer-hash
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
jax authors
5aedafc214 Merge pull request #21847 from gnecula:export_deprecate
PiperOrigin-RevId: 643099957
2024-06-13 13:17:29 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Yash Katariya
023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
jax authors
27140fe6de Merge pull request #21772 from jakevdp:beta-dep
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77 jax.scipy.special.beta: deprecate x,y in favor of a,b 2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5 Update release notes for 0.4.29 release. 2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Peter Hawkins
441ab58a58 Add note to release notes about #21403.
Fixes #21403
2024-05-24 10:09:13 -04:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Jake VanderPlas
568987af23 Finalize deprecation of batched keys to PRNG functions
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Jake VanderPlas
4bac10e750 Finalize deprecation of the config module.
To configure JAX, use `import jax` and reference the config object via `jax.config`.

PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Yue Sheng
66a92c41f6 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755 Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv 2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d Async dispatch expensive computations on the JAX CPU backend.
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226 Finish jax and jaxlib 0.4.28 release
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15 Prepare 0.4.28 release. 2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d CHANGELOG: move change from 0.4.27 to 0.4.28 2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00