Delete OpShardingSharding export since it has been 3 months since it was deprecated. Also remove deprecation warnings for MeshPspecSharding.

PiperOrigin-RevId: 538880293
This commit is contained in:
Yash Katariya 2023-06-08 13:45:43 -07:00 committed by jax authors
parent 03b06b6aa4
commit 14451492c9
2 changed files with 2 additions and 30 deletions

View File

@ -22,6 +22,8 @@ Remember to align the itemized text with the first line of an item within a list
of `numpy.product` in NumPy version 1.25.0.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
* `jax.sharding.OpShardingSharding` has been removed since it has been 3
months since it was deprecated.
## jaxlib 0.4.12

View File

@ -22,39 +22,9 @@ from jax._src.sharding_impls import (
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
GSPMDSharding as GSPMDSharding,
# TODO(yashkatariya): Remove OpShardingSharding in 3 months from
# Feb 17, 2023.
GSPMDSharding as _deprecated_OpShardingSharding,
PositionalSharding as PositionalSharding,
)
from jax._src.partition_spec import (
PartitionSpec as PartitionSpec,
)
from jax._src.interpreters.pxla import Mesh as Mesh
_deprecations = {
"OpShardingSharding": (
(
"jax.sharding.OpShardingSharding is deprecated. Please use"
" jax.sharding.GSPMDSharding."
),
_deprecated_OpShardingSharding,
),
"MeshPspecSharding": (
(
"jax.sharding.MeshPspecSharding has been removed. Please use"
" jax.sharding.NamedSharding."
),
None,
),
}
import typing
if typing.TYPE_CHECKING:
from jax._src.sharding_impls import GSPMDSharding as OpShardingSharding
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing