avoid raw key arrays in typed key sharding test

This also lets us remove a guard on `config.jax_enable_custom_prng` in
random tests.
This commit is contained in:
Roy Frostig 2023-06-30 20:38:26 -07:00
parent 404e3061b6
commit bc44b99d05

View File

@ -1925,22 +1925,22 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_make_array_from_callback(self):
devices = jax.devices()
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
shape = (len(devices),)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
def callback(index):
i = jnp.arange(len(devices))[index[0]]
return jax.vmap(random.PRNGKey)(i)
return jax.vmap(random.key)(i)
result = jax.make_array_from_callback(shape, sharding, callback)
expected = jax.vmap(random.PRNGKey)(jnp.arange(len(devices)))
expected = jax.vmap(random.key)(jnp.arange(len(devices)))
self.assertArraysEqual(result, expected)
def test_make_array_from_single_device_arrays(self):
devices = jax.devices()
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
shape = (len(devices),)
mesh = jtu.create_global_mesh((len(devices),), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
keys = random.split(random.PRNGKey(0), len(devices))
keys = random.split(random.key(0), len(devices))
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
self.assertArraysEqual(result, keys)