diff --git a/CHANGELOG.md b/CHANGELOG.md index 77e9d6b54..72afc06c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list * Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of {jax-issue}`#7737`. +* Deprecations + * `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`. + `jax.sharding.MeshPspecSharding` name will be removed in 3 months. ## jaxlib 0.3.24 * Changes diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index f88803ae7..cfc748ada 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -194,7 +194,7 @@ class MeshPspecSharding(XLACompatibleSharding): _check_mesh_resource_axis(self.mesh, self._parsed_pspec) def __repr__(self): - return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})' + return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})' def __hash__(self): if not hasattr(self, '_hash'): @@ -253,6 +253,10 @@ class MeshPspecSharding(XLACompatibleSharding): return sharding_spec.sharding_proto(special_axes=special_axes) +# New name of MeshPspecSharding to match with PositionalSharding below. +NamedSharding = MeshPspecSharding + + @functools.lru_cache() def _get_replicated_op_sharding(): proto = xc.OpSharding() @@ -383,7 +387,7 @@ class PmapSharding(XLACompatibleSharding): return global_shape[:sharded_dim] + global_shape[sharded_dim+1:] -class ReshapeableDevicesSharding(XLACompatibleSharding): +class PositionalSharding(XLACompatibleSharding): _devices: List[xc.Device] _ids: np.ndarray # dtype DeviceIdSet @@ -423,8 +427,8 @@ class ReshapeableDevicesSharding(XLACompatibleSharding): return self.remake(self._devices, new_ids) @classmethod - def remake(cls, devices: List[xc.Device], ids: np.ndarray - ) -> ReshapeableDevicesSharding: + def remake( + cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding: self = cls.__new__(cls) self._devices = devices self._ids = ids @@ -436,7 +440,7 @@ class ReshapeableDevicesSharding(XLACompatibleSharding): return id(self._devices) def __eq__(self, other) -> bool: - return (isinstance(other, ReshapeableDevicesSharding) and + return (isinstance(other, PositionalSharding) and id(self._devices) == id(other._devices) and bool(np.all(self._ids == other._ids))) diff --git a/jax/sharding.py b/jax/sharding.py index bdf5b6c25..7a120edfd 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -15,9 +15,12 @@ from jax._src.sharding import ( Sharding as Sharding, XLACompatibleSharding as XLACompatibleSharding, + # TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months. MeshPspecSharding as MeshPspecSharding, + # New name of MeshPspecSharding to match PositionalSharding below. + NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, OpShardingSharding as OpShardingSharding, - ReshapeableDevicesSharding as ReshapeableDevicesSharding, + PositionalSharding as PositionalSharding, ) diff --git a/tests/api_test.py b/tests/api_test.py index 180edabd4..a901c6639 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1577,8 +1577,8 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "device_put device specification must be a tree prefix of the " - r"corresponding value, got specification \(\(MeshPspecSharding\(.*\), " - r"MeshPspecSharding\(.*\)\), MeshPspecSharding\(.*\)\) for value tree " + r"corresponding value, got specification \(\(NamedSharding\(.*\), " + r"NamedSharding\(.*\)\), NamedSharding\(.*\)\) for value tree " r"PyTreeDef\(\(\*, \(\*, \*\)\)\)." ): jax.device_put((x, (y, z)), device=((s1, s2), s2)) @@ -1599,8 +1599,8 @@ class APITest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, "device_put device specification must be a tree prefix of the " - r"corresponding value, got specification \(MeshPspecSharding\(.*\), " - r"MeshPspecSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)." + r"corresponding value, got specification \(NamedSharding\(.*\), " + r"NamedSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)." ): jax.device_put((x, y, z), device=(s1, s2)) diff --git a/tests/array_test.py b/tests/array_test.py index 48c5594cf..d1101b703 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -692,7 +692,7 @@ class ShardingTest(jtu.JaxTestCase): with self.assertRaisesRegex( ValueError, - r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " + r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " "valid for values of rank at least 4, but was applied to a value of rank 2"): new_mps.is_compatible_aval(shape) @@ -733,9 +733,9 @@ class ShardingTest(jtu.JaxTestCase): value_shape = (8, 4) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - mps = sharding.MeshPspecSharding(mesh, pspec) + mps = sharding.NamedSharding(mesh, pspec) - devices_sharding = sharding.ReshapeableDevicesSharding(jax.devices()) + devices_sharding = sharding.PositionalSharding(jax.devices()) devices_sharding = devices_sharding.reshape(shape).replicate(axes) if transpose: devices_sharding = devices_sharding.T @@ -753,7 +753,7 @@ class ShardingTest(jtu.JaxTestCase): mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) mps = sharding.MeshPspecSharding(mesh, P('x', 'y')) - devices_sharding = sharding.ReshapeableDevicesSharding(mesh.devices) + devices_sharding = sharding.PositionalSharding(mesh.devices) op1 = mps._to_xla_op_sharding(len(value_shape)) op2 = devices_sharding._to_xla_op_sharding(len(value_shape)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 4374ed9b3..aee9dbd12 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1084,7 +1084,7 @@ class PJitTest(jtu.BufferDonationTestCase): with self.assertRaisesRegex( ValueError, r"One of with_sharding_constraint.*Sharding " - r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " + r"NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, " r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only " "valid for values of rank at least 4, but was applied to a value of rank 1"): pjit_f(jnp.array([1, 2, 3]))