551 Commits

Author SHA1 Message Date
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
George Necula
c0c918aa8b [export] Increase minimum serialization version to 9.
Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024.

This change could break clients that set a specific JAX serialization version lower than 9.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 619588685
2024-03-27 11:06:23 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
George Necula
ca59971bef [host_callback] Deprecate the jax.experimental.host_callback module. 2024-03-21 09:11:17 +02:00
Yue Sheng
147c363ea6 Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
154403c03d Finalize deprecations of jax.interpreters.ad config & source_info_util
These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError.

PiperOrigin-RevId: 616931120
2024-03-18 13:33:17 -07:00
Peter Hawkins
ee2631e4da Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
2024-03-18 11:46:06 -07:00
rajasekharporeddy
e94299c946
Fix Typos in CHANGELOG.md
This PR fixes the typos in Change log documentation
2024-03-12 13:57:07 +05:30
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Jake VanderPlas
c2d07a6623 Finalize deprecation of non-array arguments to array_equal/array_equiv 2024-02-29 05:31:37 -08:00
Jake VanderPlas
236275ebe1 Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3

PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
e0fd29082d Finish jax and jaxlib 0.4.25 release
PiperOrigin-RevId: 610413312
2024-02-26 08:19:05 -08:00
Yash Katariya
f4045dceb2 Remove the deprecation of jax.tree_map for the release of 0.4.25
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
jax authors
be002b5f1c Merge pull request #19930 from jakevdp:dep-tree_map
PiperOrigin-RevId: 609508069
2024-02-22 15:01:35 -08:00
Jake VanderPlas
a5abe4568d Mention re-instated xla APIs in the CHANGELOG 2024-02-22 12:19:29 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Sergei Lebedev
0bf8dddace Compile Triton kernels via XLA by default
PiperOrigin-RevId: 609299269
2024-02-22 02:32:26 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
Jake VanderPlas
6ffea0ba1f tree_transpose: optionally infer inner_treedef 2024-02-15 12:01:21 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
1b08d220f6 Fix jax 0.4.24 changelog 2024-02-09 09:51:41 -08:00
Yash Katariya
73e3dedf9d Update changelog and install doc to mention cuda 12.3 switch
PiperOrigin-RevId: 605473026
2024-02-08 17:21:56 -08:00
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
jax authors
136ab066b3 Merge pull request #19681 from skye:version
PiperOrigin-RevId: 604809179
2024-02-06 17:04:53 -08:00
Skye Wanderman-Milne
b93772fc24 Update version numbers post-0.4.24 release 2024-02-06 16:28:23 -08:00
Jake VanderPlas
35c0f64836 jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 2024-02-06 11:37:42 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
George Necula
af306680d0 [shape_poly] Forgot to update CHANGELOG for #19235. 2024-01-23 17:32:59 +01:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
jax authors
aac996c4db Merge pull request #19390 from jakevdp:jnp-sign
PiperOrigin-RevId: 599203136
2024-01-17 09:48:17 -08:00
Peter Hawkins
c4368351d2 Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
2024-01-17 09:30:42 -08:00
Jake VanderPlas
fb56224ae0 jnp.sign: use x/abs(x) for complex arguments 2024-01-17 08:59:40 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00
Jake VanderPlas
fa6d3f26ff jnp.unique: make return_inverse shape match NumPy 2.0 2024-01-16 11:47:45 -08:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00