KeyArray: support make_array_from_* APIs

This commit is contained in:
Jake VanderPlas 2023-05-04 16:32:49 -07:00
parent 2845df03fc
commit 4db717c52a
3 changed files with 37 additions and 0 deletions

View File

@ -597,6 +597,8 @@ def make_array_from_callback(
for device in sharding.addressable_devices
]
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
return ArrayImpl(aval, sharding, arrays, committed=True)
@ -642,6 +644,8 @@ def make_array_from_single_device_arrays(
# All input arrays should be committed. Checking it is expensive on
# single-controller systems.
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
# removing DeviceArray.
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),

View File

@ -445,6 +445,17 @@ class KeyTyRules:
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
return handler
@staticmethod
def make_sharded_array(aval, sharding, arrays, committed):
phys_aval, = KeyTyRules.physical_avals(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_arrays = [random_unwrap(arr) for arr in arrays]
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
phys_result = phys_handler(phys_arrays)
return PRNGKeyArrayImpl(aval.dtype.impl, phys_result)
# element-type-polymorphic primitive lowering rules
@staticmethod

View File

@ -1853,6 +1853,28 @@ class KeyArrayTest(jtu.JaxTestCase):
keys_on_device = jax.device_put_replicated(key, devices)
self.assertArraysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)
def test_make_array_from_callback(self):
devices = jax.devices()
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
def callback(index):
i = jnp.arange(len(devices))[index[0]]
return jax.vmap(jax.random.PRNGKey)(i)
result = jax.make_array_from_callback(shape, sharding, callback)
expected = jax.vmap(jax.random.PRNGKey)(jnp.arange(len(devices)))
self.assertArraysEqual(result, expected)
def test_make_array_from_single_device_arrays(self):
devices = jax.devices()
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
keys = jax.random.split(jax.random.PRNGKey(0), len(devices))
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
self.assertArraysEqual(result, keys)
# TODO(frostig,mattjj): more polymorphic primitives tests