mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
This commit is contained in:
parent
1d6ffdedc5
commit
bb5787da09
16
CHANGELOG.md
16
CHANGELOG.md
@ -13,6 +13,14 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
more cases. Previously non-parallel computations were always dispatched
|
||||
synchronously. You can recover the old behavior by setting
|
||||
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
|
||||
* Deprecations
|
||||
* Removed a number of previously-deprecated APIs:
|
||||
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
|
||||
* from {mod}`jax.lax`: `tie_in`
|
||||
* from {mod}`jax.nn`: `normalize`
|
||||
* from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
|
||||
`translations`, `register_translation`, `xla_destructure`,
|
||||
`TranslationRule`, `TranslationContext`, `XlaOp`.
|
||||
|
||||
## jaxlib 0.4.29
|
||||
|
||||
@ -114,7 +122,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
|
||||
Previously, no copy would be made when the output array would have the same
|
||||
dtype as the input array. This may result in some increased memory usage.
|
||||
The default value is set to `copy=False` to preserve backwards compatability.
|
||||
The default value is set to `copy=False` to preserve backwards compatibility.
|
||||
|
||||
## jaxlib 0.4.27 (May 7, 2024)
|
||||
|
||||
@ -219,7 +227,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Changes
|
||||
|
||||
* JAX lowering to StableHLO does not depend on physical devices anymore.
|
||||
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
|
||||
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
|
||||
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
|
||||
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
|
||||
This is needed because custom_partitioning and JAX callbacks need physical
|
||||
@ -1276,7 +1284,7 @@ Changes:
|
||||
* Added {func}`jax.random.ball`.
|
||||
* Added {func}`jax.default_device`.
|
||||
* Added a `python -m jax.collect_profile` script to manually capture program
|
||||
traces as an alternative to the Tensorboard UI.
|
||||
traces as an alternative to the TensorBoard UI.
|
||||
* Added a `jax.named_scope` context manager that adds profiler metadata to
|
||||
Python programs (similar to `jax.named_call`).
|
||||
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
|
||||
@ -2442,7 +2450,7 @@ Changes:
|
||||
* Added several new rules for `jax.experimental.jet` {jax-issue}`#2537`.
|
||||
* Fixed `jax.experimental.stax.BatchNorm` when `scale`/`center` isn't provided.
|
||||
* Fix some missing cases of broadcasting in `jax.numpy.einsum` {jax-issue}`#2512`.
|
||||
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitray order {jax-issue}`#2597`.
|
||||
* Implement `jax.numpy.cumsum` and `jax.numpy.cumprod` in terms of a parallel prefix scan {jax-issue}`#2596` and make `reduce_prod` differentiable to arbitrary order {jax-issue}`#2597`.
|
||||
* Add `batch_group_count` to `conv_general_dilated` {jax-issue}`#2635`.
|
||||
* Add docstring for `test_util.check_grads` {jax-issue}`#2656`.
|
||||
* Add `callback_transform` {jax-issue}`#2665`.
|
||||
|
@ -158,7 +158,6 @@ Operators
|
||||
sub
|
||||
tan
|
||||
tanh
|
||||
tie_in
|
||||
top_k
|
||||
transpose
|
||||
zeros_like_array
|
||||
|
@ -162,14 +162,14 @@ from jax._src.core import (
|
||||
|
||||
from jax._src import core as _src_core
|
||||
_deprecations = {
|
||||
# Added Oct 11, 2023:
|
||||
# Finalized 2024-05-13; remove after 2024-08-13
|
||||
"DimSize": (
|
||||
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
|
||||
_src_core.DimSize,
|
||||
None,
|
||||
),
|
||||
"Shape": (
|
||||
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
|
||||
_src_core.Shape,
|
||||
None,
|
||||
),
|
||||
# Added Dec 15, 2023
|
||||
"canonicalize_shape": (
|
||||
@ -192,8 +192,6 @@ _deprecations = {
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
DimSize = _src_core.DimSize
|
||||
Shape = _src_core.Shape
|
||||
canonicalize_shape = _deprecated_canonicalize_shape
|
||||
dimension_as_value = _deprecated_dimension_as_value
|
||||
definitely_equal = _deprecated_definitely_equal
|
||||
|
@ -17,14 +17,6 @@ from jax._src.interpreters.xla import (
|
||||
canonicalize_dtype as canonicalize_dtype,
|
||||
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
|
||||
pytype_aval_mappings as pytype_aval_mappings,
|
||||
|
||||
# Deprecations
|
||||
backend_specific_translations as _deprecated_backend_specific_translations,
|
||||
register_translation as _deprecated_register_translation,
|
||||
translations as _deprecated_translations,
|
||||
xla_destructure as _deprecated_xla_destructure,
|
||||
TranslationContext as _deprecated_TranslationContext,
|
||||
TranslationRule as _deprecated_TranslationRule,
|
||||
)
|
||||
|
||||
from jax._src.dispatch import (
|
||||
@ -39,55 +31,44 @@ Backend = xe.Client
|
||||
|
||||
# Deprecations
|
||||
_deprecations = {
|
||||
# Added Aug 29, 2023:
|
||||
# Finalized 2024-05-13; remove after 2024-08-13
|
||||
"backend_specific_translations": (
|
||||
"jax.interpreters.xla.backend_specific_translations is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_backend_specific_translations,
|
||||
None,
|
||||
),
|
||||
"translations": (
|
||||
"jax.interpreters.xla.translations is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_translations,
|
||||
None,
|
||||
),
|
||||
"register_translation": (
|
||||
"jax.interpreters.xla.register_translation is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_register_translation,
|
||||
None,
|
||||
),
|
||||
"xla_destructure": (
|
||||
"jax.interpreters.xla.xla_destructure is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_xla_destructure,
|
||||
None,
|
||||
),
|
||||
"TranslationRule": (
|
||||
"jax.interpreters.xla.TranslationRule is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_TranslationRule,
|
||||
None,
|
||||
),
|
||||
"TranslationContext": (
|
||||
"jax.interpreters.xla.TranslationContext is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
_deprecated_TranslationContext,
|
||||
None,
|
||||
),
|
||||
"XlaOp": (
|
||||
"jax.interpreters.xla.XlaOp is deprecated. "
|
||||
"Register custom primitives via jax.interpreters.mlir instead.",
|
||||
xc.XlaOp,
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
backend_specific_translations = _deprecated_backend_specific_translations
|
||||
translations = _deprecated_translations
|
||||
register_translation = _deprecated_register_translation
|
||||
xla_destructure = _deprecated_xla_destructure
|
||||
TranslationRule = _deprecated_TranslationRule
|
||||
TranslationContext = _deprecated_TranslationContext
|
||||
XlaOp = xc.XlaOp
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -211,7 +211,6 @@ from jax._src.lax.lax import (
|
||||
tan_p as tan_p,
|
||||
tanh as tanh,
|
||||
tanh_p as tanh_p,
|
||||
tie_in as _deprecated_tie_in,
|
||||
top_k as top_k,
|
||||
top_k_p as top_k_p,
|
||||
transpose as transpose,
|
||||
@ -377,18 +376,13 @@ from jax._src.dispatch import device_put_p as device_put_p
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Added January 18 2023
|
||||
# Finalized 2024-05-13; remove after 2024-08-13
|
||||
"tie_in": (
|
||||
"jax.lax.tie_in is deprecated: it has been a no-op since JAX v0.2.0. "
|
||||
"Replace z = tie_in(x, y) with z = y.", _deprecated_tie_in,
|
||||
"Replace z = tie_in(x, y) with z = y.", None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
tie_in = _deprecated_tie_in
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -52,18 +52,13 @@ from jax._src.nn.functions import (
|
||||
# Deprecations
|
||||
|
||||
_deprecations = {
|
||||
# Added Nov 8, 2023:
|
||||
# Finalized 2024-05-13; remove after 2024-08-13
|
||||
"normalize": (
|
||||
"jax.nn.normalize is deprecated. Use jax.nn.standardize instead.",
|
||||
standardize,
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
normalize = standardize
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
Loading…
x
Reference in New Issue
Block a user