The motivation for this change is two-fold
* JAX APIs should use jax.Arrays.
* Using jax.Arrays potentially allows keeping the data on device, instead
of always copying it to the host. Note that the version here still always
copies to the host.
If this change breaks you, you can recover the old behavior by changing
jax.pure_callback(
f,
result_shape_dtypes,
*args,
**kwargs,
)
to
jax.pure_callback(
lambda *args: f(*jax.tree.map(np.asarray, args)),
result_shape_dtypes,
*args,
**kwargs,
)
so that the callback function is called with NumPy arrays as before.
I will update the "External callbacks" tutorial in a follow up.
PiperOrigin-RevId: 622457378
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.
Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.
Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).
PiperOrigin-RevId: 622352537
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.
See issue #20385 for more details.
We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
It is not well-defined, and the implementation is currently untested. Both Python and NumPy raise an error when attempting complex-valued floor division.
PiperOrigin-RevId: 621918604
The old lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
PiperOrigin-RevId: 621857046
Also add a method for the plugin to return an xla_version plugin attribute.
Currently jaxlib pins a TPU/GPU backend, and uses `xla_extension_version` for backend version. As we want to stop pinning TPU/GPU backend and allow pip install different backend separately, we need this `xla_version` for features that are not capture by PJRT C API version. `xla_extension_version` will still be used for API changes such as xla_client.py, or any XLA changes in jaxlib that are not part of plugins.
PiperOrigin-RevId: 621672421
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247