rocm_jax/tests/shard_alike_test.py
Peter Hawkins 51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00

275 lines
7.7 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest
from jax._src import test_util as jtu
from jax.sharding import NamedSharding, PartitionSpec as P
from jax.experimental.shard_alike import shard_alike
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
class ShardAlikeDownstreamTest(jtu.JaxTestCase):
def test_full_like(self):
x = jnp.arange(16, dtype='float32').reshape(8, 2)
mesh = jtu.create_mesh((8,), ("i",))
x = jax.device_put(x, NamedSharding(mesh, P('i', None)))
y = jnp.full_like(x, 1)
self.assertEqual(x.sharding, y.sharding)
class ShardAlikeTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
def test_basic(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x * x
z = y * 2
_, z = shard_alike(x, z)
return z * 2
out = f(inp)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp * np_inp * 4)
def test_output_sharded_alike_input(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x * 2
return shard_alike(x, y)[1]
out = f(inp)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp * 2)
def test_arange_shard_alike_jit(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = jnp.arange(16).reshape(8, 2)
return shard_alike(x, y)[1]
out = f(inp)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp)
def test_different_shapes(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x',))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x @ x.T
return shard_alike(x, y)[1]
with self.assertRaisesRegex(
ValueError, 'The leaves shapes of `x` and `y` should match'):
f(inp)
def test_double_shard_alike(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x * 2
_, y = shard_alike(x, y)
z = y @ y.T
a = jnp.arange(64).reshape(8, 8)
return shard_alike(z, a)
out1, out2 = f(inp)
self.assertEqual(out1.sharding, NamedSharding(mesh, P('x')))
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
def test_shard_like_eager(self):
mesh = jtu.create_mesh((4, 1), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
def f(x):
y = jnp.arange(16).reshape(8, 2)
return shard_alike(x, y)[1]
out = f(inp)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp)
def test_shard_map(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp = jax.device_put(np_inp, s)
def g(x):
return jax.lax.psum(x, 'x')
@jax.jit
def f(x):
y = x @ x.T
s_out = shard_map(g, mesh, in_specs=P('x', 'y'),
out_specs=P(None, 'y'))(y)
z = s_out.T @ s_out
return shard_alike(y, z)
out1, out2 = f(inp)
# From options; P('x', 'y'), P('y'), shard_like chooses the better option.
self.assertEqual(out1.sharding, s)
self.assertEqual(out2.sharding, s)
def test_grad(self):
mesh = jtu.create_mesh((4,), ('x',))
np_inp = np.arange(8.)
s = NamedSharding(mesh, P('x'))
inp = jax.device_put(np_inp, s)
def _cb(s):
self.assertFalse(s.is_fully_replicated)
self.assertLen(s.device_set, mesh.size)
self.assertEqual(s.shard_shape(np_inp.shape), (2,))
def f(x):
y = jnp.arange(8.)
x_, y_ = shard_alike(x, y)
jax.debug.inspect_array_sharding(y_, callback=_cb)
z = x_ + y_
return jnp.sum(z)
jax.grad(f)(inp) # doesn't crash
jax.grad(jax.jit(f))(inp) # doesn't crash
def test_shard_input_as_output(self):
mesh = jtu.create_mesh((4,), ('x',))
np_inp = np.arange(8.)
s = NamedSharding(mesh, P('x'))
@jax.jit
def f(x):
y = jax.lax.with_sharding_constraint(x, s)
z = y * 2
return shard_alike(x, z)
with jtu.count_pjit_cpp_cache_miss() as count:
f(np_inp)
out1, out2 = f(np_inp)
self.assertEqual(count(), 1)
self.assertTrue(s.is_equivalent_to(out1.sharding, np_inp.ndim))
self.assertTrue(s.is_equivalent_to(out2.sharding, np_inp.ndim))
@jax.jit
def g(x):
z = x * 2
return shard_alike(x, z)
arr = jax.device_put(np_inp, s)
with jtu.count_pjit_cpp_cache_miss() as count:
g(arr)
out3, out4 = g(arr)
self.assertEqual(count(), 1)
self.assertEqual(out3.sharding, s)
self.assertEqual(out4.sharding, s)
def test_shard_alike_inputs(self):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8.)
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np_inp, s)
def f(x, y):
return shard_alike(x, y)
eager_out1, eager_out2 = f(arr, np_inp)
self.assertEqual(eager_out1.sharding, s)
self.assertEqual(eager_out2.sharding, s)
out1, out2 = jax.jit(f)(arr, np_inp)
self.assertEqual(out1.sharding, s)
self.assertEqual(out2.sharding, s)
def test_vmap_one_mapped(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(2)
s = NamedSharding(mesh, P('y'))
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
def _shard_slice_like_arg(s):
sharded_s, _ = shard_alike(s, x)
return sharded_s
replicated_x = jnp.tile(x, [8, 1]) # shape == (8, 2)
return jax.vmap(_shard_slice_like_arg, in_axes=0)(replicated_x)
out = f(inp)
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y')))
self.assertArraysEqual(out, np.tile(np_inp, [8, 1]))
def test_vmap_both_mapped(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
inp1 = jax.device_put(np_inp, s)
np_inp2 = np.arange(16).reshape(2, 8)
inp2 = jax.device_put(np_inp2, NamedSharding(mesh, P('y', 'x')))
@jax.jit
def f(x, y):
return jax.vmap(shard_alike, in_axes=(0, 1))(x, y)
out1, out2 = f(inp1, inp2)
self.assertEqual(out1.sharding, s)
self.assertEqual(out2.sharding, s)
self.assertArraysEqual(out1, np_inp)
self.assertArraysEqual(out2, np_inp2.T)
def test_sharding_preserverd_single_device(self):
mesh = jax.sharding.Mesh([jax.devices()[0]], "x")
s = NamedSharding(mesh, P("x"))
x = jax.device_put(np.arange(8), s)
_, y = shard_alike(x, jnp.arange(8))
self.assertEqual(y.sharding, s)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())