mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add support for MeshPspecSharding local_sharded_result_handler because SDA outputs from pjit can produce a MeshPspecSharding.
PiperOrigin-RevId: 470119499
This commit is contained in:
parent
cc8d406bdb
commit
96058d0197
@ -319,12 +319,20 @@ class KeyTy:
|
||||
(core.ShapedArray, output_type)]
|
||||
|
||||
# set up a grounded sharding (with a grounded sharding spec)
|
||||
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = pxla.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
phys_sharding = PmapSharding(devices=sharding.devices,
|
||||
sharding_spec=phys_sharding_spec)
|
||||
if isinstance(sharding, PmapSharding):
|
||||
trailing_sharding = [pxla.NoSharding()] * len(key_shape)
|
||||
phys_sharding_spec = pxla.ShardingSpec(
|
||||
sharding=(*sharding.sharding_spec.sharding, *trailing_sharding),
|
||||
mesh_mapping=sharding.sharding_spec.mesh_mapping)
|
||||
phys_sharding = PmapSharding(devices=sharding.devices,
|
||||
sharding_spec=phys_sharding_spec)
|
||||
elif isinstance(sharding, MeshPspecSharding):
|
||||
trailing_spec = [None] * len(key_shape)
|
||||
phys_sharding = MeshPspecSharding(
|
||||
sharding.mesh,
|
||||
pxla.PartitionSpec(*sharding.spec, *trailing_spec))
|
||||
else:
|
||||
assert False, f'impossible sharding {sharding} in local sharded result handler'
|
||||
|
||||
# set up grounded indices
|
||||
trailing_inds = [slice(None)] * len(key_shape)
|
||||
|
@ -1033,6 +1033,24 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsInstance(exe, stages.Compiled)
|
||||
self.assertArraysEqual(exe(x, x), x @ x)
|
||||
|
||||
def test_local_sharded_key_array_sda(self):
|
||||
input_shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
seeds = jnp.arange(
|
||||
prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
||||
|
||||
with mesh:
|
||||
def make_keys(seeds):
|
||||
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
|
||||
return make_key(seeds)
|
||||
|
||||
f = pjit(make_keys, in_axis_resources=P(None), out_axis_resources=P(None))
|
||||
|
||||
out = f(seeds)
|
||||
self.assertIsInstance(out, jax.random.KeyArray)
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
|
||||
|
||||
class GDAPjitTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user