595 Commits

Author SHA1 Message Date
Yue Sheng
66a92c41f6 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678
PiperOrigin-RevId: 633816901
2024-05-14 22:41:44 -07:00
Meekail Zain
5cc255b755 Rename rcond/tol to rtol in linalg.matrix_rank and linalg.pinv 2024-05-14 19:53:54 +00:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
9e7830df2d Async dispatch expensive computations on the JAX CPU backend.
By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.

PiperOrigin-RevId: 633264117
2024-05-13 10:53:09 -07:00
Jake VanderPlas
9ac1d38226 Finish jax and jaxlib 0.4.28 release
PiperOrigin-RevId: 632653310
2024-05-10 18:06:52 -07:00
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
Peter Hawkins
038dfeec15 Prepare 0.4.28 release. 2024-05-09 19:25:33 +00:00
Peter Hawkins
168f40ee3d [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier.
This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11.

PiperOrigin-RevId: 631984256
2024-05-08 18:05:28 -07:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Jake VanderPlas
c18851b65d CHANGELOG: move change from 0.4.27 to 0.4.28 2024-05-07 11:16:11 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
9b79f6520a Remove deprecated kind argument from jnp.sort and jnp.argsort.
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
Yash Katariya
70b4477296 Start jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Roy Frostig
3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
Jake VanderPlas
eced12d89b Finalize deprecation of lax.linalg positional args
PiperOrigin-RevId: 629581163
2024-04-30 17:56:18 -07:00
Jake VanderPlas
ba540ca735 Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
fad2c0e315 Merge pull request #20858 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 628061707
2024-04-25 06:58:27 -07:00
Jake VanderPlas
cbe48cad1e Finalize deprecation of arr.device_buffer and arr.device_buffers
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
jax authors
493698e6e0 Merge pull request #20195 from Micky774:array_api_astype
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070 Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
Meekail Zain
30cd3b88fd Add support for copy kwarg in astype to match Array API 2024-04-22 16:25:37 +00:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb Add a common flag for the collectives implementations on cpu. 2024-04-19 20:55:35 +02:00
Jake VanderPlas
41fa67c2dc Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3 Async dispatch expensive computations on the JAX CPU backend.
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.

Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.

PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Jake VanderPlas
462e5c603a Finalize deprecation of invalid JIT argument names & numbers
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.

PiperOrigin-RevId: 624270701
2024-04-12 13:09:17 -07:00
jax authors
4331abecff Merge pull request #20603 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 624221601
2024-04-12 10:30:01 -07:00
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Jake VanderPlas
e07325a672 Make complex_arr.astype(bool) follow NumPy's semantics 2024-04-09 16:15:59 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Sergei Lebedev
9616900cc9 jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold

* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
  of always copying it to the host. Note that the version here still always
  copies to the host.

If this change breaks you, you can recover the old behavior by changing

    jax.pure_callback(
        f,
        result_shape_dtypes,
        *args,
        **kwargs,
    )

to

    jax.pure_callback(
        lambda *args: f(*jax.tree.map(np.asarray, args)),
        result_shape_dtypes,
        *args,
        **kwargs,
    )

so that the callback function is called with NumPy arrays as before.

I will update the "External callbacks" tutorial in a follow up.

PiperOrigin-RevId: 622457378
2024-04-06 09:30:08 -07:00
rajasekharporeddy
0d68a1a82d Fix doc typos 2024-04-05 14:21:33 +05:30
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Roy Frostig
f247822977 changelog: note doc change to use jax.random.key over PRNGKey 2024-04-04 16:38:08 -07:00
Roy Frostig
2a36d75285 changelog: batching rule change for rng_bit_generator 2024-04-04 16:34:10 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
George Necula
c0c918aa8b [export] Increase minimum serialization version to 9.
Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024.

This change could break clients that set a specific JAX serialization version lower than 9.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 619588685
2024-03-27 11:06:23 -07:00