mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename ReshapeableDevicesSharding
to PositionalSharding
and add an alias NamedSharding
for MeshPspecSharding
.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months. PiperOrigin-RevId: 485753078
This commit is contained in:
parent
5448ea6c10
commit
cc5af7ed98
@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
|
||||
require inputs to be array-like: i.e. lists and tuples cannot be used in place
|
||||
of arrays. Part of {jax-issue}`#7737`.
|
||||
* Deprecations
|
||||
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
|
||||
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.
|
||||
|
||||
## jaxlib 0.3.24
|
||||
* Changes
|
||||
|
@ -194,7 +194,7 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||
|
||||
def __repr__(self):
|
||||
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
|
||||
return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
@ -253,6 +253,10 @@ class MeshPspecSharding(XLACompatibleSharding):
|
||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||
|
||||
|
||||
# New name of MeshPspecSharding to match with PositionalSharding below.
|
||||
NamedSharding = MeshPspecSharding
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_replicated_op_sharding():
|
||||
proto = xc.OpSharding()
|
||||
@ -383,7 +387,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||
|
||||
|
||||
class ReshapeableDevicesSharding(XLACompatibleSharding):
|
||||
class PositionalSharding(XLACompatibleSharding):
|
||||
_devices: List[xc.Device]
|
||||
_ids: np.ndarray # dtype DeviceIdSet
|
||||
|
||||
@ -423,8 +427,8 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
|
||||
return self.remake(self._devices, new_ids)
|
||||
|
||||
@classmethod
|
||||
def remake(cls, devices: List[xc.Device], ids: np.ndarray
|
||||
) -> ReshapeableDevicesSharding:
|
||||
def remake(
|
||||
cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding:
|
||||
self = cls.__new__(cls)
|
||||
self._devices = devices
|
||||
self._ids = ids
|
||||
@ -436,7 +440,7 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
|
||||
return id(self._devices)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return (isinstance(other, ReshapeableDevicesSharding) and
|
||||
return (isinstance(other, PositionalSharding) and
|
||||
id(self._devices) == id(other._devices) and
|
||||
bool(np.all(self._ids == other._ids)))
|
||||
|
||||
|
@ -15,9 +15,12 @@
|
||||
from jax._src.sharding import (
|
||||
Sharding as Sharding,
|
||||
XLACompatibleSharding as XLACompatibleSharding,
|
||||
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
|
||||
MeshPspecSharding as MeshPspecSharding,
|
||||
# New name of MeshPspecSharding to match PositionalSharding below.
|
||||
NamedSharding as NamedSharding,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
OpShardingSharding as OpShardingSharding,
|
||||
ReshapeableDevicesSharding as ReshapeableDevicesSharding,
|
||||
PositionalSharding as PositionalSharding,
|
||||
)
|
||||
|
@ -1577,8 +1577,8 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"device_put device specification must be a tree prefix of the "
|
||||
r"corresponding value, got specification \(\(MeshPspecSharding\(.*\), "
|
||||
r"MeshPspecSharding\(.*\)\), MeshPspecSharding\(.*\)\) for value tree "
|
||||
r"corresponding value, got specification \(\(NamedSharding\(.*\), "
|
||||
r"NamedSharding\(.*\)\), NamedSharding\(.*\)\) for value tree "
|
||||
r"PyTreeDef\(\(\*, \(\*, \*\)\)\)."
|
||||
):
|
||||
jax.device_put((x, (y, z)), device=((s1, s2), s2))
|
||||
@ -1599,8 +1599,8 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"device_put device specification must be a tree prefix of the "
|
||||
r"corresponding value, got specification \(MeshPspecSharding\(.*\), "
|
||||
r"MeshPspecSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
|
||||
r"corresponding value, got specification \(NamedSharding\(.*\), "
|
||||
r"NamedSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
|
||||
):
|
||||
jax.device_put((x, y, z), device=(s1, s2))
|
||||
|
||||
|
@ -692,7 +692,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||
r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 2"):
|
||||
new_mps.is_compatible_aval(shape)
|
||||
@ -733,9 +733,9 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
value_shape = (8, 4)
|
||||
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mps = sharding.MeshPspecSharding(mesh, pspec)
|
||||
mps = sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
devices_sharding = sharding.ReshapeableDevicesSharding(jax.devices())
|
||||
devices_sharding = sharding.PositionalSharding(jax.devices())
|
||||
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
|
||||
if transpose:
|
||||
devices_sharding = devices_sharding.T
|
||||
@ -753,7 +753,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
|
||||
|
||||
devices_sharding = sharding.ReshapeableDevicesSharding(mesh.devices)
|
||||
devices_sharding = sharding.PositionalSharding(mesh.devices)
|
||||
|
||||
op1 = mps._to_xla_op_sharding(len(value_shape))
|
||||
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))
|
||||
|
@ -1084,7 +1084,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"One of with_sharding_constraint.*Sharding "
|
||||
r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||
r"NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||
pjit_f(jnp.array([1, 2, 3]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user