From bb5787da0996fc32ac0af76572afae57406e4af0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 May 2024 10:39:43 -0700 Subject: [PATCH] Finalize deprecations of several APIs PiperOrigin-RevId: 633634215 --- CHANGELOG.md | 16 ++++++++++++---- docs/jax.lax.rst | 1 - jax/core.py | 8 +++----- jax/interpreters/xla.py | 41 +++++++++++------------------------------ jax/lax/__init__.py | 16 +++++----------- jax/nn/__init__.py | 15 +++++---------- 6 files changed, 36 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65e2dd1ef..a13a8fd5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. +* Deprecations + * Removed a number of previously-deprecated APIs: + * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` + * from {mod}`jax.lax`: `tie_in` + * from {mod}`jax.nn`: `normalize` + * from {mod}`jax.interpreters.xla`: `backend_specific_translations`, + `translations`, `register_translation`, `xla_destructure`, + `TranslationRule`, `TranslationContext`, `XlaOp`. ## jaxlib 0.4.29 @@ -114,7 +122,7 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. - The default value is set to `copy=False` to preserve backwards compatability. + The default value is set to `copy=False` to preserve backwards compatibility. ## jaxlib 0.4.27 (May 7, 2024) @@ -219,7 +227,7 @@ Remember to align the itemized text with the first line of an item within a list * Changes * JAX lowering to StableHLO does not depend on physical devices anymore. - If your primitive wraps custom_paritioning or JAX callbacks in the lowering + If your primitive wraps custom_partitioning or JAX callbacks in the lowering rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set. This is needed because custom_partitioning and JAX callbacks need physical @@ -1276,7 +1284,7 @@ Changes: * Added {func}`jax.random.ball`. * Added {func}`jax.default_device`. * Added a `python -m jax.collect_profile` script to manually capture program - traces as an alternative to the Tensorboard UI. + traces as an alternative to the TensorBoard UI. * Added a `jax.named_scope` context manager that adds profiler metadata to Python programs (similar to `jax.named_call`). * In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit @@ -2442,7 +2450,7 @@ Changes: * Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`. * Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided. * Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`. -* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`. +* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`. * Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`. * Add docstring for `test_util.check_grads` {jax-issue}`#2656`. * Add `callback_transform` {jax-issue}`#2665`. diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index b916458a3..32db1ba77 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -158,7 +158,6 @@ Operators sub tan tanh - tie_in top_k transpose zeros_like_array diff --git a/jax/core.py b/jax/core.py index cd0b3c8f2..edc31778f 100644 --- a/jax/core.py +++ b/jax/core.py @@ -162,14 +162,14 @@ from jax._src.core import ( from jax._src import core as _src_core _deprecations = { - # Added Oct 11, 2023: + # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", - _src_core.DimSize, + None, ), "Shape": ( "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].", - _src_core.Shape, + None, ), # Added Dec 15, 2023 "canonicalize_shape": ( @@ -192,8 +192,6 @@ _deprecations = { import typing if typing.TYPE_CHECKING: - DimSize = _src_core.DimSize - Shape = _src_core.Shape canonicalize_shape = _deprecated_canonicalize_shape dimension_as_value = _deprecated_dimension_as_value definitely_equal = _deprecated_definitely_equal diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 3b042dcea..1a25a22c7 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -17,14 +17,6 @@ from jax._src.interpreters.xla import ( canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, pytype_aval_mappings as pytype_aval_mappings, - - # Deprecations - backend_specific_translations as _deprecated_backend_specific_translations, - register_translation as _deprecated_register_translation, - translations as _deprecated_translations, - xla_destructure as _deprecated_xla_destructure, - TranslationContext as _deprecated_TranslationContext, - TranslationRule as _deprecated_TranslationRule, ) from jax._src.dispatch import ( @@ -39,55 +31,44 @@ Backend = xe.Client # Deprecations _deprecations = { - # Added Aug 29, 2023: + # Finalized 2024-05-13; remove after 2024-08-13 "backend_specific_translations": ( "jax.interpreters.xla.backend_specific_translations is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_backend_specific_translations, + None, ), "translations": ( "jax.interpreters.xla.translations is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_translations, + None, ), "register_translation": ( "jax.interpreters.xla.register_translation is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_register_translation, + None, ), "xla_destructure": ( "jax.interpreters.xla.xla_destructure is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_xla_destructure, + None, ), "TranslationRule": ( "jax.interpreters.xla.TranslationRule is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_TranslationRule, + None, ), "TranslationContext": ( "jax.interpreters.xla.TranslationContext is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - _deprecated_TranslationContext, + None, ), "XlaOp": ( "jax.interpreters.xla.XlaOp is deprecated. " "Register custom primitives via jax.interpreters.mlir instead.", - xc.XlaOp, + None, ), } -import typing -if typing.TYPE_CHECKING: - backend_specific_translations = _deprecated_backend_specific_translations - translations = _deprecated_translations - register_translation = _deprecated_register_translation - xla_destructure = _deprecated_xla_destructure - TranslationRule = _deprecated_TranslationRule - TranslationContext = _deprecated_TranslationContext - XlaOp = xc.XlaOp -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 1aa0c986c..040786c22 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -211,7 +211,6 @@ from jax._src.lax.lax import ( tan_p as tan_p, tanh as tanh, tanh_p as tanh_p, - tie_in as _deprecated_tie_in, top_k as top_k, top_k_p as top_k_p, transpose as transpose, @@ -377,18 +376,13 @@ from jax._src.dispatch import device_put_p as device_put_p _deprecations = { - # Added January 18 2023 + # Finalized 2024-05-13; remove after 2024-08-13 "tie_in": ( "jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. " - "Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in, + "Replace z = tie_in(x, y) with z = y.", None, ), } -import typing as _typing -if _typing.TYPE_CHECKING: - tie_in = _deprecated_tie_in -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del _typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 318416a3a..9eebebd7b 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -52,18 +52,13 @@ from jax._src.nn.functions import ( # Deprecations _deprecations = { - # Added Nov 8, 2023: + # Finalized 2024-05-13; remove after 2024-08-13 "normalize": ( "jax.nn.normalize is deprecated. Use jax.nn.standardize instead.", - standardize, + None, ), } -import typing -if typing.TYPE_CHECKING: - normalize = standardize -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing +from jax._src.deprecations import deprecation_getattr as _deprecation_getattr +__getattr__ = _deprecation_getattr(__name__, _deprecations) +del _deprecation_getattr