Finalize deprecations of several APIs

PiperOrigin-RevId: 633634215
This commit is contained in:
Jake VanderPlas 2024-05-14 10:39:43 -07:00 committed by jax authors
parent 1d6ffdedc5
commit bb5787da09
6 changed files with 36 additions and 61 deletions

View File

@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* Deprecations
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
* from {mod}`jax.lax`: `tie_in`
* from {mod}`jax.nn`: `normalize`
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
## jaxlib 0.4.29
@ -114,7 +122,7 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatability.
The default value is set to `copy=False` to preserve backwards compatibility.
## jaxlib 0.4.27 (May 7, 2024)
@ -219,7 +227,7 @@ Remember to align the itemized text with the first line of an item within a list
* Changes
* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
@ -1276,7 +1284,7 @@ Changes:
* Added {func}`jax.random.ball`.
* Added {func}`jax.default_device`.
* Added a `python -m jax.collect_profile` script to manually capture program
traces as an alternative to the Tensorboard UI.
traces as an alternative to the TensorBoard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
@ -2442,7 +2450,7 @@ Changes:
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`.
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`.
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
* Add `callback_transform` {jax-issue}`#2665`.

View File

@ -158,7 +158,6 @@ Operators
sub
tan
tanh
tie_in
top_k
transpose
zeros_like_array

View File

@ -162,14 +162,14 @@ from jax._src.core import (
from jax._src import core as _src_core
_deprecations = {
# Added Oct 11, 2023:
# Finalized 2024-05-13; remove after 2024-08-13
"DimSize": (
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
_src_core.DimSize,
None,
),
"Shape": (
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
_src_core.Shape,
None,
),
# Added Dec 15, 2023
"canonicalize_shape": (
@ -192,8 +192,6 @@ _deprecations = {
import typing
if typing.TYPE_CHECKING:
DimSize = _src_core.DimSize
Shape = _src_core.Shape
canonicalize_shape = _deprecated_canonicalize_shape
dimension_as_value = _deprecated_dimension_as_value
definitely_equal = _deprecated_definitely_equal

View File

@ -17,14 +17,6 @@ from jax._src.interpreters.xla import (
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
pytype_aval_mappings as pytype_aval_mappings,
# Deprecations
backend_specific_translations as _deprecated_backend_specific_translations,
register_translation as _deprecated_register_translation,
translations as _deprecated_translations,
xla_destructure as _deprecated_xla_destructure,
TranslationContext as _deprecated_TranslationContext,
TranslationRule as _deprecated_TranslationRule,
)
from jax._src.dispatch import (
@ -39,55 +31,44 @@ Backend = xe.Client
# Deprecations
_deprecations = {
# Added Aug 29, 2023:
# Finalized 2024-05-13; remove after 2024-08-13
"backend_specific_translations": (
"jax.interpreters.xla.backend_specific_translations is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_backend_specific_translations,
None,
),
"translations": (
"jax.interpreters.xla.translations is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_translations,
None,
),
"register_translation": (
"jax.interpreters.xla.register_translation is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_register_translation,
None,
),
"xla_destructure": (
"jax.interpreters.xla.xla_destructure is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_xla_destructure,
None,
),
"TranslationRule": (
"jax.interpreters.xla.TranslationRule is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_TranslationRule,
None,
),
"TranslationContext": (
"jax.interpreters.xla.TranslationContext is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
_deprecated_TranslationContext,
None,
),
"XlaOp": (
"jax.interpreters.xla.XlaOp is deprecated. "
"Register custom primitives via jax.interpreters.mlir instead.",
xc.XlaOp,
None,
),
}
import typing
if typing.TYPE_CHECKING:
backend_specific_translations = _deprecated_backend_specific_translations
translations = _deprecated_translations
register_translation = _deprecated_register_translation
xla_destructure = _deprecated_xla_destructure
TranslationRule = _deprecated_TranslationRule
TranslationContext = _deprecated_TranslationContext
XlaOp = xc.XlaOp
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -211,7 +211,6 @@ from jax._src.lax.lax import (
tan_p as tan_p,
tanh as tanh,
tanh_p as tanh_p,
tie_in as _deprecated_tie_in,
top_k as top_k,
top_k_p as top_k_p,
transpose as transpose,
@ -377,18 +376,13 @@ from jax._src.dispatch import device_put_p as device_put_p
_deprecations = {
# Added January 18 2023
# Finalized 2024-05-13; remove after 2024-08-13
"tie_in": (
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
"Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in,
"Replace z = tie_in(x, y) with z = y.", None,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
tie_in = _deprecated_tie_in
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -52,18 +52,13 @@ from jax._src.nn.functions import (
# Deprecations
_deprecations = {
# Added Nov 8, 2023:
# Finalized 2024-05-13; remove after 2024-08-13
"normalize": (
"jax.nn.normalize is deprecated. Use jax.nn.standardize instead.",
standardize,
None,
),
}
import typing
if typing.TYPE_CHECKING:
normalize = standardize
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr