mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
avoid raw key arrays in typed key sharding test
This also lets us remove a guard on `config.jax_enable_custom_prng` in random tests.
This commit is contained in:
parent
404e3061b6
commit
bc44b99d05
@ -1925,22 +1925,22 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_make_array_from_callback(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
|
||||
shape = (len(devices),)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
def callback(index):
|
||||
i = jnp.arange(len(devices))[index[0]]
|
||||
return jax.vmap(random.PRNGKey)(i)
|
||||
return jax.vmap(random.key)(i)
|
||||
result = jax.make_array_from_callback(shape, sharding, callback)
|
||||
expected = jax.vmap(random.PRNGKey)(jnp.arange(len(devices)))
|
||||
expected = jax.vmap(random.key)(jnp.arange(len(devices)))
|
||||
self.assertArraysEqual(result, expected)
|
||||
|
||||
def test_make_array_from_single_device_arrays(self):
|
||||
devices = jax.devices()
|
||||
shape = (len(devices),) if config.jax_enable_custom_prng else (len(devices), 2)
|
||||
shape = (len(devices),)
|
||||
mesh = jtu.create_global_mesh((len(devices),), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
|
||||
keys = random.split(random.PRNGKey(0), len(devices))
|
||||
keys = random.split(random.key(0), len(devices))
|
||||
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
|
||||
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
|
||||
self.assertArraysEqual(result, keys)
|
||||
|
Loading…
x
Reference in New Issue
Block a user