mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add deprecation warnings for several top-level jax imports
This commit is contained in:
parent
97c8ce31ed
commit
fc47137ca8
@ -74,13 +74,13 @@ from jax._src.lib import xla_client as _xc
|
||||
Device = _xc.Device
|
||||
del _xc
|
||||
|
||||
from jax._src.api import effects_barrier
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.api import checkpoint as checkpoint
|
||||
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
||||
from jax._src.api import clear_backends as clear_backends
|
||||
from jax._src.custom_derivatives import closure_convert as closure_convert
|
||||
from jax._src.util import curry # TODO(phawkins): update users to avoid this.
|
||||
from jax._src.util import curry as _deprecated_curry
|
||||
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
||||
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
||||
from jax._src.custom_derivatives import custom_vjp as custom_vjp
|
||||
@ -93,7 +93,7 @@ from jax._src.api import device_put_replicated as device_put_replicated
|
||||
from jax._src.xla_bridge import devices as devices
|
||||
from jax._src.api import disable_jit as disable_jit
|
||||
from jax._src.api import eval_shape as eval_shape
|
||||
from jax._src.api_util import flatten_fun_nokwargs # TODO(phawkins): update users to avoid this.
|
||||
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
|
||||
from jax._src.dtypes import float0 as float0
|
||||
from jax._src.api import grad as grad
|
||||
from jax._src.api import hessian as hessian
|
||||
@ -118,17 +118,17 @@ from jax._src.xla_bridge import process_count as process_count
|
||||
from jax._src.xla_bridge import process_index as process_index
|
||||
from jax._src.api import pure_callback as pure_callback
|
||||
from jax._src.api import remat as remat
|
||||
from jax._src.core import ShapedArray # TODO(jakevdp): update users to avoid this.
|
||||
from jax._src.core import ShapedArray as _deprecated_ShapedArray
|
||||
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.api import value_and_grad as value_and_grad
|
||||
from jax._src.api import vjp as vjp
|
||||
from jax._src.api import vmap as vmap
|
||||
from jax._src.api import xla_computation as xla_computation
|
||||
|
||||
from jax.interpreters import ad # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import pxla # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import xla # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import ad as _deprecated_ad
|
||||
from jax.interpreters import partial_eval as _deprecated_partial_eval
|
||||
from jax.interpreters import pxla as _deprecated_pxla
|
||||
from jax.interpreters import xla as _deprecated_xla
|
||||
|
||||
from jax._src.array import (
|
||||
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
|
||||
@ -175,6 +175,54 @@ from jax import util as util
|
||||
from jax._src.array import Shard as Shard
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Added 28 March 2023
|
||||
"ShapedArray": (
|
||||
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
|
||||
_deprecated_ShapedArray,
|
||||
),
|
||||
"ad": (
|
||||
"jax.ad is deprecated. Use jax.interpreters.ad",
|
||||
_deprecated_ad,
|
||||
),
|
||||
"partial_eval": (
|
||||
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
|
||||
_deprecated_partial_eval,
|
||||
),
|
||||
"pxla": (
|
||||
"jax.pxla is deprecated. Use jax.interpreters.pxla",
|
||||
_deprecated_pxla,
|
||||
),
|
||||
"xla": (
|
||||
"jax.xla is deprecated. Use jax.interpreters.xla",
|
||||
_deprecated_xla,
|
||||
),
|
||||
"curry": (
|
||||
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
|
||||
_deprecated_curry,
|
||||
),
|
||||
"flatten_fun_nokwargs": (
|
||||
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
|
||||
_deprecated_flatten_fun_nokwargs,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src.core import ShapedArray as ShapedArray
|
||||
from jax.interpreters import ad as ad
|
||||
from jax.interpreters import partial_eval as partial_eval
|
||||
from jax.interpreters import pxla as pxla
|
||||
from jax.interpreters import xla as xla
|
||||
from jax._src.util import curry as curry
|
||||
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove after 2 jax releases from 0.4.6
|
||||
if not config.jax_jit_pjit_api_merge:
|
||||
raise ValueError(
|
||||
|
@ -9648,5 +9648,22 @@ class GarbageCollectionTest(jtu.JaxTestCase):
|
||||
|
||||
assert x_np_weakref() is None
|
||||
|
||||
class DeprecationsTest(jtu.JaxTestCase):
|
||||
def test_jax_deprecations(self):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.ShapedArray, jax.core.ShapedArray)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.ad, jax.interpreters.ad)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.partial_eval, jax.interpreters.partial_eval)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.pxla, jax.interpreters.pxla)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.xla, jax.interpreters.xla)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.curry, jax._src.util.curry)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.assertIs(jax.flatten_fun_nokwargs, jax.api_util.flatten_fun_nokwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user