Add deprecation warnings for several top-level jax imports

This commit is contained in:
Jake VanderPlas 2023-03-28 12:40:59 -07:00
parent 97c8ce31ed
commit fc47137ca8
2 changed files with 73 additions and 8 deletions

View File

@ -74,13 +74,13 @@ from jax._src.lib import xla_client as _xc
Device = _xc.Device
del _xc
from jax._src.api import effects_barrier
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.api import checkpoint as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.util import curry # TODO(phawkins): update users to avoid this.
from jax._src.util import curry as _deprecated_curry
from jax._src.custom_derivatives import custom_gradient as custom_gradient
from jax._src.custom_derivatives import custom_jvp as custom_jvp
from jax._src.custom_derivatives import custom_vjp as custom_vjp
@ -93,7 +93,7 @@ from jax._src.api import device_put_replicated as device_put_replicated
from jax._src.xla_bridge import devices as devices
from jax._src.api import disable_jit as disable_jit
from jax._src.api import eval_shape as eval_shape
from jax._src.api_util import flatten_fun_nokwargs # TODO(phawkins): update users to avoid this.
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
from jax._src.dtypes import float0 as float0
from jax._src.api import grad as grad
from jax._src.api import hessian as hessian
@ -118,17 +118,17 @@ from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.api import pure_callback as pure_callback
from jax._src.api import remat as remat
from jax._src.core import ShapedArray # TODO(jakevdp): update users to avoid this.
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation
from jax.interpreters import ad # TODO(phawkins): update users to avoid this.
from jax.interpreters import pxla # TODO(phawkins): update users to avoid this.
from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this.
from jax.interpreters import xla # TODO(phawkins): update users to avoid this.
from jax.interpreters import ad as _deprecated_ad
from jax.interpreters import partial_eval as _deprecated_partial_eval
from jax.interpreters import pxla as _deprecated_pxla
from jax.interpreters import xla as _deprecated_xla
from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
@ -175,6 +175,54 @@ from jax import util as util
from jax._src.array import Shard as Shard
_deprecations = {
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
_deprecated_ShapedArray,
),
"ad": (
"jax.ad is deprecated. Use jax.interpreters.ad",
_deprecated_ad,
),
"partial_eval": (
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
_deprecated_partial_eval,
),
"pxla": (
"jax.pxla is deprecated. Use jax.interpreters.pxla",
_deprecated_pxla,
),
"xla": (
"jax.xla is deprecated. Use jax.interpreters.xla",
_deprecated_xla,
),
"curry": (
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
_deprecated_curry,
),
"flatten_fun_nokwargs": (
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
_deprecated_flatten_fun_nokwargs,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
from jax.interpreters import pxla as pxla
from jax.interpreters import xla as xla
from jax._src.util import curry as curry
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
# TODO(yashkatariya): Remove after 2 jax releases from 0.4.6
if not config.jax_jit_pjit_api_merge:
raise ValueError(

View File

@ -9648,5 +9648,22 @@ class GarbageCollectionTest(jtu.JaxTestCase):
assert x_np_weakref() is None
class DeprecationsTest(jtu.JaxTestCase):
def test_jax_deprecations(self):
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.ShapedArray, jax.core.ShapedArray)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.ad, jax.interpreters.ad)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.partial_eval, jax.interpreters.partial_eval)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.pxla, jax.interpreters.pxla)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.xla, jax.interpreters.xla)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.curry, jax._src.util.curry)
with self.assertWarns(DeprecationWarning):
self.assertIs(jax.flatten_fun_nokwargs, jax.api_util.flatten_fun_nokwargs)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())