1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

Mark jax.abstract_arrays as deprecated

This commit is contained in:
Jake VanderPlas 2023-06-06 07:32:35 -07:00
parent 0ec9f3c2df
commit 47ae5bddd7
5 changed files with 51 additions and 18 deletions

@ -9,16 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.12
* Deprecations
* The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
* `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
* `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
* `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
## jaxlib 0.4.12

@ -519,7 +519,7 @@
}
],
"source": [
"from jax._src import abstract_arrays\n",
"from jax import core\n",
"@trace(\"multiply_add_abstract_eval\")\n",
"def multiply_add_abstract_eval(xs, ys, zs):\n",
" \"\"\"Abstract evaluation of the primitive.\n",
@ -533,7 +533,7 @@
" \"\"\"\n",
" assert xs.shape == ys.shape\n",
" assert xs.shape == zs.shape\n",
" return abstract_arrays.ShapedArray(xs.shape, xs.dtype)\n",
" return core.ShapedArray(xs.shape, xs.dtype)\n",
"\n",
"# Now we register the abstract evaluation with JAX\n",
"multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"

@ -308,7 +308,7 @@ In the latter case, JAX uses the actual concrete value wrapped as an abstract va
:id: ctQmEeckIbdo
:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
from jax._src import abstract_arrays
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
@ -322,7 +322,7 @@ def multiply_add_abstract_eval(xs, ys, zs):
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

@ -155,7 +155,7 @@ from jax._src.tree_util import (
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from jax import abstract_arrays as abstract_arrays
from jax import abstract_arrays as _deprecated_abstract_arrays
from jax import custom_derivatives as custom_derivatives
from jax import custom_batching as custom_batching
from jax import custom_transpose as custom_transpose
@ -186,6 +186,11 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache
del _ccache
_deprecations = {
# Added 06 June 2023
"abstract_arrays": (
"jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays
),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
@ -219,6 +224,7 @@ _deprecations = {
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval

@ -14,8 +14,35 @@
# TODO(phawkins): fix users of these aliases and delete this file.
from jax._src.abstract_arrays import array_types
from jax._src.abstract_arrays import array_types as _deprecated_array_types
from jax._src.core import (
ShapedArray,
raise_to_shaped,
ShapedArray as _deprecated_ShapedArray,
raise_to_shaped as _deprecated_raise_to_shaped,
)
_deprecations = {
# Added 06 June 2023
"array_types": (
"jax.abstract_arrays.array_types is deprecated.",
_deprecated_array_types,
),
"ShapedArray": (
"jax.abstract_arrays.ShapedArray is deprecated. Use jax.core.ShapedArray.",
_deprecated_ShapedArray,
),
"raise_to_shaped": (
"jax.abstract_arrays.raise_to_shaped is deprecated. Use jax.core.raise_to_shaped.",
_deprecated_raise_to_shaped,
),
}
import typing
if typing.TYPE_CHECKING:
from jax._src.abstract_arrays import array_types as array_types
from jax._src.core import ShapedArray as ShapedArray
from jax._src.core import raise_to_shaped as raise_to_shaped
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing