555 Commits

Author SHA1 Message Date
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Roy Frostig
f247822977 changelog: note doc change to use jax.random.key over PRNGKey 2024-04-04 16:38:08 -07:00
Roy Frostig
2a36d75285 changelog: batching rule change for rng_bit_generator 2024-04-04 16:34:10 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
George Necula
c0c918aa8b [export] Increase minimum serialization version to 9.
Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024.

This change could break clients that set a specific JAX serialization version lower than 9.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 619588685
2024-03-27 11:06:23 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
George Necula
ca59971bef [host_callback] Deprecate the jax.experimental.host_callback module. 2024-03-21 09:11:17 +02:00
Yue Sheng
147c363ea6 Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
154403c03d Finalize deprecations of jax.interpreters.ad config & source_info_util
These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError.

PiperOrigin-RevId: 616931120
2024-03-18 13:33:17 -07:00
Peter Hawkins
ee2631e4da Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
2024-03-18 11:46:06 -07:00
rajasekharporeddy
e94299c946
Fix Typos in CHANGELOG.md
This PR fixes the typos in Change log documentation
2024-03-12 13:57:07 +05:30
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Jake VanderPlas
c2d07a6623 Finalize deprecation of non-array arguments to array_equal/array_equiv 2024-02-29 05:31:37 -08:00
Jake VanderPlas
236275ebe1 Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3

PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
e0fd29082d Finish jax and jaxlib 0.4.25 release
PiperOrigin-RevId: 610413312
2024-02-26 08:19:05 -08:00
Yash Katariya
f4045dceb2 Remove the deprecation of jax.tree_map for the release of 0.4.25
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
jax authors
be002b5f1c Merge pull request #19930 from jakevdp:dep-tree_map
PiperOrigin-RevId: 609508069
2024-02-22 15:01:35 -08:00
Jake VanderPlas
a5abe4568d Mention re-instated xla APIs in the CHANGELOG 2024-02-22 12:19:29 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Sergei Lebedev
0bf8dddace Compile Triton kernels via XLA by default
PiperOrigin-RevId: 609299269
2024-02-22 02:32:26 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
Jake VanderPlas
6ffea0ba1f tree_transpose: optionally infer inner_treedef 2024-02-15 12:01:21 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
1b08d220f6 Fix jax 0.4.24 changelog 2024-02-09 09:51:41 -08:00
Yash Katariya
73e3dedf9d Update changelog and install doc to mention cuda 12.3 switch
PiperOrigin-RevId: 605473026
2024-02-08 17:21:56 -08:00
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
jax authors
136ab066b3 Merge pull request #19681 from skye:version
PiperOrigin-RevId: 604809179
2024-02-06 17:04:53 -08:00
Skye Wanderman-Milne
b93772fc24 Update version numbers post-0.4.24 release 2024-02-06 16:28:23 -08:00
Jake VanderPlas
35c0f64836 jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 2024-02-06 11:37:42 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
George Necula
af306680d0 [shape_poly] Forgot to update CHANGELOG for #19235. 2024-01-23 17:32:59 +01:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
jax authors
aac996c4db Merge pull request #19390 from jakevdp:jnp-sign
PiperOrigin-RevId: 599203136
2024-01-17 09:48:17 -08:00
Peter Hawkins
c4368351d2 Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
2024-01-17 09:30:42 -08:00
Jake VanderPlas
fb56224ae0 jnp.sign: use x/abs(x) for complex arguments 2024-01-17 08:59:40 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00