The `vectorized` argument to `pure_callback` and `ffi_call` was deprecated in JAX v0.4.34 (released Oct 4 2024), then added to the CHANGELOG in v0.4.35 (doh! released Oct 22). The JAX compatibility policy requires 3 months of compatible releases before a deprecation is finalized, so it is time to remove this parameter from the public API. The `vmap_method` parameter can be used instead, and the docs for [`pure_callback`](https://docs.jax.dev/en/latest/_autosummary/jax.pure_callback.html) provide more details.
This change has one other (non-obvious!) affect on the user facing APIs. (Note that this change in behavior has also been protected by a deprecation warning since the `vectorized` parameter was deprecated.) The default behavior of `pure_callback` and `ffi_call` under `vmap` is to now raise an exception, rather than silently producing a loop. To opt in to the previous default behavior, use `vmap_method="sequential"`.
PiperOrigin-RevId: 747413383
These APIs are already broken on GPU and TPU by virtue of not being implemented in the PJRT C API, so it seems unlikely that they have any users.
PiperOrigin-RevId: 746595857
Previously, jax.jit returned a function with extra attributes, e.g., `trace`, and `lower`, such that we can use:
```
jax.jit(f).trace(...)
```
The new attributes create problems when `jax.jit` is used along `functools.wraps`.
Essentially, `functools.wraps(jax.jit(f))(wrapper)` is supposed to result in a
function that when invoked will invoke `wrapper` and then presumably `jax.jit(f)`.
This works as expected if you just call the result, but if you try to use it with
`lower` and `trace`, the `wrapper` is bypassed. This is because `wraps` copies the
attributes `trace` and `lower` from `jax.jit(f)` onto the resulting function,
so when `trace` is invoked the `wrapper` is bypassed entirely.
See #27829 and #27825.
The solution proposed here is to make the `trace` and `lower` be class attributes,
so that they are not copied by `functools.wraps`.
Thus, if you try to use `lower` or `trace` on the result of
`functools.wraps(jax.jit(f))()` you will get an error.
That is better than silently ignoring the wrapper.
The workaround is to apply `jax.jit` last among your wrappers.
Fixes: #27829
These are available via jax.extend.mlir.dialects.
No deprecation period because jax.interpreters.mlir is not a stable API.
PiperOrigin-RevId: 744712537
We have no usages of it neither in JAX nor internally, but we still have to
go through the deprecation cycle, becuase `jax.tree_util` is public API.
PiperOrigin-RevId: 739196514
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.
PiperOrigin-RevId: 731812827
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
For several releases, libtpu-nightly has been a transitional empty package that does nothing. We remove the dependency in preparation for depending on libtpu from pypi instead of a GCS bucket in jax[tpu].