Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.

PiperOrigin-RevId: 662938823
This commit is contained in:
Yash Katariya 2024-08-14 09:02:20 -07:00 committed by jax authors
parent df2e9c3836
commit 229cbae5ea
5 changed files with 47 additions and 5 deletions

View File

@ -489,7 +489,7 @@ class ArrayImpl(basearray.Array):
"""Returns the total global on-device size of the array in bytes."""
arr = self._arrays[0]
per_shard_size = arr.on_device_size_in_bytes()
return per_shard_size * len(self.sharding.device_set)
return per_shard_size * self.sharding.num_devices
def devices(self) -> set[Device]:
self._check_if_deleted()

View File

@ -418,7 +418,7 @@ def _device_put_sharding_impl(x, aval, device):
return _different_device_order_reshard(x, s)
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
x.is_fully_addressable and len(s.device_set) > 1 and
x.is_fully_addressable and s.num_devices > 1 and
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
s.device_set == x.sharding.device_set):
assert isinstance(s, Sharding)

View File

@ -144,6 +144,11 @@ class Sharding:
"""
raise NotImplementedError('Subclasses should implement this method.')
@property
def num_devices(self) -> int:
"""Number of devices that the sharding contains."""
raise NotImplementedError('Subclasses should implement this method.')
@property
def memory_kind(self) -> str | None:
"""Returns the memory kind of the sharding."""

View File

@ -257,6 +257,10 @@ class NamedSharding(sharding.Sharding):
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
_manual_axes=_manual_axes)
@property
def num_devices(self) -> int:
return self.mesh.size
@property
def device_set(self) -> set[Device]:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
@ -366,6 +370,10 @@ class SingleDeviceSharding(sharding.Sharding):
return (self._device == other._device and
self.memory_kind == other.memory_kind)
@property
def num_devices(self) -> int:
return len(self.device_set)
@property
def device_set(self) -> set[Device]:
return {self._device}
@ -501,6 +509,10 @@ class PmapSharding(sharding.Sharding):
pmap_devices = np.array(devices)
return cls(pmap_devices, sharding_spec)
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[Device]:
return set(self.devices.flat)
@ -707,6 +719,10 @@ class PositionalSharding(sharding.Sharding):
# Sharding interface
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[xc.Device]:
return set(self._devices)
@ -826,6 +842,10 @@ class GSPMDSharding(sharding.Sharding):
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
f"{len(aval_shape)}")
@property
def num_devices(self) -> int:
return len(self.device_set)
@functools.cached_property
def device_set(self) -> set[Device]:
return set(self._devices)
@ -1405,12 +1425,12 @@ def get_process_index_and_count(
if (tensor_sharding.is_fully_addressable or
tensor_sharding.is_fully_replicated):
return (0, 1)
num_devices = len(tensor_sharding.device_set)
# Get device to indices map, we don't care about the concrete
# global shape here, only to get the distribution of shards across the tensor
# using (num_devices, num_devices, ...) This is a universal shape that is
# compatible with any mesh with num_devices.
device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims)
device_map = tensor_sharding.devices_indices_map(
(tensor_sharding.num_devices,) * ndims)
# Get the slices for 'dim' for all devices.
global_slice = {k: v[dim] for k, v in device_map.items()}
@ -1564,7 +1584,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.test_pjit_single_device_sharding_add
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
def make_key_array_phys_sharding(aval, sharding):
if is_single_device_sharding(sharding):

View File

@ -35,6 +35,7 @@ from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import core
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
@ -1756,6 +1757,22 @@ class ShardMapTest(jtu.JaxTestCase):
)
self.assertAllClose(v*v, f(v), check_dtypes=False)
def test_sharded_prng_with_abstract_mesh(self):
shape = (8, 2, 2)
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape)
key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl)
key = jax.device_put(key, NamedSharding(mesh, P()))
@jax.jit
def shard_key(key):
return shard_map(
lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key)
out = shard_key(key)
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
def test_partial_auto_error_wsc_manual(self):
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))