mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add a default PmapSharding
option which matches exactly pmap
's device placement.
PiperOrigin-RevId: 484289013
This commit is contained in:
parent
978dcde8d6
commit
9f80402845
@ -20,6 +20,7 @@ import operator as op
|
||||
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
import jax
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -329,6 +330,33 @@ class PmapSharding(XLACompatibleSharding):
|
||||
return (f'PmapSharding(sharding_spec={self.sharding_spec}, '
|
||||
f'devices={self.devices})')
|
||||
|
||||
# TODO(yashkatariya): Expose `sharded_dim_size` in the API if required.
|
||||
@classmethod
|
||||
def default(cls, shape: Shape, sharded_dim: int = 0) -> PmapSharding:
|
||||
"""Creates a `PmapSharding` which matches the implicit device order used by
|
||||
`pmap`.
|
||||
|
||||
Args:
|
||||
shape: The shape of the input array.
|
||||
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
|
||||
"""
|
||||
# The dtype doesn't matter here. Its only used for creating the
|
||||
# sharding_spec.
|
||||
aval = jax.ShapedArray(shape, np.int32)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, sharded_dim)
|
||||
|
||||
num_ways_sharded = None
|
||||
for s in sharding_spec.sharding:
|
||||
if isinstance(s, pxla.Unstacked):
|
||||
num_ways_sharded = s.size
|
||||
if num_ways_sharded is None:
|
||||
raise NotImplementedError(
|
||||
'`None` to sharded_dim is not supported. Please file a jax '
|
||||
'issue if you need this feature.')
|
||||
|
||||
pmap_devices = jax.local_devices()[:num_ways_sharded]
|
||||
return cls(pmap_devices, sharding_spec)
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def device_set(self) -> Set[Device]:
|
||||
return set(self.devices.flat)
|
||||
|
@ -761,6 +761,24 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
str(out.sharding) # doesn't crash
|
||||
repr(out.sharding) # doesn't crash
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('sharded_dim_0', (4, 2), 0),
|
||||
('sharded_dim_1_0', (4, 2), 1),
|
||||
('sharded_dim_2', (4, 2, 4), 2),
|
||||
('sharded_dim_1_1', (2, 4), 1)
|
||||
)
|
||||
def test_default_pmap_sharding(self, shape, sharded_dim):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest('Test needs >= 4 devices.')
|
||||
ps = sharding.PmapSharding.default(shape, sharded_dim)
|
||||
|
||||
inp = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
|
||||
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
|
||||
|
||||
self.assertEqual(ps._device_assignment, pmap_in_sharding._device_assignment)
|
||||
self.assertEqual(ps.sharding_spec, pmap_in_sharding.sharding_spec)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user