`jax.experimental.host_callback` has been deprecated since March 2024
(JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.
If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.
See https://github.com/google/jax/issues/20385 for a discussion.
PiperOrigin-RevId: 681004255
Fix a type error that arises when we try to run the host callback tests with JAX_HOST_CALLBACK_LEGACY=False (in the process of deprecating jax.experimental.host_callback).
PiperOrigin-RevId: 671825020
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.
See issue #20385 for more details.
We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Currently JAX callbacks on TPU raise errors when the called function takes empty arguments or returns empty results. It seems that the send_to_host function works
even with empty arrays, but recv_from_host crashes (crash log below).
Here we work around this issue, by ensuring that only the non-empty results of the Python callback are sent to the device computation and the empty results are replaced with empty constants in the device computation.
This is part of the work to replace uses of host_callback with io_callback.
PiperOrigin-RevId: 559061336
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.
This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.
PiperOrigin-RevId: 557520668
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
Support for calling the xla_client compiler() API with classic HLO input is being removed, but we should just use the public JAX API anyway.
PiperOrigin-RevId: 496655326
Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.