Add support for MeshPspecSharding local_sharded_result_handler because SDA outputs from pjit can produce a MeshPspecSharding.

PiperOrigin-RevId: 470119499
This commit is contained in:
Yash Katariya 2022-08-25 17:13:33 -07:00 committed by jax authors
parent cc8d406bdb
commit 96058d0197
2 changed files with 32 additions and 6 deletions

View File

@ -319,12 +319,20 @@ class KeyTy:
(core.ShapedArray, output_type)]
# set up a grounded sharding (with a grounded sharding spec)
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
phys_sharding_spec = pxla.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
phys_sharding = PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
if isinstance(sharding, PmapSharding):
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
phys_sharding_spec = pxla.ShardingSpec(
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
mesh_mapping=sharding.sharding_spec.mesh_mapping)
phys_sharding = PmapSharding(devices=sharding.devices,
sharding_spec=phys_sharding_spec)
elif isinstance(sharding, MeshPspecSharding):
trailing_spec = [None] * len(key_shape)
phys_sharding = MeshPspecSharding(
sharding.mesh,
pxla.PartitionSpec(*sharding.spec, *trailing_spec))
else:
assert False, f'impossible sharding {sharding} in local sharded result handler'
# set up grounded indices
trailing_inds = [slice(None)] * len(key_shape)

View File

@ -1033,6 +1033,24 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(exe, stages.Compiled)
self.assertArraysEqual(exe(x, x), x @ x)
def test_local_sharded_key_array_sda(self):
input_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
seeds = jnp.arange(
prod(input_shape), dtype=np.uint32).reshape(input_shape)
with mesh:
def make_keys(seeds):
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
return make_key(seeds)
f = pjit(make_keys, in_axis_resources=P(None), out_axis_resources=P(None))
out = f(seeds)
self.assertIsInstance(out, jax.random.KeyArray)
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
class GDAPjitTest(jtu.JaxTestCase):