Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.

`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
This commit is contained in:
Yash Katariya 2022-11-02 19:12:32 -07:00 committed by jax authors
parent 5448ea6c10
commit cc5af7ed98
6 changed files with 25 additions and 15 deletions

View File

@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly * Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of {jax-issue}`#7737`. of arrays. Part of {jax-issue}`#7737`.
* Deprecations
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.
## jaxlib 0.3.24 ## jaxlib 0.3.24
* Changes * Changes

View File

@ -194,7 +194,7 @@ class MeshPspecSharding(XLACompatibleSharding):
_check_mesh_resource_axis(self.mesh, self._parsed_pspec) _check_mesh_resource_axis(self.mesh, self._parsed_pspec)
def __repr__(self): def __repr__(self):
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})' return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
def __hash__(self): def __hash__(self):
if not hasattr(self, '_hash'): if not hasattr(self, '_hash'):
@ -253,6 +253,10 @@ class MeshPspecSharding(XLACompatibleSharding):
return sharding_spec.sharding_proto(special_axes=special_axes) return sharding_spec.sharding_proto(special_axes=special_axes)
# New name of MeshPspecSharding to match with PositionalSharding below.
NamedSharding = MeshPspecSharding
@functools.lru_cache() @functools.lru_cache()
def _get_replicated_op_sharding(): def _get_replicated_op_sharding():
proto = xc.OpSharding() proto = xc.OpSharding()
@ -383,7 +387,7 @@ class PmapSharding(XLACompatibleSharding):
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:] return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
class ReshapeableDevicesSharding(XLACompatibleSharding): class PositionalSharding(XLACompatibleSharding):
_devices: List[xc.Device] _devices: List[xc.Device]
_ids: np.ndarray # dtype DeviceIdSet _ids: np.ndarray # dtype DeviceIdSet
@ -423,8 +427,8 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
return self.remake(self._devices, new_ids) return self.remake(self._devices, new_ids)
@classmethod @classmethod
def remake(cls, devices: List[xc.Device], ids: np.ndarray def remake(
) -> ReshapeableDevicesSharding: cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding:
self = cls.__new__(cls) self = cls.__new__(cls)
self._devices = devices self._devices = devices
self._ids = ids self._ids = ids
@ -436,7 +440,7 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
return id(self._devices) return id(self._devices)
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
return (isinstance(other, ReshapeableDevicesSharding) and return (isinstance(other, PositionalSharding) and
id(self._devices) == id(other._devices) and id(self._devices) == id(other._devices) and
bool(np.all(self._ids == other._ids))) bool(np.all(self._ids == other._ids)))

View File

@ -15,9 +15,12 @@
from jax._src.sharding import ( from jax._src.sharding import (
Sharding as Sharding, Sharding as Sharding,
XLACompatibleSharding as XLACompatibleSharding, XLACompatibleSharding as XLACompatibleSharding,
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
MeshPspecSharding as MeshPspecSharding, MeshPspecSharding as MeshPspecSharding,
# New name of MeshPspecSharding to match PositionalSharding below.
NamedSharding as NamedSharding,
SingleDeviceSharding as SingleDeviceSharding, SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding, PmapSharding as PmapSharding,
OpShardingSharding as OpShardingSharding, OpShardingSharding as OpShardingSharding,
ReshapeableDevicesSharding as ReshapeableDevicesSharding, PositionalSharding as PositionalSharding,
) )

View File

@ -1577,8 +1577,8 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
"device_put device specification must be a tree prefix of the " "device_put device specification must be a tree prefix of the "
r"corresponding value, got specification \(\(MeshPspecSharding\(.*\), " r"corresponding value, got specification \(\(NamedSharding\(.*\), "
r"MeshPspecSharding\(.*\)\), MeshPspecSharding\(.*\)\) for value tree " r"NamedSharding\(.*\)\), NamedSharding\(.*\)\) for value tree "
r"PyTreeDef\(\(\*, \(\*, \*\)\)\)." r"PyTreeDef\(\(\*, \(\*, \*\)\)\)."
): ):
jax.device_put((x, (y, z)), device=((s1, s2), s2)) jax.device_put((x, (y, z)), device=((s1, s2), s2))
@ -1599,8 +1599,8 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
"device_put device specification must be a tree prefix of the " "device_put device specification must be a tree prefix of the "
r"corresponding value, got specification \(MeshPspecSharding\(.*\), " r"corresponding value, got specification \(NamedSharding\(.*\), "
r"MeshPspecSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)." r"NamedSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
): ):
jax.device_put((x, y, z), device=(s1, s2)) jax.device_put((x, y, z), device=(s1, s2))

View File

@ -692,7 +692,7 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 2"): "valid for values of rank at least 4, but was applied to a value of rank 2"):
new_mps.is_compatible_aval(shape) new_mps.is_compatible_aval(shape)
@ -733,9 +733,9 @@ class ShardingTest(jtu.JaxTestCase):
value_shape = (8, 4) value_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, pspec) mps = sharding.NamedSharding(mesh, pspec)
devices_sharding = sharding.ReshapeableDevicesSharding(jax.devices()) devices_sharding = sharding.PositionalSharding(jax.devices())
devices_sharding = devices_sharding.reshape(shape).replicate(axes) devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose: if transpose:
devices_sharding = devices_sharding.T devices_sharding = devices_sharding.T
@ -753,7 +753,7 @@ class ShardingTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.MeshPspecSharding(mesh, P('x', 'y')) mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
devices_sharding = sharding.ReshapeableDevicesSharding(mesh.devices) devices_sharding = sharding.PositionalSharding(mesh.devices)
op1 = mps._to_xla_op_sharding(len(value_shape)) op1 = mps._to_xla_op_sharding(len(value_shape))
op2 = devices_sharding._to_xla_op_sharding(len(value_shape)) op2 = devices_sharding._to_xla_op_sharding(len(value_shape))

View File

@ -1084,7 +1084,7 @@ class PJitTest(jtu.BufferDonationTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"One of with_sharding_constraint.*Sharding " r"One of with_sharding_constraint.*Sharding "
r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " r"NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 1"): "valid for values of rank at least 4, but was applied to a value of rank 1"):
pjit_f(jnp.array([1, 2, 3])) pjit_f(jnp.array([1, 2, 3]))