jax authors
27140fe6de
Merge pull request #21772 from jakevdp:beta-dep
...
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e
Merge pull request #21781 from hawkinsp:release
...
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4
Update version numbers after v0.4.29 release.
2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b
tree_all: add support for is_leaf
2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77
jax.scipy.special.beta: deprecate x,y in favor of a,b
2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f
Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
...
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.
In a future release of JAX, this behavior will become an error.
PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d
[export] Remove old deprecated APIs for jax.experimental.export.
...
See CHANGELOG.md.
The deprecation period has passed.
Also replace deprecated .call_exported with .call in tests.
PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73
[export] Rename in_shardings and out_shardings fields.
...
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5
Update release notes for 0.4.29 release.
2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4
Deprecate XLACompatibleSharding
in favor of jax.sharding.Sharding
.
...
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Peter Hawkins
441ab58a58
Add note to release notes about #21403 .
...
Fixes #21403
2024-05-24 10:09:13 -04:00
Sergei Lebedev
0a694a1b42
Bumped the minimum ml_dtypes version to 0.4.0
2024-05-23 21:51:00 +01:00
Jake VanderPlas
568987af23
Finalize deprecation of batched keys to PRNG functions
...
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Jake VanderPlas
4bac10e750
Finalize deprecation of the config module.
...
To configure JAX, use `import jax` and reference the config object via `jax.config`.
PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Yue Sheng
66a92c41f6
Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
...
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755
Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv
2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09
Finalize deprecations of several APIs
...
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d
Async dispatch expensive computations on the JAX CPU backend.
...
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226
Finish jax and jaxlib 0.4.28 release
...
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69
Deprecate newshape argument of jnp.reshape
2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15
Prepare 0.4.28 release.
2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d
[XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
...
This works around https://github.com/python/cpython/issues/89478 , which was fixed in Python 3.11.
PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0
Removed get_compute_capability from jax.experimental.pallas.gpu
...
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d
CHANGELOG: move change from 0.4.27 to 0.4.28
2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4
Finish jax and jaxlib 0.4.27 release
...
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
9b79f6520a
Remove deprecated kind
argument from jnp.sort
and jnp.argsort
.
...
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Yash Katariya
70b4477296
Start jax and jaxlib 0.4.27 release
...
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Jake VanderPlas
e95173a4d3
Require arraylike input for several jax.numpy functions
...
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Roy Frostig
3f9540761e
reintroduce the Threefry GPU kernel lowering, under a flag
...
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`
PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
eced12d89b
Finalize deprecation of lax.linalg positional args
...
PiperOrigin-RevId: 629581163
2024-04-30 17:56:18 -07:00
Jake VanderPlas
ba540ca735
Finalize deprecation of jnp.where keyword arguments
...
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
fad2c0e315
Merge pull request #20858 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Jake VanderPlas
cbe48cad1e
Finalize deprecation of arr.device_buffer
and arr.device_buffers
...
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
493698e6e0
Merge pull request #20195 from Micky774:array_api_astype
...
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070
Merge pull request #20317 from inailuig:mpi_collectives
...
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
Meekail Zain
30cd3b88fd
Add support for copy kwarg in astype to match Array API
2024-04-22 16:25:37 +00:00
rajasekharporeddy
aaddba0c20
Fix doc Typos
2024-04-22 10:32:51 +05:30
Sergei Lebedev
6e23c14f85
jax.debug.callback now passes arguments as jax.Arrays
...
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes #20627 .
PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb
Add a common flag for the collectives implementations on cpu.
2024-04-19 20:55:35 +02:00
Jake VanderPlas
41fa67c2dc
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
...
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yue Sheng
c2d4373535
Make core.Token
a non-trivial class which wraps a jax.Array
. Currently, we use a singleton and empty core.token
object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
...
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).
PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Meekail Zain
ceeb975735
Add new cumulative_sum function to numpy and array_api
2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680
Add new unstack function to numpy/array_api namespaces
2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3
Async dispatch expensive computations on the JAX CPU backend.
...
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.
Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.
PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Meekail Zain
2899213efb
Fixed hypot bug on nan/inf pairings, began deprecation of non-real values
2024-04-15 17:56:16 +00:00
Sergei Lebedev
754fab91f7
Bumped the minimum jaxlib to 0.4.23
...
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Jake VanderPlas
462e5c603a
Finalize deprecation of invalid JIT argument names & numbers
...
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.
PiperOrigin-RevId: 624270701
2024-04-12 13:09:17 -07:00
jax authors
4331abecff
Merge pull request #20603 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 624221601
2024-04-12 10:30:01 -07:00
Jake VanderPlas
1ea205be1c
softmax: deprecate initial argument & always set to -inf internally
2024-04-10 10:23:21 -07:00
Jake VanderPlas
e07325a672
Make complex_arr.astype(bool) follow NumPy's semantics
2024-04-09 16:15:59 -07:00