jax authors
8f5f8df112
Merge pull request #21863 from jakevdp:dep-tracer-hash
...
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
0bcc81ceb3
Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
...
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
jax authors
5aedafc214
Merge pull request #21847 from gnecula:export_deprecate
...
PiperOrigin-RevId: 643099957
2024-06-13 13:17:29 -07:00
Jake VanderPlas
0a86e9a929
Deprecate hashing of tracers
2024-06-13 13:14:27 -07:00
Yash Katariya
023bc7856b
Add registration handler for TPU v5e in mesh_utils.
...
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
George Necula
7af03a8fd1
[export] Deprecate jax.experimental.export
...
And announce jax.export.
While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
f63b94574a
Deprecate internal pretty-printing APIs, jax.core.pp_*
2024-06-13 09:44:56 -07:00
Peter Hawkins
b13733c13f
Update JAX dependencies, extras, and documentation for plugins.
...
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Yash Katariya
b1f7627c71
[Rollback] Bumped the minimum ml_dtypes version to 0.4.0
...
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b
PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
jax authors
27140fe6de
Merge pull request #21772 from jakevdp:beta-dep
...
PiperOrigin-RevId: 642275316
2024-06-11 08:15:58 -07:00
jax authors
27de85439e
Merge pull request #21781 from hawkinsp:release
...
PiperOrigin-RevId: 641994356
2024-06-10 12:56:31 -07:00
Peter Hawkins
6fa31e59c4
Update version numbers after v0.4.29 release.
2024-06-10 14:37:53 -04:00
Jake VanderPlas
814b32a44b
tree_all: add support for is_leaf
2024-06-10 09:46:15 -07:00
Jake VanderPlas
990b475b77
jax.scipy.special.beta: deprecate x,y in favor of a,b
2024-06-10 09:01:39 -07:00
Peter Hawkins
a8246ea67f
Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
...
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.
In a future release of JAX, this behavior will become an error.
PiperOrigin-RevId: 641690427
2024-06-09 09:18:29 -07:00
George Necula
3914cb415d
[export] Remove old deprecated APIs for jax.experimental.export.
...
See CHANGELOG.md.
The deprecation period has passed.
Also replace deprecated .call_exported with .call in tests.
PiperOrigin-RevId: 641236222
2024-06-07 06:52:10 -07:00
George Necula
01ee768f73
[export] Rename in_shardings and out_shardings fields.
...
We rename `in_shardings` to `in_shardings_hlo` to remove confusion
with JAX's use of `in_shardings`.
We also rename `xla_compatible_in_sharding` to `in_shardings_jax`
since we do not have a XLACompatibleSharding type anymore.
2024-06-06 22:00:16 +01:00
Peter Hawkins
09448384e5
Update release notes for 0.4.29 release.
2024-06-06 11:13:14 -04:00
Yash Katariya
1edd649de4
Deprecate XLACompatibleSharding
in favor of jax.sharding.Sharding
.
...
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Peter Hawkins
441ab58a58
Add note to release notes about #21403 .
...
Fixes #21403
2024-05-24 10:09:13 -04:00
Sergei Lebedev
0a694a1b42
Bumped the minimum ml_dtypes version to 0.4.0
2024-05-23 21:51:00 +01:00
Jake VanderPlas
568987af23
Finalize deprecation of batched keys to PRNG functions
...
PiperOrigin-RevId: 636196573
2024-05-22 09:40:32 -07:00
Jake VanderPlas
4bac10e750
Finalize deprecation of the config module.
...
To configure JAX, use `import jax` and reference the config object via `jax.config`.
PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Yue Sheng
66a92c41f6
Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
...
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755
Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv
2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09
Finalize deprecations of several APIs
...
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d
Async dispatch expensive computations on the JAX CPU backend.
...
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226
Finish jax and jaxlib 0.4.28 release
...
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69
Deprecate newshape argument of jnp.reshape
2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15
Prepare 0.4.28 release.
2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d
[XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
...
This works around https://github.com/python/cpython/issues/89478 , which was fixed in Python 3.11.
PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0
Removed get_compute_capability from jax.experimental.pallas.gpu
...
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d
CHANGELOG: move change from 0.4.27 to 0.4.28
2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4
Finish jax and jaxlib 0.4.27 release
...
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
9b79f6520a
Remove deprecated kind
argument from jnp.sort
and jnp.argsort
.
...
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Yash Katariya
70b4477296
Start jax and jaxlib 0.4.27 release
...
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Jake VanderPlas
e95173a4d3
Require arraylike input for several jax.numpy functions
...
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Roy Frostig
3f9540761e
reintroduce the Threefry GPU kernel lowering, under a flag
...
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`
PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
eced12d89b
Finalize deprecation of lax.linalg positional args
...
PiperOrigin-RevId: 629581163
2024-04-30 17:56:18 -07:00
Jake VanderPlas
ba540ca735
Finalize deprecation of jnp.where keyword arguments
...
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
fad2c0e315
Merge pull request #20858 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Jake VanderPlas
cbe48cad1e
Finalize deprecation of arr.device_buffer
and arr.device_buffers
...
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
493698e6e0
Merge pull request #20195 from Micky774:array_api_astype
...
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070
Merge pull request #20317 from inailuig:mpi_collectives
...
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
Meekail Zain
30cd3b88fd
Add support for copy kwarg in astype to match Array API
2024-04-22 16:25:37 +00:00
rajasekharporeddy
aaddba0c20
Fix doc Typos
2024-04-22 10:32:51 +05:30
Sergei Lebedev
6e23c14f85
jax.debug.callback now passes arguments as jax.Arrays
...
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes #20627 .
PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb
Add a common flag for the collectives implementations on cpu.
2024-04-19 20:55:35 +02:00
Jake VanderPlas
41fa67c2dc
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
...
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yue Sheng
c2d4373535
Make core.Token
a non-trivial class which wraps a jax.Array
. Currently, we use a singleton and empty core.token
object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
...
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).
PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00