mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
PiperOrigin-RevId: 518241326
This commit is contained in:
parent
9387713e04
commit
e0453add22
@ -24,6 +24,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
|
||||
for which it is an alias.
|
||||
* The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use
|
||||
`jax.Array` instead.
|
||||
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
|
||||
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
|
||||
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
|
||||
|
@ -28,7 +28,6 @@ import numpy as np
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
@ -52,6 +51,7 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters.batching import ConcatAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
|
@ -44,7 +44,7 @@ import jax
|
||||
from jax import jit
|
||||
from jax import errors
|
||||
from jax import lax
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.tree_util import tree_leaves, tree_flatten, tree_map
|
||||
|
||||
from jax._src import api_util
|
||||
|
@ -44,7 +44,6 @@ from jax._src.interpreters.pxla import (
|
||||
SPMDBatchTrace as SPMDBatchTrace,
|
||||
ShardInfo as ShardInfo,
|
||||
ShardedAxis as ShardedAxis,
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
ShardingSpec as ShardingSpec,
|
||||
TileManual as TileManual,
|
||||
TileVectorize as TileVectorize,
|
||||
@ -122,25 +121,17 @@ from jax._src.mesh import (
|
||||
from jax._src.mesh import Mesh as _deprecated_Mesh
|
||||
from jax._src.interpreters.pxla import (
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
|
||||
make_sharded_device_array as _deprecated_make_sharded_device_array,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.mesh import Mesh as Mesh
|
||||
from jax._src.interpreters.pxla import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
device_put as device_put,
|
||||
)
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
# Added Feb 8, 2023:
|
||||
"Mesh": (
|
||||
"jax.interpreters.pxla.Mesh is deprecated. Use jax.sharding.Mesh.",
|
||||
_deprecated_Mesh,
|
||||
),
|
||||
# Added Feb 8, 2023:
|
||||
"PartitionSpec": (
|
||||
(
|
||||
"jax.interpreters.pxla.PartitionSpec is deprecated. Use "
|
||||
@ -148,6 +139,14 @@ _deprecations = {
|
||||
),
|
||||
_deprecated_PartitionSpec,
|
||||
),
|
||||
# Added March 15, 2023:
|
||||
"ShardedDeviceArray": (
|
||||
(
|
||||
"jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use "
|
||||
"jax.Array."
|
||||
),
|
||||
_deprecated_ShardedDeviceArray,
|
||||
),
|
||||
# make_sharded_device_array is deprecated as of March 3, 2023. jax.Array
|
||||
# is the default since November 2022.
|
||||
"make_sharded_device_array": (
|
||||
@ -166,6 +165,17 @@ _deprecations = {
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.mesh import Mesh as Mesh
|
||||
from jax._src.interpreters.pxla import (
|
||||
PartitionSpec as PartitionSpec,
|
||||
ShardedDeviceArray as ShardedDeviceArray,
|
||||
device_put as device_put,
|
||||
make_sharded_device_array as make_sharded_device_array,
|
||||
)
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -45,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import array
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
|
Loading…
x
Reference in New Issue
Block a user