609 Commits

Author SHA1 Message Date
jax authors
27140fe6de Merge pull request #21772 from jakevdp:beta-dep
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e Merge pull request #21781 from hawkinsp:release
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4 Update version numbers after v0.4.29 release. 2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77 jax.scipy.special.beta: deprecate x,y in favor of a,b 2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

In a future release of JAX, this behavior will become an error.

PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d [export] Remove old deprecated APIs for jax.experimental.export.
See CHANGELOG.md.
The deprecation period has passed.

Also replace deprecated .call_exported with .call in tests.

PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73 [export] Rename in_shardings and out_shardings fields.
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5 Update release notes for 0.4.29 release. 2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Peter Hawkins
441ab58a58 Add note to release notes about #21403.
Fixes #21403
2024-05-24 10:09:13 -04:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Jake VanderPlas
568987af23 Finalize deprecation of batched keys to PRNG functions
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Jake VanderPlas
4bac10e750 Finalize deprecation of the config module.
To configure JAX, use `import jax` and reference the config object via `jax.config`.

PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Yue Sheng
66a92c41f6 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755 Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv 2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d Async dispatch expensive computations on the JAX CPU backend.
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226 Finish jax and jaxlib 0.4.28 release
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15 Prepare 0.4.28 release. 2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d CHANGELOG: move change from 0.4.27 to 0.4.28 2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
9b79f6520a Remove deprecated kind argument from jnp.sort and jnp.argsort.
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Yash Katariya
70b4477296 Start jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
eced12d89b Finalize deprecation of lax.linalg positional args
PiperOrigin-RevId: 629581163
2024-04-30 17:56:18 -07:00
Jake VanderPlas
ba540ca735 Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
fad2c0e315 Merge pull request #20858 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Jake VanderPlas
cbe48cad1e Finalize deprecation of arr.device_buffer and arr.device_buffers
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
493698e6e0 Merge pull request #20195 from Micky774:array_api_astype
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070 Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
Meekail Zain
30cd3b88fd Add support for copy kwarg in astype to match Array API 2024-04-22 16:25:37 +00:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb Add a common flag for the collectives implementations on cpu. 2024-04-19 20:55:35 +02:00
Jake VanderPlas
41fa67c2dc Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3 Async dispatch expensive computations on the JAX CPU backend.
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.

Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.

PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Jake VanderPlas
462e5c603a Finalize deprecation of invalid JIT argument names & numbers
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.

PiperOrigin-RevId: 624270701
2024-04-12 13:09:17 -07:00
jax authors
4331abecff Merge pull request #20603 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 624221601
2024-04-12 10:30:01 -07:00
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Jake VanderPlas
e07325a672 Make complex_arr.astype(bool) follow NumPy's semantics 2024-04-09 16:15:59 -07:00