From fbcb157ad35f60868def2e71ea2725013861e6a4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 28 Jun 2024 07:27:23 -0700 Subject: [PATCH] Finalize deprecation of several previously-deprecated jax.core functions: - `jax.core.canonicalize_shape` - `jax.core.dimension_as_value` - `jax.core.definitely_equal` - `jax.core.symbolic_equal_dim` These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024. PiperOrigin-RevId: 647671122 --- CHANGELOG.md | 4 ++++ jax/core.py | 17 +++++------------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be8f6f8d1..4f8e1a084 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,10 @@ Remember to align the itemized text with the first line of an item within a list * `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be installed either as a part of local CUDA installation, or via NVIDIA's CUDA pip wheels. +* Deprecations + * Removed a number of previously-deprecated internal APIs related to + polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`, + `dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`. ## jaxlib 0.4.31 diff --git a/jax/core.py b/jax/core.py index c23d37123..b023d2daf 100644 --- a/jax/core.py +++ b/jax/core.py @@ -68,7 +68,6 @@ from jax._src.core import ( call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, - canonicalize_shape as _deprecated_canonicalize_shape, check_eqn as check_eqn, check_jaxpr as check_jaxpr, check_type as check_type, @@ -80,8 +79,6 @@ from jax._src.core import ( cur_sublevel as cur_sublevel, custom_typechecks as custom_typechecks, dedup_referents as dedup_referents, - definitely_equal as _deprecated_definitely_equal, - dimension_as_value as _deprecated_dimension_as_value, do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr, ensure_compile_time_eval as ensure_compile_time_eval, escaped_tracer_error as escaped_tracer_error, @@ -172,18 +169,18 @@ _deprecations = { "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", None, ), - # Added Dec 15, 2023 + # Finalized 2024-06-24; remove after 2024-09-24 "canonicalize_shape": ( - "jax.core.canonicalize_shape is deprecated.", _deprecated_canonicalize_shape, + "jax.core.canonicalize_shape is deprecated.", None, ), "dimension_as_value": ( - "jax.core.dimension_as_value is deprecated. Use jnp.array.", _deprecated_dimension_as_value, + "jax.core.dimension_as_value is deprecated. Use jnp.array.", None, ), "definitely_equal": ( - "jax.core.definitely_equal is deprecated. Use ==.", _deprecated_definitely_equal, + "jax.core.definitely_equal is deprecated. Use ==.", None, ), "symbolic_equal_dim": ( - "jax.core.symbolic_equal_dim is deprecated. Use ==.", _deprecated_definitely_equal, + "jax.core.symbolic_equal_dim is deprecated. Use ==.", None, ), # Added Jan 8, 2024 "non_negative_dim": ( @@ -193,9 +190,6 @@ _deprecations = { import typing if typing.TYPE_CHECKING: - canonicalize_shape = _deprecated_canonicalize_shape - dimension_as_value = _deprecated_dimension_as_value - definitely_equal = _deprecated_definitely_equal non_negative_dim = _deprecated_non_negative_dim pp_aval = _src_core.pp_aval pp_eqn = _src_core.pp_eqn @@ -209,7 +203,6 @@ if typing.TYPE_CHECKING: pp_kv_pairs = _src_core.pp_kv_pairs pp_var = _src_core.pp_var pp_vars = _src_core.pp_vars - symbolic_equal_dim = _deprecated_definitely_equal else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations)