624 Commits

Author SHA1 Message Date
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Yash Katariya
045ea944c8 Finish jax and jaxlib 0.4.30 release
PiperOrigin-RevId: 644402635
2024-06-18 08:56:39 -07:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
9d5932a190 Deprecate passing of arrays in place of dtypes. 2024-06-14 05:40:04 -07:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
8f5f8df112 Merge pull request #21863 from jakevdp:dep-tracer-hash
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
jax authors
5aedafc214 Merge pull request #21847 from gnecula:export_deprecate
PiperOrigin-RevId: 643099957
2024-06-13 13:17:29 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Yash Katariya
023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
jax authors
27140fe6de Merge pull request #21772 from jakevdp:beta-dep
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77 jax.scipy.special.beta: deprecate x,y in favor of a,b 2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5 Update release notes for 0.4.29 release. 2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Peter Hawkins
441ab58a58 Add note to release notes about #21403.
Fixes #21403
2024-05-24 10:09:13 -04:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Jake VanderPlas
568987af23 Finalize deprecation of batched keys to PRNG functions
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Jake VanderPlas
4bac10e750 Finalize deprecation of the config module.
To configure JAX, use `import jax` and reference the config object via `jax.config`.

PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Yue Sheng
66a92c41f6 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755 Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv 2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d Async dispatch expensive computations on the JAX CPU backend.
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226 Finish jax and jaxlib 0.4.28 release
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15 Prepare 0.4.28 release. 2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d CHANGELOG: move change from 0.4.27 to 0.4.28 2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
9b79f6520a Remove deprecated kind argument from jnp.sort and jnp.argsort.
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Yash Katariya
70b4477296 Start jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
eced12d89b Finalize deprecation of lax.linalg positional args
PiperOrigin-RevId: 629581163
2024-04-30 17:56:18 -07:00
Jake VanderPlas
ba540ca735 Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
fad2c0e315 Merge pull request #20858 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Jake VanderPlas
cbe48cad1e Finalize deprecation of arr.device_buffer and arr.device_buffers
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
493698e6e0 Merge pull request #20195 from Micky774:array_api_astype
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070 Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00