Add a default PmapSharding option which matches exactly pmap's device placement.

PiperOrigin-RevId: 484289013
This commit is contained in:
Yash Katariya 2022-10-27 10:27:46 -07:00 committed by jax authors
parent 978dcde8d6
commit 9f80402845
2 changed files with 46 additions and 0 deletions

View File

@ -20,6 +20,7 @@ import operator as op
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import jax
from jax._src.util import safe_map, safe_zip
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
@ -329,6 +330,33 @@ class PmapSharding(XLACompatibleSharding):
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
f'devices={self.devices})')
# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
@classmethod
def default(cls, shape: Shape, sharded_dim: int = 0) -> PmapSharding:
"""Creates a `PmapSharding` which matches the implicit device order used by
`pmap`.
Args:
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
aval = jax.ShapedArray(shape, np.int32)
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
num_ways_sharded = None
for s in sharding_spec.sharding:
if isinstance(s, pxla.Unstacked):
num_ways_sharded = s.size
if num_ways_sharded is None:
raise NotImplementedError(
'`None` to sharded_dim is not supported. Please file a jax '
'issue if you need this feature.')
pmap_devices = jax.local_devices()[:num_ways_sharded]
return cls(pmap_devices, sharding_spec)
@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
return set(self.devices.flat)

View File

@ -761,6 +761,24 @@ class ShardingTest(jtu.JaxTestCase):
str(out.sharding) # doesn't crash
repr(out.sharding) # doesn't crash
@parameterized.named_parameters(
('sharded_dim_0', (4, 2), 0),
('sharded_dim_1_0', (4, 2), 1),
('sharded_dim_2', (4, 2, 4), 2),
('sharded_dim_1_1', (2, 4), 1)
)
def test_default_pmap_sharding(self, shape, sharded_dim):
if jax.device_count() < 4:
self.skipTest('Test needs >= 4 devices.')
ps = sharding.PmapSharding.default(shape, sharded_dim)
inp = jnp.arange(np.prod(shape)).reshape(shape)
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment)
self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())