mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename ReshapeableDevicesSharding
to PositionalSharding
and add an alias NamedSharding
for MeshPspecSharding
.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months. PiperOrigin-RevId: 485753078
This commit is contained in:
parent
5448ea6c10
commit
cc5af7ed98
@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
|
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
|
||||||
require inputs to be array-like: i.e. lists and tuples cannot be used in place
|
require inputs to be array-like: i.e. lists and tuples cannot be used in place
|
||||||
of arrays. Part of {jax-issue}`#7737`.
|
of arrays. Part of {jax-issue}`#7737`.
|
||||||
|
* Deprecations
|
||||||
|
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
|
||||||
|
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.
|
||||||
|
|
||||||
## jaxlib 0.3.24
|
## jaxlib 0.3.24
|
||||||
* Changes
|
* Changes
|
||||||
|
@ -194,7 +194,7 @@ class MeshPspecSharding(XLACompatibleSharding):
|
|||||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'MeshPspecSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
|
return f'NamedSharding(mesh={dict(self.mesh.shape)}, partition_spec={self.spec})'
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
if not hasattr(self, '_hash'):
|
if not hasattr(self, '_hash'):
|
||||||
@ -253,6 +253,10 @@ class MeshPspecSharding(XLACompatibleSharding):
|
|||||||
return sharding_spec.sharding_proto(special_axes=special_axes)
|
return sharding_spec.sharding_proto(special_axes=special_axes)
|
||||||
|
|
||||||
|
|
||||||
|
# New name of MeshPspecSharding to match with PositionalSharding below.
|
||||||
|
NamedSharding = MeshPspecSharding
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def _get_replicated_op_sharding():
|
def _get_replicated_op_sharding():
|
||||||
proto = xc.OpSharding()
|
proto = xc.OpSharding()
|
||||||
@ -383,7 +387,7 @@ class PmapSharding(XLACompatibleSharding):
|
|||||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||||
|
|
||||||
|
|
||||||
class ReshapeableDevicesSharding(XLACompatibleSharding):
|
class PositionalSharding(XLACompatibleSharding):
|
||||||
_devices: List[xc.Device]
|
_devices: List[xc.Device]
|
||||||
_ids: np.ndarray # dtype DeviceIdSet
|
_ids: np.ndarray # dtype DeviceIdSet
|
||||||
|
|
||||||
@ -423,8 +427,8 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
|
|||||||
return self.remake(self._devices, new_ids)
|
return self.remake(self._devices, new_ids)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def remake(cls, devices: List[xc.Device], ids: np.ndarray
|
def remake(
|
||||||
) -> ReshapeableDevicesSharding:
|
cls, devices: List[xc.Device], ids: np.ndarray) -> PositionalSharding:
|
||||||
self = cls.__new__(cls)
|
self = cls.__new__(cls)
|
||||||
self._devices = devices
|
self._devices = devices
|
||||||
self._ids = ids
|
self._ids = ids
|
||||||
@ -436,7 +440,7 @@ class ReshapeableDevicesSharding(XLACompatibleSharding):
|
|||||||
return id(self._devices)
|
return id(self._devices)
|
||||||
|
|
||||||
def __eq__(self, other) -> bool:
|
def __eq__(self, other) -> bool:
|
||||||
return (isinstance(other, ReshapeableDevicesSharding) and
|
return (isinstance(other, PositionalSharding) and
|
||||||
id(self._devices) == id(other._devices) and
|
id(self._devices) == id(other._devices) and
|
||||||
bool(np.all(self._ids == other._ids)))
|
bool(np.all(self._ids == other._ids)))
|
||||||
|
|
||||||
|
@ -15,9 +15,12 @@
|
|||||||
from jax._src.sharding import (
|
from jax._src.sharding import (
|
||||||
Sharding as Sharding,
|
Sharding as Sharding,
|
||||||
XLACompatibleSharding as XLACompatibleSharding,
|
XLACompatibleSharding as XLACompatibleSharding,
|
||||||
|
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
|
||||||
MeshPspecSharding as MeshPspecSharding,
|
MeshPspecSharding as MeshPspecSharding,
|
||||||
|
# New name of MeshPspecSharding to match PositionalSharding below.
|
||||||
|
NamedSharding as NamedSharding,
|
||||||
SingleDeviceSharding as SingleDeviceSharding,
|
SingleDeviceSharding as SingleDeviceSharding,
|
||||||
PmapSharding as PmapSharding,
|
PmapSharding as PmapSharding,
|
||||||
OpShardingSharding as OpShardingSharding,
|
OpShardingSharding as OpShardingSharding,
|
||||||
ReshapeableDevicesSharding as ReshapeableDevicesSharding,
|
PositionalSharding as PositionalSharding,
|
||||||
)
|
)
|
||||||
|
@ -1577,8 +1577,8 @@ class APITest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
"device_put device specification must be a tree prefix of the "
|
"device_put device specification must be a tree prefix of the "
|
||||||
r"corresponding value, got specification \(\(MeshPspecSharding\(.*\), "
|
r"corresponding value, got specification \(\(NamedSharding\(.*\), "
|
||||||
r"MeshPspecSharding\(.*\)\), MeshPspecSharding\(.*\)\) for value tree "
|
r"NamedSharding\(.*\)\), NamedSharding\(.*\)\) for value tree "
|
||||||
r"PyTreeDef\(\(\*, \(\*, \*\)\)\)."
|
r"PyTreeDef\(\(\*, \(\*, \*\)\)\)."
|
||||||
):
|
):
|
||||||
jax.device_put((x, (y, z)), device=((s1, s2), s2))
|
jax.device_put((x, (y, z)), device=((s1, s2), s2))
|
||||||
@ -1599,8 +1599,8 @@ class APITest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
"device_put device specification must be a tree prefix of the "
|
"device_put device specification must be a tree prefix of the "
|
||||||
r"corresponding value, got specification \(MeshPspecSharding\(.*\), "
|
r"corresponding value, got specification \(NamedSharding\(.*\), "
|
||||||
r"MeshPspecSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
|
r"NamedSharding\(.*\)\) for value tree PyTreeDef\(\(\*, \*, \*\)\)."
|
||||||
):
|
):
|
||||||
jax.device_put((x, y, z), device=(s1, s2))
|
jax.device_put((x, y, z), device=(s1, s2))
|
||||||
|
|
||||||
|
@ -692,7 +692,7 @@ class ShardingTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"Sharding MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
r"Sharding NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||||
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||||
"valid for values of rank at least 4, but was applied to a value of rank 2"):
|
"valid for values of rank at least 4, but was applied to a value of rank 2"):
|
||||||
new_mps.is_compatible_aval(shape)
|
new_mps.is_compatible_aval(shape)
|
||||||
@ -733,9 +733,9 @@ class ShardingTest(jtu.JaxTestCase):
|
|||||||
value_shape = (8, 4)
|
value_shape = (8, 4)
|
||||||
|
|
||||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||||
mps = sharding.MeshPspecSharding(mesh, pspec)
|
mps = sharding.NamedSharding(mesh, pspec)
|
||||||
|
|
||||||
devices_sharding = sharding.ReshapeableDevicesSharding(jax.devices())
|
devices_sharding = sharding.PositionalSharding(jax.devices())
|
||||||
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
|
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
|
||||||
if transpose:
|
if transpose:
|
||||||
devices_sharding = devices_sharding.T
|
devices_sharding = devices_sharding.T
|
||||||
@ -753,7 +753,7 @@ class ShardingTest(jtu.JaxTestCase):
|
|||||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||||
mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
|
mps = sharding.MeshPspecSharding(mesh, P('x', 'y'))
|
||||||
|
|
||||||
devices_sharding = sharding.ReshapeableDevicesSharding(mesh.devices)
|
devices_sharding = sharding.PositionalSharding(mesh.devices)
|
||||||
|
|
||||||
op1 = mps._to_xla_op_sharding(len(value_shape))
|
op1 = mps._to_xla_op_sharding(len(value_shape))
|
||||||
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))
|
op2 = devices_sharding._to_xla_op_sharding(len(value_shape))
|
||||||
|
@ -1084,7 +1084,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
r"One of with_sharding_constraint.*Sharding "
|
r"One of with_sharding_constraint.*Sharding "
|
||||||
r"MeshPspecSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
r"NamedSharding\(mesh={'replica': 1, 'data': 1, 'mdl': 2}, "
|
||||||
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
r"partition_spec=PartitionSpec\(None, \('mdl',\), None, None\)\) is only "
|
||||||
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
"valid for values of rank at least 4, but was applied to a value of rank 1"):
|
||||||
pjit_f(jnp.array([1, 2, 3]))
|
pjit_f(jnp.array([1, 2, 3]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user