Finalize deprecation of several previously-deprecated jax.core functions:

- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
This commit is contained in:
Jake VanderPlas 2024-06-28 07:27:23 -07:00 committed by jax authors
parent 648b9519cf
commit fbcb157ad3
2 changed files with 9 additions and 12 deletions

View File

@ -18,6 +18,10 @@ Remember to align the itemized text with the first line of an item within a list
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
`dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`.
## jaxlib 0.4.31

View File

@ -68,7 +68,6 @@ from jax._src.core import (
call_bind_with_continuation as call_bind_with_continuation,
call_impl as call_impl,
call_p as call_p,
canonicalize_shape as _deprecated_canonicalize_shape,
check_eqn as check_eqn,
check_jaxpr as check_jaxpr,
check_type as check_type,
@ -80,8 +79,6 @@ from jax._src.core import (
cur_sublevel as cur_sublevel,
custom_typechecks as custom_typechecks,
dedup_referents as dedup_referents,
definitely_equal as _deprecated_definitely_equal,
dimension_as_value as _deprecated_dimension_as_value,
do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
ensure_compile_time_eval as ensure_compile_time_eval,
escaped_tracer_error as escaped_tracer_error,
@ -172,18 +169,18 @@ _deprecations = {
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
None,
),
# Added Dec 15, 2023
# Finalized 2024-06-24; remove after 2024-09-24
"canonicalize_shape": (
"jax.core.canonicalize_shape is deprecated.", _deprecated_canonicalize_shape,
"jax.core.canonicalize_shape is deprecated.", None,
),
"dimension_as_value": (
"jax.core.dimension_as_value is deprecated. Use jnp.array.", _deprecated_dimension_as_value,
"jax.core.dimension_as_value is deprecated. Use jnp.array.", None,
),
"definitely_equal": (
"jax.core.definitely_equal is deprecated. Use ==.", _deprecated_definitely_equal,
"jax.core.definitely_equal is deprecated. Use ==.", None,
),
"symbolic_equal_dim": (
"jax.core.symbolic_equal_dim is deprecated. Use ==.", _deprecated_definitely_equal,
"jax.core.symbolic_equal_dim is deprecated. Use ==.", None,
),
# Added Jan 8, 2024
"non_negative_dim": (
@ -193,9 +190,6 @@ _deprecations = {
import typing
if typing.TYPE_CHECKING:
canonicalize_shape = _deprecated_canonicalize_shape
dimension_as_value = _deprecated_dimension_as_value
definitely_equal = _deprecated_definitely_equal
non_negative_dim = _deprecated_non_negative_dim
pp_aval = _src_core.pp_aval
pp_eqn = _src_core.pp_eqn
@ -209,7 +203,6 @@ if typing.TYPE_CHECKING:
pp_kv_pairs = _src_core.pp_kv_pairs
pp_var = _src_core.pp_var
pp_vars = _src_core.pp_vars
symbolic_equal_dim = _deprecated_definitely_equal
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)