mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
KeyArray: support make_array_from_* APIs
This commit is contained in:
parent
2845df03fc
commit
4db717c52a
@ -597,6 +597,8 @@ def make_array_from_callback(
|
||||
for device in sharding.addressable_devices
|
||||
]
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
|
||||
return ArrayImpl(aval, sharding, arrays, committed=True)
|
||||
|
||||
|
||||
@ -642,6 +644,8 @@ def make_array_from_single_device_arrays(
|
||||
# All input arrays should be committed. Checking it is expensive on
|
||||
# single-controller systems.
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
|
||||
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
|
||||
# removing DeviceArray.
|
||||
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
|
||||
|
@ -445,6 +445,17 @@ class KeyTyRules:
|
||||
return PRNGKeyArrayImpl(aval.dtype.impl, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def make_sharded_array(aval, sharding, arrays, committed):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
phys_arrays = [random_unwrap(arr) for arr in arrays]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(aval, sharding, False)
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed, False)
|
||||
phys_result = phys_handler(phys_arrays)
|
||||
return PRNGKeyArrayImpl(aval.dtype.impl, phys_result)
|
||||
|
||||
# element-type-polymorphic primitive lowering rules
|
||||
|
||||
@staticmethod
|
||||
|
@ -1853,6 +1853,28 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
keys_on_device = jax.device_put_replicated(key, devices)
|
||||
self.assertArraysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)
|
||||
|
||||
def test_make_array_from_callback(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
def callback(index):
|
||||
i = jnp.arange(len(devices))[index[0]]
|
||||
return jax.vmap(jax.random.PRNGKey)(i)
|
||||
result = jax.make_array_from_callback(shape, sharding, callback)
|
||||
expected = jax.vmap(jax.random.PRNGKey)(jnp.arange(len(devices)))
|
||||
self.assertArraysEqual(result, expected)
|
||||
|
||||
def test_make_array_from_single_device_arrays(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
keys = jax.random.split(jax.random.PRNGKey(0), len(devices))
|
||||
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
|
||||
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
|
||||
self.assertArraysEqual(result, keys)
|
||||
|
||||
# TODO(frostig,mattjj): more polymorphic primitives tests
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user