mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Mark jax.abstract_arrays as deprecated
This commit is contained in:
parent
0ec9f3c2df
commit
47ae5bddd7
20
CHANGELOG.md
20
CHANGELOG.md
@ -9,16 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jax 0.4.12
|
||||
|
||||
* Deprecations
|
||||
* The following APIs have been removed after a 3 month deprecation period, in
|
||||
accordance with the {ref}`api-compatibility` policy:
|
||||
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
|
||||
of `numpy.alltrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
|
||||
of `numpy.sometrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
|
||||
of `numpy.product` in NumPy version 1.25.0.
|
||||
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
|
||||
of `numpy.cumproduct` in NumPy version 1.25.0.
|
||||
* `jax.abstract_arrays` and its contents are now deprecated. See related
|
||||
functionality in :mod:`jax.core`.
|
||||
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
|
||||
of `numpy.alltrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
|
||||
of `numpy.sometrue` in NumPy version 1.25.0.
|
||||
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
|
||||
of `numpy.product` in NumPy version 1.25.0.
|
||||
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
|
||||
of `numpy.cumproduct` in NumPy version 1.25.0.
|
||||
|
||||
## jaxlib 0.4.12
|
||||
|
||||
|
@ -519,7 +519,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from jax._src import abstract_arrays\n",
|
||||
"from jax import core\n",
|
||||
"@trace(\"multiply_add_abstract_eval\")\n",
|
||||
"def multiply_add_abstract_eval(xs, ys, zs):\n",
|
||||
" \"\"\"Abstract evaluation of the primitive.\n",
|
||||
@ -533,7 +533,7 @@
|
||||
" \"\"\"\n",
|
||||
" assert xs.shape == ys.shape\n",
|
||||
" assert xs.shape == zs.shape\n",
|
||||
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
|
||||
" return core.ShapedArray(xs.shape, xs.dtype)\n",
|
||||
"\n",
|
||||
"# Now we register the abstract evaluation with JAX\n",
|
||||
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"
|
||||
|
@ -308,7 +308,7 @@ In the latter case, JAX uses the actual concrete value wrapped as an abstract va
|
||||
:id: ctQmEeckIbdo
|
||||
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax import core
|
||||
@trace("multiply_add_abstract_eval")
|
||||
def multiply_add_abstract_eval(xs, ys, zs):
|
||||
"""Abstract evaluation of the primitive.
|
||||
@ -322,7 +322,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
|
||||
"""
|
||||
assert xs.shape == ys.shape
|
||||
assert xs.shape == zs.shape
|
||||
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
|
||||
return core.ShapedArray(xs.shape, xs.dtype)
|
||||
|
||||
# Now we register the abstract evaluation with JAX
|
||||
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
|
||||
|
@ -155,7 +155,7 @@ from jax._src.tree_util import (
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
# jax and rely on the names imported above.
|
||||
from jax import abstract_arrays as abstract_arrays
|
||||
from jax import abstract_arrays as _deprecated_abstract_arrays
|
||||
from jax import custom_derivatives as custom_derivatives
|
||||
from jax import custom_batching as custom_batching
|
||||
from jax import custom_transpose as custom_transpose
|
||||
@ -186,6 +186,11 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache
|
||||
del _ccache
|
||||
|
||||
_deprecations = {
|
||||
# Added 06 June 2023
|
||||
"abstract_arrays": (
|
||||
"jax.abstract_arrays is deprecated. Refer to jax.core.",
|
||||
_deprecated_abstract_arrays
|
||||
),
|
||||
# Added 28 March 2023
|
||||
"ShapedArray": (
|
||||
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
|
||||
@ -219,6 +224,7 @@ _deprecations = {
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src import abstract_arrays as abstract_arrays
|
||||
from jax._src.core import ShapedArray as ShapedArray
|
||||
from jax.interpreters import ad as ad
|
||||
from jax.interpreters import partial_eval as partial_eval
|
||||
|
@ -14,8 +14,35 @@
|
||||
|
||||
# TODO(phawkins): fix users of these aliases and delete this file.
|
||||
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.abstract_arrays import array_types as _deprecated_array_types
|
||||
from jax._src.core import (
|
||||
ShapedArray,
|
||||
raise_to_shaped,
|
||||
ShapedArray as _deprecated_ShapedArray,
|
||||
raise_to_shaped as _deprecated_raise_to_shaped,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added 06 June 2023
|
||||
"array_types": (
|
||||
"jax.abstract_arrays.array_types is deprecated.",
|
||||
_deprecated_array_types,
|
||||
),
|
||||
"ShapedArray": (
|
||||
"jax.abstract_arrays.ShapedArray is deprecated. Use jax.core.ShapedArray.",
|
||||
_deprecated_ShapedArray,
|
||||
),
|
||||
"raise_to_shaped": (
|
||||
"jax.abstract_arrays.raise_to_shaped is deprecated. Use jax.core.raise_to_shaped.",
|
||||
_deprecated_raise_to_shaped,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.abstract_arrays import array_types as array_types
|
||||
from jax._src.core import ShapedArray as ShapedArray
|
||||
from jax._src.core import raise_to_shaped as raise_to_shaped
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user