575 Commits

Author SHA1 Message Date
jax authors
493698e6e0 Merge pull request #20195 from Micky774:array_api_astype
PiperOrigin-RevId: 627232885
2024-04-22 19:30:51 -07:00
jax authors
d20a2f1070 Merge pull request #20317 from inailuig:mpi_collectives
PiperOrigin-RevId: 627208382
2024-04-22 17:41:44 -07:00
Meekail Zain
30cd3b88fd Add support for copy kwarg in astype to match Array API 2024-04-22 16:25:37 +00:00
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Clemens Giuliani
60d4c43fcb Add a common flag for the collectives implementations on cpu. 2024-04-19 20:55:35 +02:00
Jake VanderPlas
41fa67c2dc Finalize deprecation of zero-dimensional inputs to jnp.nonzero
PiperOrigin-RevId: 626299531
2024-04-19 02:19:10 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Meekail Zain
ceeb975735 Add new cumulative_sum function to numpy and array_api 2024-04-16 19:57:55 +00:00
Meekail Zain
6bdc83c680 Add new unstack function to numpy/array_api namespaces 2024-04-15 21:03:26 +00:00
Yue Sheng
64775d02a3 Async dispatch expensive computations on the JAX CPU backend.
Before the change, on CPU backend we always run computations inline unless there are multiple CPU devices with potential collectives. Now, we will use `HloCostAnalysis` to estimate the cost of the computation and do async dispatch if it is expensive.

Add a JAX flag for users to opt-out by adding `jax.config.update('jax_cpu_enable_async_dispatch', False)` in their programs.

PiperOrigin-RevId: 625064815
2024-04-15 13:29:44 -07:00
Meekail Zain
2899213efb Fixed hypot bug on nan/inf pairings, began deprecation of non-real values 2024-04-15 17:56:16 +00:00
Sergei Lebedev
754fab91f7 Bumped the minimum jaxlib to 0.4.23
jaxlib 0.4.23 has xla_extension_version 223 and mlir_api_version 54.
2024-04-13 08:18:33 +01:00
Jake VanderPlas
462e5c603a Finalize deprecation of invalid JIT argument names & numbers
Invalid static_argnames/static_argnums have been resulting in a warning since JAX v0.3.17, released in June 2022. After this change, they will result in an error.

PiperOrigin-RevId: 624270701
2024-04-12 13:09:17 -07:00
jax authors
4331abecff Merge pull request #20603 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 624221601
2024-04-12 10:30:01 -07:00
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Jake VanderPlas
e07325a672 Make complex_arr.astype(bool) follow NumPy's semantics 2024-04-09 16:15:59 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Sergei Lebedev
9616900cc9 jax.pure_callback and jax.experimental.io_callback now use jax.Arrays
The motivation for this change is two-fold

* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
  of always copying it to the host. Note that the version here still always
  copies to the host.

If this change breaks you, you can recover the old behavior by changing

    jax.pure_callback(
        f,
        result_shape_dtypes,
        *args,
        **kwargs,
    )

to

    jax.pure_callback(
        lambda *args: f(*jax.tree.map(np.asarray, args)),
        result_shape_dtypes,
        *args,
        **kwargs,
    )

so that the callback function is called with NumPy arrays as before.

I will update the "External callbacks" tutorial in a follow up.

PiperOrigin-RevId: 622457378
2024-04-06 09:30:08 -07:00
rajasekharporeddy
0d68a1a82d Fix doc typos 2024-04-05 14:21:33 +05:30
George Necula
a510f03ef8 [callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.

See issue #20385 for more details.

We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
2024-04-05 08:51:30 +01:00
jax authors
2512843a56 Merge pull request #20550 from Micky774:api_clip
PiperOrigin-RevId: 622045823
2024-04-04 19:58:06 -07:00
Roy Frostig
f247822977 changelog: note doc change to use jax.random.key over PRNGKey 2024-04-04 16:38:08 -07:00
Roy Frostig
2a36d75285 changelog: batching rule change for rng_bit_generator 2024-04-04 16:34:10 -07:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
Sergei Lebedev
498e81ab10 Pallas now exclusively uses XLA for compiling kernels on GPU
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.

PiperOrigin-RevId: 621857046
2024-04-04 07:47:26 -07:00
Yash Katariya
24517ca3e0 Finish jax and jaxlib 0.4.26 release
PiperOrigin-RevId: 621658207
2024-04-03 15:40:24 -07:00
Peter Hawkins
61493263a9 Prepare for 0.4.26 release. 2024-04-03 14:38:58 -04:00
Jake VanderPlas
fd7c85b349 jnp.geomspace: make complex behavior consistent with NumPy 2.0 2024-04-02 16:12:49 -07:00
Jake VanderPlas
9e01afe7af Add jax.numpy.trapezoid
This function has been added to NumPy in version 2.0, as a replacement
for the already deprecated trapz function.
2024-04-01 13:05:20 -07:00
George Necula
c0c918aa8b [export] Increase minimum serialization version to 9.
Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024.

This change could break clients that set a specific JAX serialization version lower than 9.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

PiperOrigin-RevId: 619588685
2024-03-27 11:06:23 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
George Necula
ca59971bef [host_callback] Deprecate the jax.experimental.host_callback module. 2024-03-21 09:11:17 +02:00
Yue Sheng
147c363ea6 Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
154403c03d Finalize deprecations of jax.interpreters.ad config & source_info_util
These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError.

PiperOrigin-RevId: 616931120
2024-03-18 13:33:17 -07:00
Peter Hawkins
ee2631e4da Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
2024-03-18 11:46:06 -07:00
rajasekharporeddy
e94299c946
Fix Typos in CHANGELOG.md
This PR fixes the typos in Change log documentation
2024-03-12 13:57:07 +05:30
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Jake VanderPlas
c2d07a6623 Finalize deprecation of non-array arguments to array_equal/array_equiv 2024-02-29 05:31:37 -08:00
Jake VanderPlas
236275ebe1 Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3

PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
e0fd29082d Finish jax and jaxlib 0.4.25 release
PiperOrigin-RevId: 610413312
2024-02-26 08:19:05 -08:00
Yash Katariya
f4045dceb2 Remove the deprecation of jax.tree_map for the release of 0.4.25
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
jax authors
be002b5f1c Merge pull request #19930 from jakevdp:dep-tree_map
PiperOrigin-RevId: 609508069
2024-02-22 15:01:35 -08:00
Jake VanderPlas
a5abe4568d Mention re-instated xla APIs in the CHANGELOG 2024-02-22 12:19:29 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Sergei Lebedev
0bf8dddace Compile Triton kernels via XLA by default
PiperOrigin-RevId: 609299269
2024-02-22 02:32:26 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00