mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sharding_in_types] Use set_mesh
API to trigger sharding_in_types instead of the config option.
PiperOrigin-RevId: 702814257
This commit is contained in:
parent
fa6585dea1
commit
9e2708eb57
@ -2193,15 +2193,8 @@ def lower_sharding_computation(
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
# TODO(yashkatariya): Thread it via jit path and remove the None check by
|
||||
# making tests go via set_mesh API always.
|
||||
devices_from_context = (
|
||||
None if mesh_lib.device_context.concrete_mesh is None
|
||||
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
|
||||
else:
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
unique_intermediate_shardings = util.stable_unique(
|
||||
|
@ -707,9 +707,6 @@ def get_abstract_mesh(in_avals):
|
||||
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
|
||||
f' another mesh: {a.sharding.mesh}')
|
||||
m = a.sharding.mesh # type: ignore
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if m is None:
|
||||
return contextlib.nullcontext()
|
||||
assert isinstance(m, AbstractMesh)
|
||||
return m
|
||||
|
||||
@ -1791,8 +1788,12 @@ def _pjit_lower(
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
if resource_env is not None else (None, 'jit'))
|
||||
if config.sharding_in_types.value:
|
||||
mesh = mesh_lib.device_context.concrete_mesh
|
||||
api_name = 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
if resource_env is not None else (None, 'jit'))
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
|
@ -51,6 +51,7 @@ from jax._src import monitoring
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src import stages
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
@ -1442,6 +1443,16 @@ def with_and_without_mesh(f):
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
def with_user_mesh(sizes, names):
|
||||
def decorator(fn):
|
||||
def mesh_fn(*args, **kwargs):
|
||||
mesh = create_mesh(sizes, names)
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
return fn(*args, **kwargs, mesh=mesh)
|
||||
return mesh_fn
|
||||
return decorator
|
||||
|
||||
|
||||
def create_mesh(mesh_shape, axis_names, iota_order=False):
|
||||
size = math.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
|
@ -4630,38 +4630,16 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ins, _ = f.lower(np.arange(8)).compile().input_shardings
|
||||
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
|
||||
|
||||
def test_sharding_in_types_with_set_mesh(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest("ShiT doesn't work with shardy")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
with mesh_lib.set_mesh(mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
||||
|
||||
@jtu.with_config(jax_sharding_in_types=True, jax_use_shardy_partitioner=False)
|
||||
@jtu.with_config(jax_use_shardy_partitioner=False)
|
||||
class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic_mul(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_basic_mul(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -4696,8 +4674,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
def test_fully_replicated_array_mul(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_fully_replicated_array_mul(self, mesh):
|
||||
np_inp1 = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr1 = jax.device_put(np_inp1, s)
|
||||
@ -4745,8 +4723,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'),
|
||||
('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce')
|
||||
)
|
||||
def test_dot_general(self, spec1, spec2, out_spec, collective_name):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
|
||||
@ -4781,8 +4759,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
|
||||
def test_dot_general_out_type(self):
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
@jtu.with_user_mesh((4,), ('x',))
|
||||
def test_dot_general_out_type(self, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
@ -4824,8 +4802,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
"dot_general requires contracting dimensions to have consistent sharding",
|
||||
TypeError),
|
||||
)
|
||||
def test_dot_general_error(self, spec1, spec2, error_msg, error_type):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh):
|
||||
np_inp1 = np.arange(16).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))
|
||||
@ -4837,8 +4815,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(error_type, error_msg):
|
||||
f(arr1, arr2)
|
||||
|
||||
def test_dot_general_batch_error(self):
|
||||
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
|
||||
@jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z'))
|
||||
def test_dot_general_batch_error(self, mesh):
|
||||
arr1 = jax.device_put(np.ones((8, 4, 2)),
|
||||
NamedSharding(mesh, P('x', 'y', 'z')))
|
||||
arr2 = jax.device_put(np.ones((8, 2, 4)),
|
||||
@ -4856,9 +4834,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
' have the consistent sharding'):
|
||||
jnp.einsum('abc,acz->abz', arr1, arr2)
|
||||
|
||||
def test_aval_repr(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('model', 'data'))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('model', 'data'))
|
||||
def test_aval_repr(self, mesh):
|
||||
aval = core.ShapedArray((128, 64), np.float32,
|
||||
sharding=NamedSharding(mesh, P('model', 'data')))
|
||||
self.assertEqual(aval.str_short(), 'float32[128@model,64@data]')
|
||||
@ -4876,14 +4853,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None, P('x', 'y'), P()),
|
||||
('first', 0, P('x', 'y'), P('y')),
|
||||
('second', 1, P('x', 'y'), P('x')),
|
||||
('first2', 0, P(('x', 'y'), None), P(None)),
|
||||
('all', None, P('x', 'y'), P(), True),
|
||||
('first', 0, P('x', 'y'), P('y'), True),
|
||||
('second', 1, P('x', 'y'), P('x'), True),
|
||||
('first2', 0, P(('x', 'y'), None), P(None), True),
|
||||
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
|
||||
)
|
||||
def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, in_spec)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -4907,14 +4884,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertIn('all-reduce', compiled_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None, P('x', 'y'), P()),
|
||||
('first', 0, P('x', 'y'), P('y')),
|
||||
('second', 1, P('x', 'y'), P('x')),
|
||||
('first2', 0, P(('x', 'y'), None), P(None)),
|
||||
('all', None, P('x', 'y'), P(), True),
|
||||
('first', 0, P('x', 'y'), P('y'), True),
|
||||
('second', 1, P('x', 'y'), P('x'), True),
|
||||
('first2', 0, P(('x', 'y'), None), P(None), True),
|
||||
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
|
||||
)
|
||||
def test_reduce_max(self, axis, in_spec, out_spec, reduce=True):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, in_spec)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -4954,8 +4931,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('2', 2, P('x', 'y', None)),
|
||||
('-1', -1, P('x', 'y', None)),
|
||||
)
|
||||
def test_broadcast_in_dim(self, axis, out_spec):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_broadcast_in_dim(self, axis, out_spec, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -4980,8 +4957,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('3', 3),
|
||||
('4', 4),
|
||||
)
|
||||
def test_integer_pow(self, pow):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_integer_pow(self, pow, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5010,12 +4987,13 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Mesh for all inputs should be equal"):
|
||||
f(arr1, arr2)
|
||||
with config.sharding_in_types(True):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Mesh for all inputs should be equal"):
|
||||
f(arr1, arr2)
|
||||
|
||||
def test_sin_unop(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_sin_unop(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5032,8 +5010,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
def test_jnp_array(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_jnp_array(self, mesh):
|
||||
np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5048,8 +5026,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
f(arr)
|
||||
|
||||
def test_lax_transpose_rule(self):
|
||||
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
|
||||
@jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z'))
|
||||
def test_lax_transpose_rule(self, mesh):
|
||||
np_inp = np.arange(16).reshape(4, 2, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y', 'z'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5067,8 +5045,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
def test_broadcasted_iota_with_sharding(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_broadcasted_iota_with_sharding(self, mesh):
|
||||
np_inp = np.arange(4)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5094,8 +5072,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
_, out = g(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
def test_einsum_with_out_type(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_einsum_with_out_type(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
|
||||
@ -5140,8 +5118,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out[0].sharding, arr3.sharding)
|
||||
self.assertEqual(out[1].sharding, arr4.sharding)
|
||||
|
||||
def test_einsum_inverse(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_einsum_inverse(self, mesh):
|
||||
np_inp = np.arange(64.)
|
||||
|
||||
@jax.jit
|
||||
@ -5179,9 +5157,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True),
|
||||
('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True)
|
||||
)
|
||||
@jtu.with_user_mesh((2,), ('x',))
|
||||
def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
|
||||
use_sharding_arg):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
use_sharding_arg, mesh):
|
||||
np_inp = np.arange(math.prod(src_shape),
|
||||
dtype=np.float32).reshape(src_shape)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec))
|
||||
@ -5209,8 +5187,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
def test_select(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_select(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr1 = jax.device_put(np_inp, s)
|
||||
@ -5234,8 +5212,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
TypeError, "select cases must have the same shardings"):
|
||||
f(arr1 == arr2, arr1, arr3)
|
||||
|
||||
def test_device_put_reshard(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_device_put_reshard(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -5250,8 +5228,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
def test_shard_map_full_manual(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_shard_map_full_manual(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
@ -5275,8 +5253,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, (np_inp * np_inp) * 2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
def test_shard_map_dot(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_shard_map_dot(self, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
|
||||
@ -5302,8 +5280,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
def test_slice(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_slice(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(4, 4)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@ -5333,8 +5311,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"):
|
||||
f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y')))))
|
||||
|
||||
def test_squeeze(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_squeeze(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(4, 4, 1)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None)))
|
||||
|
||||
@ -5359,8 +5337,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
def test_pad(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_pad(self, mesh):
|
||||
np_inp = np.arange(8.)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
||||
|
||||
@ -5401,8 +5379,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'))))
|
||||
f(arr, ((4, 4, 1),), None)
|
||||
|
||||
def test_concatenate(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 1), ('x', 'y'))
|
||||
def test_concatenate(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(4, 4)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr1 = jax.device_put(np_inp, s)
|
||||
@ -5443,8 +5421,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr1, arr2)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
def test_scan(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_scan(self, mesh):
|
||||
carry = jax.device_put(np.arange(16.).reshape(2, 8),
|
||||
NamedSharding(mesh, P(None, 'x')))
|
||||
arr = jax.device_put(np.arange(128.).reshape(8, 8, 2),
|
||||
@ -5481,8 +5459,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
ValueError, "0th dimension of all xs should be replicated"):
|
||||
f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None))))
|
||||
|
||||
def test_argminmax(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_argminmax(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
Loading…
x
Reference in New Issue
Block a user