The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.
See issue #20385 for more details.
We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
It is not well-defined, and the implementation is currently untested. Both Python and NumPy raise an error when attempting complex-valued floor division.
PiperOrigin-RevId: 621918604
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
PiperOrigin-RevId: 621857046
Also add a method for the plugin to return an xla_version plugin attribute.
Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins.
PiperOrigin-RevId: 621672421
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
Previously, prior to #20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:
```
def f_host():
return 42.
io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```
However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.
In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.
In some sense this is replacing the change in #20433 to add a canonicalization
step instead of an error.