mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Add a test using inputs with different device orders for a single colocated Python call
PiperOrigin-RevId: 708461989
This commit is contained in:
parent
e560c6a45c
commit
0e0fc0ac03
@ -300,6 +300,48 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
# around 15 seconds.
|
||||
self.assertLess(elapsed_time, 10)
|
||||
|
||||
def testInputsWithDifferentDeviceOrders(self):
|
||||
cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2]
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest("Not enough CPU devices")
|
||||
|
||||
@colocated_python.colocated_python
|
||||
def add(x: jax.Array, y: jax.Array) -> jax.Array:
|
||||
arrays = [
|
||||
x.addressable_shards[1].data + y.addressable_shards[0].data,
|
||||
x.addressable_shards[0].data + y.addressable_shards[1].data,
|
||||
]
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
y.shape, y.sharding, arrays
|
||||
)
|
||||
|
||||
# The execution will use mixed device orders. We should specialize the
|
||||
# function with devices to avoid the argument-dependent device selection.
|
||||
add = add.specialize(devices=cpu_devices)
|
||||
|
||||
mesh1 = jax.sharding.Mesh([cpu_devices[0], cpu_devices[1]], "x")
|
||||
sharding1 = jax.sharding.NamedSharding(
|
||||
mesh1, jax.sharding.PartitionSpec("x")
|
||||
)
|
||||
mesh2 = jax.sharding.Mesh([cpu_devices[1], cpu_devices[0]], "x")
|
||||
sharding2 = jax.sharding.NamedSharding(
|
||||
mesh2, jax.sharding.PartitionSpec("x")
|
||||
)
|
||||
|
||||
x = np.array([0, 2])
|
||||
x = jax.device_put(x, sharding1)
|
||||
y = np.array([4, 8])
|
||||
y = jax.device_put(y, sharding2)
|
||||
|
||||
out = add(x, y)
|
||||
|
||||
self.assertEqual(out.sharding, sharding2)
|
||||
out_device_list = [shard.device for shard in out.addressable_shards]
|
||||
self.assertEqual(out_device_list, [cpu_devices[1], cpu_devices[0]])
|
||||
|
||||
out = jax.device_get(out)
|
||||
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user