mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.
PiperOrigin-RevId: 662938823
This commit is contained in:
parent
df2e9c3836
commit
229cbae5ea
@ -489,7 +489,7 @@ class ArrayImpl(basearray.Array):
|
||||
"""Returns the total global on-device size of the array in bytes."""
|
||||
arr = self._arrays[0]
|
||||
per_shard_size = arr.on_device_size_in_bytes()
|
||||
return per_shard_size * len(self.sharding.device_set)
|
||||
return per_shard_size * self.sharding.num_devices
|
||||
|
||||
def devices(self) -> set[Device]:
|
||||
self._check_if_deleted()
|
||||
|
@ -418,7 +418,7 @@ def _device_put_sharding_impl(x, aval, device):
|
||||
return _different_device_order_reshard(x, s)
|
||||
|
||||
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
|
||||
x.is_fully_addressable and len(s.device_set) > 1 and
|
||||
x.is_fully_addressable and s.num_devices > 1 and
|
||||
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
|
||||
s.device_set == x.sharding.device_set):
|
||||
assert isinstance(s, Sharding)
|
||||
|
@ -144,6 +144,11 @@ class Sharding:
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
"""Number of devices that the sharding contains."""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@property
|
||||
def memory_kind(self) -> str | None:
|
||||
"""Returns the memory kind of the sharding."""
|
||||
|
@ -257,6 +257,10 @@ class NamedSharding(sharding.Sharding):
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec,
|
||||
_manual_axes=_manual_axes)
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return self.mesh.size
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
@ -366,6 +370,10 @@ class SingleDeviceSharding(sharding.Sharding):
|
||||
return (self._device == other._device and
|
||||
self.memory_kind == other.memory_kind)
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return len(self.device_set)
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
return {self._device}
|
||||
@ -501,6 +509,10 @@ class PmapSharding(sharding.Sharding):
|
||||
pmap_devices = np.array(devices)
|
||||
return cls(pmap_devices, sharding_spec)
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return len(self.device_set)
|
||||
|
||||
@functools.cached_property
|
||||
def device_set(self) -> set[Device]:
|
||||
return set(self.devices.flat)
|
||||
@ -707,6 +719,10 @@ class PositionalSharding(sharding.Sharding):
|
||||
|
||||
# Sharding interface
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return len(self.device_set)
|
||||
|
||||
@functools.cached_property
|
||||
def device_set(self) -> set[xc.Device]:
|
||||
return set(self._devices)
|
||||
@ -826,6 +842,10 @@ class GSPMDSharding(sharding.Sharding):
|
||||
f"{len(num_ways_dim_sharded)}, but was applied to a value of rank "
|
||||
f"{len(aval_shape)}")
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return len(self.device_set)
|
||||
|
||||
@functools.cached_property
|
||||
def device_set(self) -> set[Device]:
|
||||
return set(self._devices)
|
||||
@ -1405,12 +1425,12 @@ def get_process_index_and_count(
|
||||
if (tensor_sharding.is_fully_addressable or
|
||||
tensor_sharding.is_fully_replicated):
|
||||
return (0, 1)
|
||||
num_devices = len(tensor_sharding.device_set)
|
||||
# Get device to indices map, we don't care about the concrete
|
||||
# global shape here, only to get the distribution of shards across the tensor
|
||||
# using (num_devices, num_devices, ...) This is a universal shape that is
|
||||
# compatible with any mesh with num_devices.
|
||||
device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims)
|
||||
device_map = tensor_sharding.devices_indices_map(
|
||||
(tensor_sharding.num_devices,) * ndims)
|
||||
|
||||
# Get the slices for 'dim' for all devices.
|
||||
global_slice = {k: v[dim] for k, v in device_map.items()}
|
||||
@ -1564,7 +1584,7 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
def is_single_device_sharding(sharding: sharding.Sharding) -> bool:
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding)
|
||||
|
||||
def make_key_array_phys_sharding(aval, sharding):
|
||||
if is_single_device_sharding(sharding):
|
||||
|
@ -35,6 +35,7 @@ from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
@ -1756,6 +1757,22 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
)
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_sharded_prng_with_abstract_mesh(self):
|
||||
shape = (8, 2, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape)
|
||||
key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl)
|
||||
key = jax.device_put(key, NamedSharding(mesh, P()))
|
||||
|
||||
@jax.jit
|
||||
def shard_key(key):
|
||||
return shard_map(
|
||||
lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key)
|
||||
|
||||
out = shard_key(key)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
def test_partial_auto_error_wsc_manual(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user