The idea is that each deprecated behavior would have an associated ID by which it could be referred to globally, so that we could call `deprecations.accelerate(module_name, ID)` in order to accelerate the deprecation period and run code with the post-deprecation behavior.
For now, these deprecation accelerations will be private APIs, but we could think about how to expose these to the user, perhaps via a config flag that finalizes all deprecations in the library.
PiperOrigin-RevId: 605064227
This argument is a carry-over from NumPy, and has never had any effect (all jax.numpy
sorts were stable by default). Now that the new stable parameter is supported, it will
be clearer if we explicitly deprecate and eventually remove this argument.
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 592266215
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
Currently, we do not support shape polymorphism when we index with a
slice, e.g., `x[a🅱️c]`, and insted we direct the user to use to
`lax.dynamic_slice`. This is only because so far we have not tried
to ensure that the index and bounds checking computations in gather
are compatible with shape polymorphism. The problem was that there
were a lot of conditionals, e.g., `if start >= stop` that cannot be
handled in general in presence of symbolic shapes.
Here we introduce a new helper function `_preprocess_slice` to contain
all the computations for the start and the size of the slice.
To test that this does not break the JAX index computations, I ran
the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`.