Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.

PiperOrigin-RevId: 518241326
This commit is contained in:
Peter Hawkins 2023-03-21 05:13:21 -07:00 committed by jax authors
parent 9387713e04
commit e0453add22
5 changed files with 29 additions and 17 deletions

View File

@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
for which it is an alias.
* The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use
`jax.Array` instead.
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.

View File

@ -28,7 +28,6 @@ import numpy as np
import jax
from jax import tree_util
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.tree_util import tree_map
@ -52,6 +51,7 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters.batching import ConcatAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (

View File

@ -44,7 +44,7 @@ import jax
from jax import jit
from jax import errors
from jax import lax
from jax.interpreters import pxla
from jax._src.interpreters import pxla
from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import api_util

View File

@ -44,7 +44,6 @@ from jax._src.interpreters.pxla import (
SPMDBatchTrace as SPMDBatchTrace,
ShardInfo as ShardInfo,
ShardedAxis as ShardedAxis,
ShardedDeviceArray as ShardedDeviceArray,
ShardingSpec as ShardingSpec,
TileManual as TileManual,
TileVectorize as TileVectorize,
@ -122,25 +121,17 @@ from jax._src.mesh import (
from jax._src.mesh import Mesh as _deprecated_Mesh
from jax._src.interpreters.pxla import (
PartitionSpec as _deprecated_PartitionSpec,
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
make_sharded_device_array as _deprecated_make_sharded_device_array,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.mesh import Mesh as Mesh
from jax._src.interpreters.pxla import (
PartitionSpec as PartitionSpec,
make_sharded_device_array as make_sharded_device_array,
device_put as device_put,
)
del typing
_deprecations = {
# Added Feb 8, 2023:
"Mesh": (
"jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.",
_deprecated_Mesh,
),
# Added Feb 8, 2023:
"PartitionSpec": (
(
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
@ -148,6 +139,14 @@ _deprecations = {
),
_deprecated_PartitionSpec,
),
# Added March 15, 2023:
"ShardedDeviceArray": (
(
"jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use "
"jax.Array."
),
_deprecated_ShardedDeviceArray,
),
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
# is the default since November 2022.
"make_sharded_device_array": (
@ -166,6 +165,17 @@ _deprecations = {
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
import typing
if typing.TYPE_CHECKING:
from jax._src.mesh import Mesh as Mesh
from jax._src.interpreters.pxla import (
PartitionSpec as PartitionSpec,
ShardedDeviceArray as ShardedDeviceArray,
device_put as device_put,
make_sharded_device_array as make_sharded_device_array,
)
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -45,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
from jax._src import config as jax_config
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip
from jax.interpreters import pxla
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import array
from jax._src.sharding_impls import PmapSharding