diff --git a/CHANGELOG.md b/CHANGELOG.md index 05e1d7f55..3a4c88ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead, for which it is an alias. + * The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use + `jax.Array` instead. * Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` * `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 67a9c14cc..b95c8b5f9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -28,7 +28,6 @@ import numpy as np import jax from jax import tree_util from jax.interpreters import partial_eval as pe -from jax.interpreters import pxla from jax.interpreters import xla from jax.tree_util import tree_map @@ -52,6 +51,7 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import pxla from jax._src.interpreters.batching import ConcatAxis from jax._src.lax import slicing from jax._src.lax.utils import ( diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a698a9342..71f10ba07 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -44,7 +44,7 @@ import jax from jax import jit from jax import errors from jax import lax -from jax.interpreters import pxla +from jax._src.interpreters import pxla from jax.tree_util import tree_leaves, tree_flatten, tree_map from jax._src import api_util diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 437da3e1a..b13507b46 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -44,7 +44,6 @@ from jax._src.interpreters.pxla import ( SPMDBatchTrace as SPMDBatchTrace, ShardInfo as ShardInfo, ShardedAxis as ShardedAxis, - ShardedDeviceArray as ShardedDeviceArray, ShardingSpec as ShardingSpec, TileManual as TileManual, TileVectorize as TileVectorize, @@ -122,25 +121,17 @@ from jax._src.mesh import ( from jax._src.mesh import Mesh as _deprecated_Mesh from jax._src.interpreters.pxla import ( PartitionSpec as _deprecated_PartitionSpec, + ShardedDeviceArray as _deprecated_ShardedDeviceArray, make_sharded_device_array as _deprecated_make_sharded_device_array, ) -import typing -if typing.TYPE_CHECKING: - from jax._src.mesh import Mesh as Mesh - from jax._src.interpreters.pxla import ( - PartitionSpec as PartitionSpec, - make_sharded_device_array as make_sharded_device_array, - device_put as device_put, - ) -del typing - _deprecations = { # Added Feb 8, 2023: "Mesh": ( "jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.", _deprecated_Mesh, ), + # Added Feb 8, 2023: "PartitionSpec": ( ( "jax.interpreters.pxla.PartitionSpec is deprecated. Use " @@ -148,6 +139,14 @@ _deprecations = { ), _deprecated_PartitionSpec, ), + # Added March 15, 2023: + "ShardedDeviceArray": ( + ( + "jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use " + "jax.Array." + ), + _deprecated_ShardedDeviceArray, + ), # make_sharded_device_array is deprecated as of March 3, 2023. jax.Array # is the default since November 2022. "make_sharded_device_array": ( @@ -166,6 +165,17 @@ _deprecations = { ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + from jax._src.mesh import Mesh as Mesh + from jax._src.interpreters.pxla import ( + PartitionSpec as PartitionSpec, + ShardedDeviceArray as ShardedDeviceArray, + device_put as device_put, + make_sharded_device_array as make_sharded_device_array, + ) +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 7814a0070..4d8c5a36f 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -45,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr, from jax._src import config as jax_config from jax._src import xla_bridge from jax._src.util import safe_map, safe_zip -from jax.interpreters import pxla +from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import array from jax._src.sharding_impls import PmapSharding