Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.

`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
This commit is contained in:
Yash Katariya 2022-11-02 19:12:32 -07:00 committed by jax authors
parent 5448ea6c10
commit cc5af7ed98
6 changed files with 25 additions and 15 deletions

View File

@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of {jax-issue}`#7737`.
* Deprecations
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.
## jaxlib 0.3.24
* Changes

View File

@ -194,7 +194,7 @@ class MeshPspecSharding(XLACompatibleSharding):
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
def __repr__(self):
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
def __hash__(self):
if not hasattr(self, '_hash'):
@ -253,6 +253,10 @@ class MeshPspecSharding(XLACompatibleSharding):
return sharding_spec.sharding_proto(special_axes=special_axes)
# New name of MeshPspecSharding to match with PositionalSharding below.
NamedSharding = MeshPspecSharding
@functools.lru_cache()
def _get_replicated_op_sharding():
proto = xc.OpSharding()
@ -383,7 +387,7 @@ class PmapSharding(XLACompatibleSharding):
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
class ReshapeableDevicesSharding(XLACompatibleSharding):
class PositionalSharding(XLACompatibleSharding):
_devices: List[xc.Device]
_ids: np.ndarray # dtype DeviceIdSet
@ -423,8 +427,8 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
return self.remake(self._devices, new_ids)
@classmethod
def remake(cls, devices: List[xc.Device], ids: np.ndarray
) -> ReshapeableDevicesSharding:
def remake(
cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding:
self = cls.__new__(cls)
self._devices = devices
self._ids = ids
@ -436,7 +440,7 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
return id(self._devices)
def __eq__(self, other) -> bool:
return (isinstance(other, ReshapeableDevicesSharding) and
return (isinstance(other, PositionalSharding) and
id(self._devices) == id(other._devices) and
bool(np.all(self._ids == other._ids)))

View File

@ -15,9 +15,12 @@
from jax._src.sharding import (
Sharding as Sharding,
XLACompatibleSharding as XLACompatibleSharding,
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
MeshPspecSharding as MeshPspecSharding,
# New name of MeshPspecSharding to match PositionalSharding below.
NamedSharding as NamedSharding,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
OpShardingSharding as OpShardingSharding,
ReshapeableDevicesSharding as ReshapeableDevicesSharding,
PositionalSharding as PositionalSharding,
)

View File

@ -1577,8 +1577,8 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"device_put device specification must be a tree prefix of the "
r"corresponding value, got specification \(\(MeshPspecSharding\(.*\), "
r"MeshPspecSharding\(.*\)\), MeshPspecSharding\(.*\)\) for value tree "
r"corresponding value, got specification \(\(NamedSharding\(.*\), "
r"NamedSharding\(.*\)\), NamedSharding\(.*\)\) for value tree "
r"PyTreeDef\(\(\*, \(\*, \*\)\)\)."
):
jax.device_put((x, (y, z)), device=((s1, s2), s2))
@ -1599,8 +1599,8 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"device_put device specification must be a tree prefix of the "
r"corresponding value, got specification \(MeshPspecSharding\(.*\), "
r"MeshPspecSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
r"corresponding value, got specification \(NamedSharding\(.*\), "
r"NamedSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
):
jax.device_put((x, y, z), device=(s1, s2))

View File

@ -692,7 +692,7 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 2"):
new_mps.is_compatible_aval(shape)
@ -733,9 +733,9 @@ class ShardingTest(jtu.JaxTestCase):
value_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec)
mps = sharding.NamedSharding(mesh, pspec)
devices_sharding = sharding.ReshapeableDevicesSharding(jax.devices())
devices_sharding = sharding.PositionalSharding(jax.devices())
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose:
devices_sharding = devices_sharding.T
@ -753,7 +753,7 @@ class ShardingTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
devices_sharding = sharding.ReshapeableDevicesSharding(mesh.devices)
devices_sharding = sharding.PositionalSharding(mesh.devices)
op1 = mps._to_xla_op_sharding(len(value_shape))
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))

View File

@ -1084,7 +1084,7 @@ class PJitTest(jtu.BufferDonationTestCase):
with self.assertRaisesRegex(
ValueError,
r"One of with_sharding_constraint.*Sharding "
r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3]))