mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
3145 lines
104 KiB
Python
3145 lines
104 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
|
from functools import partial
|
|
import itertools as it
|
|
import math
|
|
import operator as op
|
|
from types import SimpleNamespace
|
|
from typing import Any, NamedTuple, TypeVar
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.ad_checkpoint
|
|
from jax import api_util
|
|
from jax import lax
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import prng
|
|
from jax._src import test_util as jtu
|
|
from jax._src.lib.mlir.dialects import sdy
|
|
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
|
from jax._src.ad_checkpoint import saved_residuals
|
|
from jax._src.mesh import AxisType
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src import linear_util as lu
|
|
from jax._src import tree_util
|
|
import jax.numpy as jnp
|
|
|
|
from jax.experimental.custom_partitioning import custom_partitioning
|
|
from jax.experimental.shard_map import shard_map
|
|
|
|
|
|
config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(8)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
# Helper for some tests.
|
|
def create_inputs(a_sharding, b_sharding):
|
|
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
|
b, e, f = 8, 8, 8 # pylint: disable=invalid-name
|
|
m1 = jax.device_put(
|
|
jnp.arange(b * e).reshape((b, e)),
|
|
jax.sharding.NamedSharding(mesh, a_sharding))
|
|
m2 = jax.device_put(
|
|
jnp.arange(e * f).reshape((e, f)),
|
|
jax.sharding.NamedSharding(mesh, b_sharding))
|
|
return mesh, m1, m2
|
|
|
|
|
|
class ShardMapTest(jtu.JaxTestCase):
|
|
|
|
def test_identity(self):
|
|
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
|
assert a.addressable_data(0).shape == (4, 2)
|
|
|
|
def identity(x):
|
|
return x
|
|
|
|
@jax.jit
|
|
def fwd(a):
|
|
c = shard_map(
|
|
identity,
|
|
mesh,
|
|
in_specs=(P('z', ('x', 'y')),),
|
|
out_specs=P('z', ('x', 'y')))(a)
|
|
return c
|
|
|
|
c = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (4, 2))
|
|
|
|
def test_all_gather(self):
|
|
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
|
assert a.addressable_data(0).shape == (4, 2)
|
|
|
|
# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
|
|
# all_gather_invariant primitive, which differs in its output replication
|
|
# type compared to all_gather.
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
|
|
def fwd(a):
|
|
return (
|
|
lax.all_gather(a, 'z', axis=0, tiled=True),
|
|
lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True),
|
|
)
|
|
c, d = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (8, 2))
|
|
for i, a_shard in enumerate(np.split(a, 4, axis=1)):
|
|
self.assertAllClose(c.addressable_data(2 * i), a_shard)
|
|
self.assertEqual(d.addressable_data(0).shape, (4, 8))
|
|
for i, a_shard in enumerate(np.split(a, 2, axis=0)):
|
|
self.assertAllClose(d.addressable_data(i), a_shard)
|
|
|
|
def test_all_gather_with_axis_index_groups(self):
|
|
mesh, a, _ = create_inputs(P('x', ('y', 'z')), P(None, None))
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(P('x', ('y', 'z')),),
|
|
out_specs=P('x', ('y', 'z')),
|
|
)
|
|
def fwd(a):
|
|
return lax.all_gather(
|
|
a, ('y', 'z'), axis_index_groups=((0, 1), (2, 3)), axis=-1, tiled=True
|
|
)
|
|
|
|
c = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (4, 4))
|
|
for i, row_block in enumerate(np.split(a, 2, axis=0)):
|
|
for j, block in enumerate(np.split(row_block, 2, axis=-1)):
|
|
self.assertAllClose(c.addressable_data(4 * i + 2 * j), block)
|
|
self.assertAllClose(c.addressable_data(4 * i + 2 * j + 1), block)
|
|
|
|
def test_matmul_partial(self):
|
|
raise unittest.SkipTest("invalid replication asserted by out_spec?")
|
|
|
|
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
|
|
assert a.addressable_data(0).shape == (4, 4)
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('z', 'y'), P('y', None)), out_specs=P('z', None))
|
|
def fwd(a):
|
|
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
|
|
return c
|
|
|
|
c = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (4, 8))
|
|
|
|
def test_matmul_reduce_scatter(self):
|
|
mesh, a, b = create_inputs(P('z', 'y'), P('y', None))
|
|
assert a.addressable_data(0).shape == (4, 4)
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('z', 'y'), P('y', None)),
|
|
out_specs=P(('z', 'y'), None))
|
|
def fwd(a, b):
|
|
c = jnp.matmul(a, b) # [B.z, F] {y.unreduced}
|
|
return (
|
|
lax.psum_scatter(c, 'y', scatter_dimension=0, tiled=True),
|
|
lax.psum_scatter(c, ('z', 'y'), scatter_dimension=0, tiled=True),
|
|
)
|
|
|
|
expected = jnp.matmul(a, b)
|
|
c, d = fwd(a, b)
|
|
self.assertEqual(c.addressable_data(0).shape, (2, 8))
|
|
self.assertAllClose(expected, c)
|
|
self.assertEqual(d.addressable_data(0).shape, (1, 8))
|
|
self.assertAllClose(expected[:4] + expected[4:], d)
|
|
|
|
def test_reduce_scatter_with_axis_index_groups(self):
|
|
axis_index_groups = ((0, 2, 4, 6), (1, 3, 5, 7))
|
|
mesh, a, _ = create_inputs(P(None, ('x', 'y', 'z')), P(None, None))
|
|
assert a.addressable_data(0).shape == (8, 1)
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(P(None, ('x', 'y', 'z')),),
|
|
out_specs=P(None, ('x', 'y', 'z')),
|
|
)
|
|
def fwd(a):
|
|
return lax.psum_scatter(
|
|
a,
|
|
('x', 'y', 'z'),
|
|
scatter_dimension=0,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=True,
|
|
)
|
|
|
|
c = fwd(a)
|
|
|
|
self.assertEqual(c.addressable_data(0).shape, (2, 1))
|
|
|
|
sum_of_even_columns = np.sum(a[..., axis_index_groups[0]], -1)
|
|
for i, sums in enumerate(np.split(sum_of_even_columns, 4, 0)):
|
|
self.assertAllClose(np.squeeze(c.addressable_data(2 * i), -1), sums)
|
|
|
|
sum_of_odd_columns = np.sum(a[..., axis_index_groups[1]], -1)
|
|
for i, sums in enumerate(np.split(sum_of_odd_columns, 4, 0)):
|
|
self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums)
|
|
|
|
def test_collective_permute(self):
|
|
mesh = jtu.create_mesh((8,), 'x')
|
|
a = jax.device_put(
|
|
jnp.arange(8 * 8).reshape((8, 8)),
|
|
jax.sharding.NamedSharding(mesh, P('x', None)))
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
|
|
)
|
|
def fwd(a):
|
|
axis_size = lax.psum(1, 'x')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(a, 'x', perm=perm)
|
|
|
|
c = fwd(a)
|
|
self.assertAllClose(c[1, :], a[0, :])
|
|
|
|
def test_collective_permute_with_multiple_axis_names(self):
|
|
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
|
a = jax.device_put(
|
|
jnp.arange(8 * 8).reshape((4, 16)),
|
|
jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))),
|
|
)
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(P('x', ('y', 'z')),),
|
|
out_specs=P('x', ('y', 'z')),
|
|
)
|
|
def fwd(a):
|
|
xy_axis_size = lax.psum(1, ('x', 'y'))
|
|
yz_axis_size = lax.psum(1, ('y', 'z'))
|
|
xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)]
|
|
yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)]
|
|
return (
|
|
lax.ppermute(a, ('x', 'y'), perm=xy_perm),
|
|
lax.ppermute(a, ('y', 'z'), perm=yz_perm),
|
|
)
|
|
|
|
c, d = fwd(a)
|
|
for i in range(8):
|
|
self.assertAllClose(
|
|
a.addressable_data(i), c.addressable_data((i + 2) % 8)
|
|
)
|
|
self.assertAllClose(
|
|
a.addressable_data(i), d.addressable_data(4 * (i // 4) + (i + 1) % 4)
|
|
)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name='_single_axis_name', axis_name='x', mesh_axes=dict(x=8)
|
|
),
|
|
dict(
|
|
testcase_name='_multiple_axis_names',
|
|
axis_name=('x', 'y'),
|
|
mesh_axes=dict(x=4, y=2),
|
|
),
|
|
)
|
|
def test_all_to_all(self, axis_name, mesh_axes):
|
|
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
|
a = jax.device_put(
|
|
jnp.arange(8 * 8).reshape((8, 8)),
|
|
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
|
)
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(P(axis_name, None),),
|
|
out_specs=P(None, axis_name),
|
|
)
|
|
def fwd(a):
|
|
return lax.all_to_all(
|
|
a, axis_name, split_axis=1, concat_axis=1, tiled=True
|
|
)
|
|
|
|
c = fwd(a)
|
|
assert (c == jnp.reshape(a.T, (1, 64))).all()
|
|
|
|
def test_all_to_all_with_axis_index_groups(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
a = jax.device_put(
|
|
jnp.arange(4 * 4).reshape((4, 4)),
|
|
jax.sharding.NamedSharding(mesh, P('x', None)),
|
|
)
|
|
self.assertEqual(a.addressable_data(0).shape, (1, 4))
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=(P('x', None),),
|
|
out_specs=P(None, 'x'),
|
|
)
|
|
def fwd(a):
|
|
return lax.all_to_all(
|
|
a,
|
|
'x',
|
|
split_axis=1,
|
|
concat_axis=0,
|
|
axis_index_groups=((0, 1), (2, 3)),
|
|
tiled=True,
|
|
)
|
|
|
|
c = fwd(a)
|
|
|
|
# Each shard corresponds to a quadrant rather than a row.
|
|
self.assertEqual(c.addressable_data(0).shape, (2, 2))
|
|
for i, row_block in enumerate(np.split(a, 2, axis=0)):
|
|
for j, block in enumerate(np.split(row_block, 2, axis=-1)):
|
|
self.assertAllClose(block, c.addressable_data(2 * i + j))
|
|
|
|
def test_all_to_all_grad(self):
|
|
mesh = jtu.create_mesh((4,), 'x')
|
|
a = jax.device_put(
|
|
jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)),
|
|
jax.sharding.NamedSharding(mesh, P('x', None)),
|
|
)
|
|
self.assertEqual(a.addressable_data(0).shape, (2, 8))
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P(None, 'x')
|
|
)
|
|
def fwd(x):
|
|
return lax.all_to_all(x, 'x', split_axis=1, concat_axis=0, tiled=True)
|
|
|
|
c = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (8, 2))
|
|
self.assertAllClose(a, c)
|
|
|
|
@jax.jit
|
|
@partial(jax.grad, has_aux=True)
|
|
def loss_and_grad(x):
|
|
loss = fwd(x).sum() * 2
|
|
return loss, loss
|
|
|
|
grad, loss = loss_and_grad(a)
|
|
self.assertEqual(loss, 2 * sum(range(64)))
|
|
self.assertAllClose(grad, 2 * np.ones_like(a))
|
|
|
|
def test_eager_repr(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = None
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
|
|
def f(x):
|
|
nonlocal s
|
|
s = str(x)
|
|
return x
|
|
_ = f(np.arange(8 * 8.).reshape(8, 8))
|
|
|
|
self.assertIsInstance(s, str)
|
|
self.assertIn('at mesh coordinates', s)
|
|
|
|
def test_jvp_basic(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
|
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
|
args = np.arange(4 * 4.).reshape(4, 4),
|
|
jtu.check_grads(g, args, 2, ['fwd'])
|
|
jtu.check_grads(jax.jit(g), args, 2, ['fwd'])
|
|
|
|
def test_linearize_basic(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
|
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
|
x = np.arange(4 * 4.).reshape(4, 4)
|
|
|
|
y, y_dot = jax.jvp(g, [x], [x])
|
|
|
|
y_, g_lin = jax.linearize(g, x)
|
|
y_dot_ = g_lin(x)
|
|
|
|
self.assertAllClose(y, y_, check_dtypes=False)
|
|
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
|
|
|
def test_linearize_basic_repres(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh,
|
|
in_specs=(P('x',),), out_specs=P('x',))
|
|
x = np.arange(4.)
|
|
|
|
y, y_dot = jax.jvp(g, [x], [x])
|
|
|
|
y_, g_lin = jax.linearize(g, x)
|
|
y_dot_ = g_lin(x)
|
|
|
|
self.assertAllClose(y, y_, check_dtypes=False)
|
|
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
|
|
|
def test_linearize_basic_repres_jit(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
|
in_specs=(P('x',),), out_specs=P('x',))
|
|
x = np.arange(4.)
|
|
|
|
y, y_dot = jax.jvp(g, [x], [x])
|
|
|
|
y_, g_lin = jax.linearize(g, x)
|
|
y_dot_ = g_lin(x)
|
|
|
|
self.assertAllClose(y, y_, check_dtypes=False)
|
|
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
|
|
|
def test_replication_checker_eager(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = np.arange(8 * 8.).reshape(8, 8)
|
|
|
|
def f(x):
|
|
return 2 * x
|
|
def g(x):
|
|
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'statically inferred'):
|
|
g(x)
|
|
|
|
def f2(x):
|
|
return jax.lax.psum(x, 'x')
|
|
def g2(x):
|
|
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
|
_ = g2(x) # doesn't crash
|
|
|
|
def test_replication_checker_jit(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = np.arange(8 * 8.).reshape(8, 8)
|
|
|
|
def f(x):
|
|
return 2 * x
|
|
def g(x):
|
|
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'statically inferred'):
|
|
jax.jit(g)(x)
|
|
|
|
def f2(x):
|
|
return jax.lax.psum(x, 'x')
|
|
def g2(x):
|
|
return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x)
|
|
_ = jax.jit(g2)(x) # doesn't crash
|
|
|
|
def test_process_env_traces(self):
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
x = np.arange(8.)
|
|
|
|
def g(x):
|
|
y = (3. * x).sum()
|
|
z = shard_map(lambda x: 2 * x * y, mesh,
|
|
in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.))
|
|
return z
|
|
|
|
jtu.check_grads(g, (x,), modes=['fwd'], order=2)
|
|
|
|
def test_eager_control_flow(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = jnp.arange(2 * 2.).reshape(2, 2)
|
|
|
|
def f(x):
|
|
y = jax.lax.psum(x, ('x', 'y'))
|
|
if y < 0:
|
|
return x
|
|
else:
|
|
return -x
|
|
|
|
def g(x):
|
|
return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x)
|
|
y = g(x)
|
|
self.assertAllClose(y, -x, check_dtypes=False)
|
|
|
|
def test_outer_jit_detects_shard_map_mesh(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
|
|
_ = jax.jit(f)(jnp.array(2.0)) # doesn't crash
|
|
|
|
def test_vmap_basic(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = jnp.arange(8 * 8.).reshape(8, 8)
|
|
|
|
def g(x):
|
|
return shard_map(lambda x: 2. * x, mesh,
|
|
in_specs=P('y'), out_specs=P('y'))(x)
|
|
y = jax.vmap(g)(x)
|
|
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
|
|
|
def test_vmap_basic_axis_name(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = jnp.arange(8 * 8.).reshape(8, 8)
|
|
|
|
def g(x):
|
|
return shard_map(lambda x: 2. * x, mesh,
|
|
in_specs=P('y'), out_specs=P('y'))(x)
|
|
y = jax.vmap(g, axis_name='i')(x)
|
|
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
|
|
|
def test_vmap_basic_axis_name_reuse_mesh_name(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
x = jnp.arange(8 * 8.).reshape(8, 8)
|
|
|
|
def g(x):
|
|
return shard_map(lambda x: 2. * x, mesh,
|
|
in_specs=P('y'), out_specs=P('y'))(x)
|
|
y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh
|
|
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
|
|
|
def test_tree_prefix_error(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y'))
|
|
def f(x):
|
|
return x
|
|
|
|
x = jnp.arange(8 * 8.).reshape(8, 8)
|
|
with self.assertRaisesRegex(ValueError, r'shard_map in_specs\[0\]'):
|
|
f([x, x])
|
|
|
|
def test_rank_errors(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
def foo():
|
|
return {'hi': [3.]}
|
|
|
|
with self.assertRaisesRegex(ValueError, 'which has length 1'):
|
|
shard_map(foo, mesh=mesh, in_specs=(), out_specs={'hi': P('x')})()
|
|
|
|
with self.assertRaisesRegex(ValueError, 'which has length 1'):
|
|
jax.jit(lambda: shard_map(foo, mesh=mesh,
|
|
in_specs=(), out_specs={'hi': P('x')})())()
|
|
|
|
with self.assertRaisesRegex(ValueError, 'which has rank 0'):
|
|
shard_map(foo, mesh=mesh, in_specs=({'hi': P('x')},), out_specs=())(
|
|
{'hi': [jnp.array(3.)]})
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'consider using an in_specs entry of `P\(\)`'):
|
|
shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.)
|
|
|
|
def test_reverse_mode_ad(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('x',), P(None)), out_specs=P('x',))
|
|
def f(x, y):
|
|
return jnp.sin(x) + 3 + jnp.tan(2.) * jnp.cos(x) + y
|
|
|
|
x = jnp.arange(8.) / 10.
|
|
y = jnp.arange(4.) / 10.
|
|
jtu.check_grads(f, (x, y), modes=['fwd', 'rev'], order=2)
|
|
|
|
def test_post_process(self):
|
|
# JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
def f(x):
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def g(y):
|
|
return jnp.sin(y) * jnp.sin(x).sum()
|
|
return g(jnp.arange(8.))
|
|
|
|
x = jnp.arange(8.)
|
|
_, f_lin = jax.linearize(f, x)
|
|
y_dot = f_lin(x)
|
|
|
|
y_dot_expected = jnp.sin(jnp.arange(8.)) * (jnp.cos(x) * x).sum()
|
|
self.assertAllClose(y_dot, y_dot_expected, check_dtypes=False)
|
|
|
|
@jtu.run_on_devices('gpu', 'tpu')
|
|
def test_axis_index(self):
|
|
mesh = jtu.create_mesh((4,), 'x')
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x'))
|
|
def f():
|
|
return jax.lax.axis_index('x')[None]
|
|
|
|
x = f()
|
|
self.assertAllClose(x, jnp.arange(4), check_dtypes=False)
|
|
|
|
def test_remat_basic(self):
|
|
# this tests remat-of-shmap
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
|
|
# check param updating is handled
|
|
@jax.remat
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
return jnp.sin(jnp.sin(x))
|
|
|
|
x = jnp.arange(4.)
|
|
g = jax.grad(lambda x: f(x).sum())(x) # doesn't crash
|
|
self.assertAllClose(g, jnp.cos(jnp.sin(x)) * jnp.cos(x), check_dtypes=False,
|
|
atol=1e-3, rtol=1e-3)
|
|
saved_res = saved_residuals(f, x)
|
|
self.assertLen(saved_res, 1)
|
|
|
|
# also check residuals are handled correctly
|
|
@partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable)
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f2(x):
|
|
return jnp.sin(jnp.sin(x))
|
|
|
|
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
|
|
self.assertAllClose(g2, jnp.cos(jnp.sin(x)) * jnp.cos(x),
|
|
check_dtypes=False, atol=1e-3, rtol=1e-3)
|
|
saved_res = saved_residuals(f2, x)
|
|
self.assertLen(saved_res, 2)
|
|
|
|
def test_shmap_of_remat_basic(self):
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
|
|
x = jnp.arange(4.)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
@partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable)
|
|
def f2(x):
|
|
return jnp.sin(x)
|
|
|
|
g2 = jax.grad(lambda x: f2(x).sum())(x) # doesn't crash
|
|
self.assertAllClose(g2, jnp.cos(x), check_dtypes=False)
|
|
|
|
def test_remat_scalar_residuals(self):
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
|
|
@partial(jax.remat, policy=jax.checkpoint_policies.everything_saveable)
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
return jnp.sin(jnp.sin(jnp.sin(x.sum()))[None])
|
|
|
|
x = jnp.arange(8.)
|
|
_ = jax.grad(lambda x: f(x).sum())(x) # doesn't crash
|
|
jtu.check_grads(f, (x,), modes=['rev'], order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_collectives_not_saved(self):
|
|
# regression test for bug in cl/612416803
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
|
|
@jax.remat
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
return jax.lax.all_gather(x, 'x') * jax.lax.all_gather(x, 'x')
|
|
|
|
saved_res = saved_residuals(f, jnp.ones(4))
|
|
self.assertLen(saved_res, 1)
|
|
|
|
def test_check_rep_false_doesnt_hit_rep_rules(self):
|
|
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
|
|
|
prim = core.Primitive('prim') # no rep rule here!
|
|
prim.multiple_results = True
|
|
prim.def_impl(lambda: [])
|
|
prim.def_abstract_eval(lambda: [])
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True)
|
|
def f():
|
|
prim.bind()
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
f()
|
|
with self.assertRaises(NotImplementedError):
|
|
jax.jit(f)()
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
|
|
def f2():
|
|
prim.bind()
|
|
|
|
f2()
|
|
jax.jit(f2)()
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False)
|
|
def f3():
|
|
jax.jit(prim.bind)()
|
|
|
|
f3()
|
|
jax.jit(f3)()
|
|
|
|
def test_vmap_spmd_axis_name(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
return x
|
|
|
|
x = jnp.arange(4 * 4).reshape(4, 4)
|
|
jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name='y'))(x).jaxpr
|
|
e, = jaxpr.eqns
|
|
self.assertIn('in_names', e.params)
|
|
self.assertEqual(e.params['in_names'], ({0: ('y',), 1: ('x',)},))
|
|
self.assertIn('out_names', e.params)
|
|
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
|
|
|
|
def test_vmap_of_grad_spmd_axis_name(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
@partial(
|
|
shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False
|
|
)
|
|
def f(x):
|
|
return jnp.sin(jnp.sum(x))
|
|
|
|
x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4)
|
|
put_x = jax.device_put(
|
|
x,
|
|
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')),
|
|
)
|
|
vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x)
|
|
vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x)
|
|
self.assertArraysEqual(
|
|
vmap_spmd_axisname_result, vmap_no_spmd_axisname_result
|
|
)
|
|
|
|
def test_vmap_spmd_axis_name_pair(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
|
def f(x):
|
|
return x
|
|
|
|
x = jnp.arange(4 * 4).reshape(4, 4)
|
|
jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr
|
|
e, = jaxpr.eqns
|
|
self.assertIn('in_names', e.params)
|
|
self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},))
|
|
self.assertIn('out_names', e.params)
|
|
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
|
|
|
def test_nested_vmap_with_capture_spmd_axis_name(self):
|
|
self.skipTest('https://github.com/jax-ml/jax/issues/23476')
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
def to_map_with_capture(x, y):
|
|
|
|
# We capture x from `to_map_with_capture`'s parameters.
|
|
def with_capture(y_slice):
|
|
# Inside of all the maps, we have 'mapped everything away'--we are just
|
|
# adding two scalars, but one by fully mapping across each of the two
|
|
# dimensions, the other by mapping across one and capturing the
|
|
# resulting scalar.
|
|
self.assertEqual(x.shape, ())
|
|
self.assertEqual(y_slice.shape, ())
|
|
return x + y_slice
|
|
|
|
# This vmap i will refer to as 'inner vmap'.
|
|
vmap_with_capture = jax.vmap(with_capture)
|
|
shmap_vmap_capture = shard_map(
|
|
vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y')
|
|
)
|
|
return shmap_vmap_capture(y)
|
|
|
|
# And this one is the outer vmap.
|
|
mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x')
|
|
x = jnp.arange(2).reshape(2)
|
|
y = jnp.arange(2 * 2).reshape(2, 2)
|
|
# Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap
|
|
# is over an axis of size 2. This is a problem at the moment.
|
|
jax.make_jaxpr(mapped)(x, y).jaxpr
|
|
|
|
def test_shard_map_abstract_mesh(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
|
|
|
def f(x):
|
|
return shard_map(lambda x: x, mesh=mesh.abstract_mesh, in_specs=P('x'),
|
|
out_specs=P('x'))(x)
|
|
|
|
out1 = jax.jit(f)(arr)
|
|
self.assertArraysEqual(out1, np_inp)
|
|
self.assertEqual(out1.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
out_eager = f(arr)
|
|
self.assertArraysEqual(out_eager, np_inp)
|
|
self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
out1, out2 = shard_map(lambda x, y: (x, y), mesh=mesh.abstract_mesh,
|
|
in_specs=P('x'), out_specs=P('x'))(np_inp, arr)
|
|
self.assertArraysEqual(out1, np_inp)
|
|
self.assertEqual(out1.sharding, NamedSharding(mesh, P('x')))
|
|
self.assertArraysEqual(out2, np_inp)
|
|
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
|
|
|
|
def test_different_devices_shmap_abstract_mesh_cache_hit(self):
|
|
if jax.device_count() < 4:
|
|
self.skipTest('Requires >=4 devices')
|
|
|
|
mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i')
|
|
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i')
|
|
abstract_mesh = mesh1.abstract_mesh
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'),
|
|
out_specs=P('i'))(x)
|
|
return jax.lax.sin(x)
|
|
|
|
with (
|
|
jtu.count_jit_tracing_cache_miss() as tracing_count,
|
|
jtu.count_jit_and_pmap_lowerings() as lowering_count,
|
|
jtu.count_jit_compilation_cache_miss() as compilation_count,
|
|
):
|
|
a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P()))
|
|
out_a = f(a) # tracing and lowering cached
|
|
|
|
# same num_devices but different devices.
|
|
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
|
f(b) # tracing and lowering cache *hit*
|
|
|
|
self.assertEqual(tracing_count(), 1)
|
|
self.assertEqual(lowering_count(), 1)
|
|
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
|
|
|
|
def test_shmap_abstract_mesh_errors(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
np_inp = np.arange(8)
|
|
abstract_mesh = mesh.abstract_mesh
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Please pass `jax.Array`s with a `NamedSharding` as input to"
|
|
" `shard_map` when passing `AbstractMesh` to the mesh argument"):
|
|
shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'),
|
|
out_specs=P('x'))(jnp.arange(8))
|
|
|
|
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
|
|
mesh2 = jtu.create_mesh((2,), 'y')
|
|
abs_mesh2 = mesh2.abstract_mesh
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Mesh shape of the input.*does not match the mesh shape passed to'
|
|
' shard_map'):
|
|
shard_map(lambda x: x, mesh=abs_mesh2, in_specs=P('y'),
|
|
out_specs=P('y'))(arr)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Please pass `jax.Array`s with a `NamedSharding` as input to'
|
|
' `shard_map` when passing `AbstractMesh` to the mesh argument.'):
|
|
shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'),
|
|
out_specs=P('x'))(np_inp)
|
|
|
|
arr_mesh2 = jax.device_put(np_inp, NamedSharding(mesh2, P('y')))
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Mesh shape of the input.*does not match the mesh shape passed to'
|
|
' shard_map'):
|
|
shard_map(lambda x, y: (x, y), mesh=abstract_mesh, in_specs=P('x'),
|
|
out_specs=P('x'))(arr, arr_mesh2)
|
|
|
|
@parameterized.parameters([True, False])
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
|
@jtu.thread_unsafe_test()
|
|
def test_debug_print_jit(self, jit):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest('TODO(b/384938613): Failing under shardy')
|
|
mesh = Mesh(jax.devices(), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
idx = jax.lax.axis_index('i')
|
|
jax.debug.print("instance {i} has value x={x}", i=idx, x=x)
|
|
y = jnp.cos(x)
|
|
jax.debug.print("instance {i} has value y={y}", i=idx, y=y)
|
|
return y
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
x = jnp.arange(2 * len(jax.devices()))
|
|
|
|
with jtu.capture_stdout() as output:
|
|
f(x)
|
|
jax.effects_barrier()
|
|
for i in range(len(jax.devices())):
|
|
self.assertIn(f'instance {i} has value', output())
|
|
|
|
def test_debug_print_eager(self):
|
|
mesh = Mesh(jax.devices(), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
jax.debug.print("x={x}", x=x)
|
|
y = jnp.cos(x)
|
|
jax.debug.print("y={y}", y=y)
|
|
return y
|
|
|
|
x = jnp.arange(2 * len(jax.devices()))
|
|
|
|
with jtu.capture_stdout() as output:
|
|
f(x)
|
|
jax.effects_barrier()
|
|
for i in range(len(jax.devices())):
|
|
self.assertIn(f'x=[{2*i} {2*i+1}]', output())
|
|
|
|
def test_partial_eval_custom_axis_env(self):
|
|
mesh = Mesh(jax.devices(), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(_):
|
|
_, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')),
|
|
None, None, length=1)
|
|
return idx
|
|
|
|
xs = jnp.arange(16.)
|
|
jax.eval_shape(jax.grad(lambda x: jax.remat(f)(x).sum().astype('float32')),
|
|
xs)
|
|
|
|
@jax.legacy_prng_key('allow')
|
|
def test_prngkeyarray_eager(self):
|
|
# https://github.com/jax-ml/jax/issues/15398
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
|
|
|
rng = jax.random.PRNGKey(0)
|
|
sharded_rng = jax.random.split(rng, num=4)
|
|
sharded_rng = jax.device_put(sharded_rng, sharding)
|
|
|
|
def f(key):
|
|
return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16,
|
|
dtype=jnp.int32)
|
|
|
|
pspec = P('x') if config.enable_custom_prng.value else P('x', None)
|
|
g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec)
|
|
_ = g(sharded_rng) # don't crash!
|
|
|
|
def test_functools_partial_rank_error(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial
|
|
def f(x):
|
|
return x
|
|
|
|
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',))
|
|
x = jnp.arange(4)
|
|
with self.assertRaises(ValueError):
|
|
g(x)
|
|
|
|
def test_in_specs_none_error(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
def f(x): return x
|
|
|
|
with self.assertRaisesRegex(TypeError, "but it was None"):
|
|
shard_map(f, mesh, in_specs=None, out_specs=P())(3.)
|
|
|
|
# TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug
|
|
# with self.assertRaises(TypeError):
|
|
# shard_map(f, mesh, in_specs=(None,), out_specs=P())(3.)
|
|
|
|
shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash
|
|
|
|
def test_scan_rep_rule(self):
|
|
mesh = jtu.create_mesh((2, 2,), ('x', 'y'))
|
|
|
|
def f(x, y, z):
|
|
x, y, z = x.sum(), y.sum(), z.sum()
|
|
def body(c, _):
|
|
c, *cs = c
|
|
return (*cs, c), None
|
|
out, _ = jax.lax.scan(body, (x, y, z), None, length=3)
|
|
return [jnp.expand_dims(a, 0) for a in out]
|
|
|
|
x = jnp.arange(4)
|
|
|
|
# doesn't crash, because out_spec assumes no replication (and there is none)
|
|
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=P(('x', 'y')))(x, x, x)
|
|
|
|
# does crash, because output incorrectly promises replication
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=P('x'))(x, x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=P('y'))(x, x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=P(None))(x, x, x)
|
|
|
|
def g(x, y, z):
|
|
x, y, z = x.sum(), y.sum(), z.sum()
|
|
def body(c, _):
|
|
return c, None
|
|
out, _ = jax.lax.scan(body, (x, y, z), None, length=1)
|
|
return [jnp.expand_dims(a, 0) for a in out]
|
|
|
|
# doesn't crash, because everything matches
|
|
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x)
|
|
|
|
# does crash, because the second guy is wrong
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
|
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)
|
|
|
|
def test_cond_rep_rule(self):
|
|
mesh = jtu.create_mesh((2, 2,), ('x', 'y'))
|
|
x = jnp.arange(4)
|
|
|
|
def f(x, y):
|
|
def true_fn(x, y):
|
|
return x
|
|
def false_fun(x, y):
|
|
return x + 1
|
|
return jax.lax.cond(True, true_fn, false_fun, x, y)
|
|
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x)
|
|
|
|
def f(x, y):
|
|
def true_fn(x, y):
|
|
return x
|
|
def false_fun(x, y):
|
|
return y
|
|
return jax.lax.cond(True, true_fn, false_fun, x, y)
|
|
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x)
|
|
|
|
def f(x, y):
|
|
def true_fn(x, y):
|
|
return x
|
|
def false_fun(x, y):
|
|
return x + 1
|
|
return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y)
|
|
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x)
|
|
|
|
def f(x, y):
|
|
def true_fn(x, y):
|
|
return x
|
|
def false_fun(x, y):
|
|
return x + 1
|
|
return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y)
|
|
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
|
|
with self.assertRaisesRegex(ValueError, "require replication"):
|
|
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x)
|
|
|
|
# https://github.com/jax-ml/jax/issues/24418
|
|
def f(a):
|
|
c = jax.lax.cond(jnp.any(a), lambda: 1, lambda: 0)
|
|
return jnp.reshape(c, a.shape)
|
|
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
a = jnp.array([True, False])
|
|
shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a)
|
|
|
|
def test_switch_rep_rule(self):
|
|
mesh = jtu.create_mesh((2, 2,), ('x', 'y'))
|
|
x = jnp.arange(4)
|
|
|
|
def f(n, x, y):
|
|
return jax.lax.switch(
|
|
n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y)
|
|
|
|
shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x)
|
|
|
|
def test_eager_custom_jvp_basic(self):
|
|
@jax.custom_jvp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
@foo.defjvp
|
|
def foo_jvp(primals, tangents):
|
|
(x,), (x_dot,) = primals, tangents
|
|
return foo(x), 3. * x_dot
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
|
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
|
|
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
|
|
self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False)
|
|
|
|
def test_eager_custom_vjp_basic(self):
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
def foo_fwd(x):
|
|
return foo(x), None
|
|
|
|
def foo_bwd(_, y_bar):
|
|
return 3. * y_bar,
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
|
y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.))
|
|
self.assertAllClose(y, (2. * jnp.arange(4.)).sum())
|
|
self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_axis_index_basic(self, jit):
|
|
def foo():
|
|
return jax.lax.axis_index('x')[None]
|
|
|
|
if jit:
|
|
foo = jax.jit(foo)
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))()
|
|
expected = jnp.arange(4.)
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_axis_index_twoaxes(self, jit):
|
|
def foo():
|
|
out1 = jax.lax.axis_index('i')[None, None]
|
|
out2 = jax.lax.axis_index('j')[None, None]
|
|
out3 = jax.lax.axis_index(('i', 'j'))[None, None]
|
|
return out1, out2, out3
|
|
|
|
if jit:
|
|
foo = jax.jit(foo)
|
|
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(),
|
|
out_specs=P('i', 'j'))()
|
|
expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2))
|
|
expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2))
|
|
expected3 = jnp.arange(8.).reshape(4, 2)
|
|
self.assertAllClose(ans1, expected1, check_dtypes=False)
|
|
self.assertAllClose(ans2, expected2, check_dtypes=False)
|
|
self.assertAllClose(ans3, expected3, check_dtypes=False)
|
|
|
|
def test_axis_index_eager(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P())
|
|
def foo():
|
|
val = jax.lax.psum(jax.lax.axis_index('x'), 'x')
|
|
return 1. if val > 0 else -1.
|
|
|
|
out = foo() # doesn't crash
|
|
self.assertEqual(out, 1.)
|
|
|
|
def test_jaxpr_shardings_with_no_outputs(self):
|
|
# https://github.com/jax-ml/jax/issues/15385
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('i'))
|
|
def f():
|
|
return jax.lax.iota(jnp.dtype('int32'), 4)
|
|
f() # don't crash
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(P('i'),), out_specs=P('i'))
|
|
def g(a_block):
|
|
i = jnp.arange(a_block.shape[0])
|
|
return i + a_block
|
|
|
|
g(np.arange(32)) # don't crash
|
|
|
|
def test_device_put(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
return x + jax.device_put(1)
|
|
|
|
x = jnp.arange(32.)
|
|
f(x) # doesn't crash
|
|
jax.jit(f)(x) # doesn't crash
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def g(x):
|
|
return x + jax.device_put(1, jax.devices()[0])
|
|
|
|
with self.assertRaisesRegex(ValueError, "got device"):
|
|
g(x)
|
|
|
|
# jit means device_puts are ignored, even those within shmap bodies, so no
|
|
# error!
|
|
jax.jit(g)(x) # doesn't crash
|
|
|
|
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
|
def test_key_array_with_replicated_last_tile_dim(self):
|
|
# See https://github.com/jax-ml/jax/issues/16137
|
|
|
|
mesh = jtu.create_mesh((2, 4), ('i', 'j'))
|
|
|
|
def f(rng):
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
|
|
check_rep=False)
|
|
def g(rng):
|
|
return jnp.array([jax.random.normal(rng[0])])
|
|
return g(jax.random.split(rng, 4))
|
|
|
|
jax.jit(f)(jax.random.key(0)) # doesn't crash
|
|
|
|
# same method appears in api_test.py:DCETest
|
|
# TODO(mattjj): consider moving this method to be a helper in jtu
|
|
def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool],
|
|
expected_used_inputs: list[bool],
|
|
expected_num_eqns: int | None = None,
|
|
check_diff: bool = True):
|
|
jaxpr_dce, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs)
|
|
core.check_jaxpr(jaxpr_dce)
|
|
self.assertEqual(used_inputs, expected_used_inputs)
|
|
if expected_num_eqns is not None:
|
|
all_jaxprs = it.chain([jaxpr_dce], core.subjaxprs(jaxpr_dce))
|
|
num_eqns = sum(len(subjaxpr.eqns) for subjaxpr in all_jaxprs)
|
|
self.assertEqual(num_eqns, expected_num_eqns, msg=str(jaxpr_dce))
|
|
|
|
rand_ = jtu.rand_small(np.random.RandomState(0))
|
|
rand = lambda v: rand_(v.aval.shape, v.aval.dtype)
|
|
consts = [rand(v) for v in jaxpr.constvars]
|
|
inputs = [rand(v) for v in jaxpr.invars ]
|
|
inputs_dce = [x for x, used in zip(inputs, used_inputs) if used]
|
|
full_outs = core.eval_jaxpr(jaxpr , consts, *inputs)
|
|
expected_outs_dce = [y for y, used in zip(full_outs, used_outputs) if used]
|
|
outs = core.eval_jaxpr(jaxpr_dce, consts, *inputs_dce)
|
|
self.assertAllClose(outs, expected_outs_dce)
|
|
|
|
if check_diff and expected_num_eqns != 0:
|
|
f = lambda *args: core.eval_jaxpr(jaxpr_dce, consts, *args)
|
|
jtu.check_grads(f, inputs_dce, order=2, modes=['rev'])
|
|
|
|
def test_returned_out_sharding(self):
|
|
mesh = jtu.create_mesh((1, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
inp = jax.device_put(jnp.zeros((2, 2)), s)
|
|
out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp)
|
|
self.assertEqual(out.sharding, s)
|
|
self.assertArraysEqual(out, inp)
|
|
|
|
def test_dce(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
def f(x, y, z):
|
|
@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')),
|
|
out_specs=(P(None, None), P(None, 'i'), P('i', 'j')))
|
|
def g(y, z):
|
|
return jnp.sin(x), jnp.cos(z), jnp.tan(y)
|
|
|
|
return g(y, z)
|
|
|
|
x = jnp.zeros((4, 4))
|
|
y = jnp.zeros((8, 8))
|
|
z = jnp.zeros((16, 16))
|
|
jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr
|
|
self.assertLen(jaxpr.eqns, 1)
|
|
self.assertLen(jaxpr.eqns[0].params['jaxpr'].eqns, 3)
|
|
|
|
# If we use all outputs, nothing should be deleted.
|
|
self.assert_dce_result(
|
|
jaxpr, used_outputs=[True, True, True],
|
|
expected_used_inputs=[True, True, True],
|
|
expected_num_eqns=1 + 3, # one outer eqn, three remain in body
|
|
check_diff=False)
|
|
|
|
# If we drop the last output, the second input should be dropped.
|
|
self.assert_dce_result(
|
|
jaxpr, used_outputs=[True, True, False],
|
|
expected_used_inputs=[True, False, True],
|
|
expected_num_eqns=1 + 2, # one outer eqn, two remain in body
|
|
check_diff=False)
|
|
# If we drop the second output, the last input should be dropped.
|
|
self.assert_dce_result(
|
|
jaxpr, used_outputs=[True, False, True],
|
|
expected_used_inputs=[True, True, False],
|
|
expected_num_eqns=1 + 2, # one outer eqn, two remain in body
|
|
check_diff=False)
|
|
# If we drop the latter two outputs, the latter two inputs should be dropped
|
|
self.assert_dce_result(
|
|
jaxpr, used_outputs=[True, False, False],
|
|
expected_used_inputs=[True, False, False],
|
|
expected_num_eqns=1 + 1, # one outer eqn, two remain in body
|
|
check_diff=False)
|
|
|
|
# Finally, try dropping the closed-over value.
|
|
self.assert_dce_result(
|
|
jaxpr, used_outputs=[False, True, False],
|
|
expected_used_inputs=[False, False, True],
|
|
expected_num_eqns=1 + 1, # one outer eqn, two remain in body
|
|
check_diff=False)
|
|
|
|
def test_post_process_partial_eval_with_scalar_res(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh,
|
|
in_specs=P(), out_specs=P())())(2.0)
|
|
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
|
|
|
def test_sharding_metadata_in_hlo_attrs(self):
|
|
mesh = Mesh(jax.devices(), ('i',))
|
|
x = jnp.arange(len(jax.devices()), dtype='float32')
|
|
y = jnp.array([3.], dtype='float32')
|
|
|
|
def foo(x):
|
|
x = jnp.sin(x)
|
|
x = shard_map(lambda x: jnp.cos(x * y), mesh,
|
|
in_specs=P('i'), out_specs=P('i'))(x)
|
|
x = shard_map(lambda x: jnp.cos(x * y), mesh,
|
|
in_specs=P('i'), out_specs=P('i'))(x)
|
|
return x
|
|
|
|
hlo_str = jax.jit(foo).lower(x).as_text("stablehlo", debug_info=True)
|
|
if config.use_shardy_partitioner.value:
|
|
if len(jax.devices()) > 1:
|
|
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
|
|
else:
|
|
# When devices == 1, the `sdy.manual_computation` is inlined.
|
|
self.assertEqual(0, hlo_str.count('sdy.manual_computation'))
|
|
else:
|
|
self.assertIn('call @shmap_body', hlo_str)
|
|
self.assertIn('call @shmap_body_0', hlo_str)
|
|
self.assertIn('%arg0: tensor<1xf32>', hlo_str)
|
|
self.assertIn('"[None]"', hlo_str)
|
|
self.assertIn('%arg1: tensor<1xf32>', hlo_str)
|
|
self.assertIn('"[(\'i\',)]"', hlo_str)
|
|
self.assertIn(
|
|
'-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str
|
|
)
|
|
|
|
def test_rewrite_process_call(self):
|
|
def f(x):
|
|
return core.call_p.bind(
|
|
lu.wrap_init(lambda x: [2. * x],
|
|
debug_info=api_util.debug_info("test", lambda x: [2. * x],
|
|
(x,), {})),
|
|
x)[0] * x
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
|
x = jnp.arange(4.)
|
|
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
def test_rewrite_post_process_call(self):
|
|
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
|
|
# behavior (i.e. no data dependence).
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@jax.jit
|
|
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
|
|
def f(x):
|
|
return core.call_p.bind(
|
|
lu.wrap_init(lambda: [2. * x],
|
|
debug_info=api_util.debug_info("test", lambda: [2. * x],
|
|
(), {})))[0] * x
|
|
|
|
x = jnp.arange(4.)
|
|
y = f(x)
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_rewrite_process_custom_jvp_call(self, jit):
|
|
@jax.custom_jvp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
@foo.defjvp
|
|
def foo_jvp(primals, tangents):
|
|
(x,), (x_dot,) = primals, tangents
|
|
return foo(x), 2. * x_dot
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(lambda x: foo(x) * x, mesh,
|
|
in_specs=(P('x'),), out_specs=P('x'))
|
|
if jit:
|
|
g = jax.jit(g)
|
|
|
|
x = jnp.arange(4.)
|
|
y = g(x)
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
y2, y_dot = jax.jvp(g, (x,), (3 * x,))
|
|
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
|
|
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_rewrite_process_custom_vjp_call(self, jit):
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
def foo_fwd(x):
|
|
return foo(x), None
|
|
|
|
def foo_bwd(_, y_bar):
|
|
return 2. * y_bar,
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(lambda x: foo(x) * x, mesh,
|
|
in_specs=(P('x'),), out_specs=P('x'))
|
|
if jit:
|
|
g = jax.jit(g)
|
|
|
|
x = jnp.arange(4.)
|
|
y = g(x)
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
|
|
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
|
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_rewrite_process_custom_vjp_call_match_more_replicated(self, jit):
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
def foo_fwd(x):
|
|
return foo(x), None
|
|
|
|
def foo_bwd(_, y_bar):
|
|
return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(lambda x: foo(x) * x, mesh,
|
|
in_specs=(P('x'),), out_specs=P('x'))
|
|
if jit:
|
|
g = jax.jit(g)
|
|
|
|
x = jnp.arange(4.)
|
|
y = g(x)
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
|
|
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
|
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)
|
|
|
|
def test_same_pspec_eager_shard_map(self):
|
|
# This behavior is not guaranteed by JAX and this test can be changed if
|
|
# the behavior changes.
|
|
mesh = jtu.create_mesh((1, 4, 1), ('data', 'seq', 'model'))
|
|
|
|
def f(x):
|
|
return x * x + 2
|
|
|
|
x = jnp.ones([2, 16, 4])
|
|
x_spec = jax.sharding.PartitionSpec("data", "seq", "model")
|
|
x = jax.device_put(x, jax.sharding.NamedSharding(mesh, x_spec))
|
|
shard_f = shard_map(f, mesh=mesh, in_specs=x_spec, out_specs=x_spec)
|
|
|
|
y = shard_f(x)
|
|
self.assertEqual(x_spec, y.sharding.spec)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit):
|
|
@jax.custom_vjp
|
|
def foo(x, y):
|
|
del y
|
|
return 2. * x
|
|
|
|
def foo_fwd(x, y):
|
|
return foo(x, y), y
|
|
|
|
def foo_bwd(y, _):
|
|
return y, None # diff! x_bar less replicated than primal/tangent
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
|
|
in_specs=(P(), P('x')), out_specs=P('x'))
|
|
if jit:
|
|
g = jax.jit(g)
|
|
|
|
x = jnp.arange(4.)
|
|
y = jnp.arange(4 * 4.)
|
|
|
|
z = g(x, y)
|
|
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)
|
|
|
|
z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y)
|
|
self.assertAllClose(z.sum(), z_, check_dtypes=False)
|
|
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
|
|
check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_rewrite_custom_vjp_call_jaxpr(self, jit):
|
|
@jax.custom_vjp
|
|
def foo(x):
|
|
return 2. * x
|
|
|
|
def foo_fwd(x):
|
|
return foo(x), None
|
|
|
|
def foo_bwd(_, y_bar):
|
|
return 2. * y_bar,
|
|
|
|
foo.defvjp(foo_fwd, foo_bwd)
|
|
|
|
def foo_scan(x):
|
|
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
|
|
return y
|
|
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
g = shard_map(lambda x: foo_scan(x) * x, mesh,
|
|
in_specs=(P('x'),), out_specs=P('x'))
|
|
if jit:
|
|
g = jax.jit(g)
|
|
|
|
x = jnp.arange(4.)
|
|
y = g(x)
|
|
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
|
|
|
y_, x_bar = jax.value_and_grad(lambda x: g(x).sum())(x)
|
|
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
|
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
|
|
|
def test_transpose_identity(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
|
def f(x):
|
|
return x
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.)
|
|
e, = jaxpr.jaxpr.eqns
|
|
self.assertEmpty(e.params['jaxpr'].eqns)
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,))
|
|
e, = jaxpr.jaxpr.eqns
|
|
self.assertEmpty(e.params['jaxpr'].eqns)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
|
def g(x):
|
|
return jax.jit(lambda x: 1. * x)(x)
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
|
|
e, = jaxpr.jaxpr.eqns
|
|
e1, e2 = e.params['jaxpr'].eqns
|
|
self.assertEmpty(e1.outvars)
|
|
self.assertLen(e2.params['jaxpr'].eqns, 1)
|
|
|
|
def test_fanout_specs_transpose_to_psum(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
|
|
def f(x):
|
|
return x
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
|
|
e, = jaxpr.jaxpr.eqns
|
|
e2, = e.params['jaxpr'].eqns
|
|
self.assertEqual(str(e2.primitive), 'psum2')
|
|
self.assertEqual(e2.params['axes'], ('x',))
|
|
|
|
def test_fanin_psum_transposes_to_fanout(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
|
|
def f(x):
|
|
return jax.lax.psum(x, 'x')
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.]))
|
|
e, = jaxpr.jaxpr.eqns
|
|
e1, = e.params['jaxpr'].eqns
|
|
self.assertEqual(str(e1.primitive), 'pbroadcast')
|
|
|
|
def test_psum_with_implicit_fanout_self_transposes(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
return jax.lax.psum(x, 'x')
|
|
|
|
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
|
|
e, = jaxpr.jaxpr.eqns
|
|
e1, e2 = e.params['jaxpr'].eqns
|
|
self.assertEqual(str(e1.primitive), 'psum2')
|
|
self.assertEqual(str(e2.primitive), 'pbroadcast')
|
|
|
|
def test_transpose_float0(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
s = jax.sharding.NamedSharding(mesh, P(None, 'x'))
|
|
|
|
# vjp that triggers float0
|
|
@jax.custom_vjp
|
|
def f(x, _):
|
|
return x
|
|
def f_fwd(x, y):
|
|
return x, jnp.zeros(shape=y.shape, dtype=np.int32)
|
|
def f_rev(tmp, g):
|
|
return (g, tmp)
|
|
f.defvjp(f_fwd, f_rev)
|
|
|
|
# trivial vjp that consumes float0
|
|
@jax.custom_vjp
|
|
def g(x, y):
|
|
return x, y
|
|
def g_fwd(x, y):
|
|
return jax.vjp(lambda x, y: (x, y), x, y)
|
|
def g_bwd(vjp_fn, result):
|
|
return vjp_fn(result)
|
|
g.defvjp(g_fwd, g_bwd)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(P('x'), P()), out_specs=P())
|
|
def f_shmapped(x, y):
|
|
return jax.lax.psum(f(x, y).sum(), axis_name=('x'))
|
|
|
|
@partial(shard_map, mesh=mesh, check_rep=False,
|
|
in_specs=P('x'), out_specs=(P('x'), P()))
|
|
def f_shmapped2(x, y):
|
|
return g(x, y)
|
|
|
|
def f_wrapper(x, y):
|
|
x, y = jax.lax.map(lambda xs: f_shmapped2(xs[0], xs[1]), (x, y))
|
|
return jax.lax.map(lambda xs: f_shmapped(xs[0], xs[1]), (x, y)).sum()
|
|
|
|
@partial(jax.jit, in_shardings=s,
|
|
out_shardings=jax.sharding.NamedSharding(mesh, P()))
|
|
def example(x, y):
|
|
return jax.grad(f_wrapper, allow_int=True, argnums=(0, 1))(x, y)
|
|
|
|
x = np.zeros(shape=(8,16), dtype=np.float32)
|
|
y = np.zeros(shape=(8,16), dtype=np.int32)
|
|
# Doesn't crash.
|
|
dx, dy = example(x, y)
|
|
self.assertEqual(dy.dtype, jax.dtypes.float0)
|
|
|
|
def test_rewrite_binops(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.))
|
|
e, = jaxpr.jaxpr.eqns
|
|
e = e.params['jaxpr'].eqns[0]
|
|
self.assertEqual(e.primitive.name, 'pbroadcast')
|
|
self.assertEqual(e.params['axes'], ('x',))
|
|
|
|
def test_rewrite_scan(self):
|
|
mesh = jtu.create_mesh((4,), ('x',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
|
def f(x):
|
|
x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None,
|
|
length=2)
|
|
return x
|
|
|
|
jaxpr = jax.make_jaxpr(f)(jnp.arange(4.))
|
|
e, = jaxpr.jaxpr.eqns
|
|
e, = e.params['jaxpr'].eqns
|
|
e1, e2 = e.params['jaxpr'].eqns
|
|
self.assertEqual(e1.primitive.name, 'psum2')
|
|
self.assertEqual(e2.primitive.name, 'pbroadcast')
|
|
|
|
def test_check_rep_false_grads(self):
|
|
if jtu.is_device_tpu(5, 'e'):
|
|
self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e')
|
|
|
|
# This test is redundant with the systematic tests below, but it serves as a
|
|
# direct regression test for a bug.
|
|
mesh = jtu.create_mesh((4,), ('heads',))
|
|
|
|
def f(q, k, v):
|
|
def body(q, k, v):
|
|
return q * k[None, :] + v[None, :]
|
|
|
|
out = shard_map(body, mesh, check_rep=False,
|
|
in_specs=(q_spec, kv_spec, kv_spec,),
|
|
out_specs=q_spec)(q, k, v)
|
|
return out.sum()
|
|
|
|
q_spec = P('heads', None)
|
|
kv_spec = P(None)
|
|
q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec))
|
|
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
|
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
|
|
|
if jtu.device_under_test() == 'tpu':
|
|
rtol = 2e-2
|
|
else:
|
|
rtol = 1e-2
|
|
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol)
|
|
|
|
def test_axis_env_extension_regression(self):
|
|
def foo(x):
|
|
i = jax.lax.axis_index('x')
|
|
return jnp.exp(x) + i.astype(x.dtype)
|
|
|
|
@partial(jax.remat, policy=lambda *args, **kwargs: True)
|
|
def bar(x):
|
|
return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),),
|
|
out_specs=P('x'), check_rep=False)(x)
|
|
|
|
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
|
|
|
|
@parameterized.parameters(it.product([True, False], repeat=2))
|
|
def test_res_forwarding_optimization(self, jit, remat):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
return jax.lax.exp(x)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
if remat:
|
|
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
|
f = jax.remat(f, policy=policy)
|
|
g = lambda x: f(x).sum()
|
|
|
|
x = jnp.arange(16.)
|
|
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
|
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
|
e1, _, e2 = jaxpr.eqns
|
|
self.assertLen(e1.outvars, 1) # only primal output
|
|
self.assertLen(e2.invars, 2) # res and cotangent inputs
|
|
self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1)
|
|
|
|
@parameterized.parameters(it.product([True, False], repeat=2))
|
|
def test_res_forwarding_optimization_complex(self, jit, remat):
|
|
# like the above test, but a different function `f`
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
return jax.lax.exp(x.sum()) + x, jax.lax.exp(x)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
if remat:
|
|
policy = jax.ad_checkpoint.checkpoint_policies.everything_saveable
|
|
f = jax.remat(f, policy=policy)
|
|
g = lambda x: sum(f(x)).sum()
|
|
|
|
x = jnp.arange(16.)
|
|
jaxpr_ = jax.make_jaxpr(jax.grad(g))(x)
|
|
jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals))
|
|
e1, _, e2 = jaxpr.eqns
|
|
self.assertLen(e1.outvars, 2) # one primal and one res output
|
|
self.assertLen(e2.invars, 4) # two res and two cotangent inputs
|
|
self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_check_rep_failure_inside_rule(self, jit):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
def loss(w, x):
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
|
def f(x):
|
|
return jax.lax.psum(((w * x) ** 2).sum(), 'i')
|
|
return f(x)
|
|
|
|
if jit:
|
|
loss = jax.jit(loss)
|
|
|
|
jax.grad(loss)(3.0, jnp.arange(8.)) # don't crash
|
|
|
|
def test_conv_general_dilated(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
|
|
dot = partial(lax.conv_general_dilated, window_strides=(),
|
|
padding='VALID', dimension_numbers=('NC', 'IO', 'NC'))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(P(None, 'i'), P('i', None)),
|
|
out_specs=P(None, None))
|
|
def f(x, y):
|
|
return lax.psum(dot(x, y), 'i')
|
|
|
|
a = jnp.ones((16, 32))
|
|
b = jnp.ones((32, 8))
|
|
y = f(a, b) # don't crash
|
|
self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_cumsum(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
x = jnp.arange(8.)
|
|
shard_map(jnp.cumsum, mesh=mesh, in_specs=P('i'), out_specs=P('i')
|
|
)(x) # don't crash
|
|
|
|
def test_custom_jvp_inside_jit(self):
|
|
mesh = jtu.create_mesh((4,), ('batch',))
|
|
x = shard_map(jax.jit(jax.nn.relu),
|
|
mesh=mesh, in_specs=P('batch'),
|
|
out_specs=P('batch'))(jnp.arange(16.)) # don't crash
|
|
|
|
def test_random_normal_rules(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
keys = jax.random.split(jax.random.key(0), 4)
|
|
shard_map(lambda k: jax.random.normal(k[0], (1,)),
|
|
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(keys) # don't crash
|
|
|
|
def test_erf_rules(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
x = jnp.arange(16.)
|
|
shard_map(jax.lax.erf,
|
|
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash
|
|
|
|
def test_error_for_variable_num_args(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
def f(*args):
|
|
return args[0] @ args[1]
|
|
|
|
shard_f = shard_map(
|
|
f, mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y'))
|
|
|
|
with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"):
|
|
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))
|
|
|
|
def test_custom_vjp_replication_error_message_hint(self):
|
|
mesh = jtu.create_mesh((4,), 'i')
|
|
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return jax.lax.psum(x, 'i')
|
|
def f_fwd(x):
|
|
return f(x), None
|
|
def f_bwd(_, g):
|
|
return jax.lax.psum(g, 'i'),
|
|
f.defvjp(f_fwd, f_bwd)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
|
def g(x):
|
|
return f(f(x))
|
|
|
|
y, grad = jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4))
|
|
# first psum sums, second psum multiplies by 4
|
|
self.assertAllClose(y, (jnp.ones(4) * 4).sum(), check_dtypes=False)
|
|
# two psums on the backward pass, each one multiplies by 4
|
|
self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False)
|
|
|
|
def test_repeated_psum_allowed(self):
|
|
# https://github.com/jax-ml/jax/issues/19175
|
|
mesh = jtu.create_mesh((4,), 'i')
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
|
def g(x):
|
|
return jax.lax.psum(jax.lax.psum(x, 'i'), 'i')
|
|
|
|
y = g(jnp.arange(4.))
|
|
self.assertAllClose(y, jnp.arange(4.).sum(keepdims=True) * 4,
|
|
check_dtypes=False)
|
|
|
|
def test_approx_top_k(self):
|
|
mesh = Mesh(np.array(jax.devices()[:2]), ('i',))
|
|
|
|
x = jnp.array([3.0, 1.0, 4.0, 2.0])
|
|
_ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh, P('i'), P('i'))(x)
|
|
|
|
def test_disable_jit(self):
|
|
mesh = Mesh(np.array(jax.devices()[:2]), ('i',))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(x):
|
|
return x
|
|
|
|
x = jnp.arange(8.)
|
|
with jax.disable_jit():
|
|
f(x) # don't crash
|
|
|
|
@parameterized.parameters(it.product(range(4), repeat=3))
|
|
@jtu.run_on_devices("cpu")
|
|
def test_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd):
|
|
num_args = 3
|
|
rng = np.random.RandomState(seed)
|
|
mesh = Mesh(np.array(jax.devices()[:1]), ('i',))
|
|
|
|
in_perm = rng.permutation(num_args)
|
|
out_perm = rng.permutation(num_args)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(inputs):
|
|
inputs = [inputs[i] for i in in_perm]
|
|
outputs = inputs[:num_input_fwd] + [
|
|
jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i])
|
|
for i in range(num_args - num_input_fwd)]
|
|
return [outputs[i] for i in out_perm]
|
|
|
|
jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1,
|
|
modes=['rev'], atol=1e-3, rtol=1e-3)
|
|
|
|
def test_partial_auto(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
|
|
{AxisType.Manual: ('i',), AxisType.Auto: ('j',)})
|
|
x = jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P(None, 'j')))
|
|
return x * x
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
auto=frozenset({'j'}))(x)
|
|
return jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
'in_shardings=[<@mesh, [{"i"}, {?}]>]'
|
|
' out_shardings=[<@mesh, [{"i"}, {?}]>] manual_axes={"i"}',
|
|
f.lower(v).as_text(),
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,'
|
|
' replicated}}',
|
|
f.lower(v).as_text('hlo'),
|
|
)
|
|
self.assertAllClose(v * v, f(v), check_dtypes=False)
|
|
|
|
def test_partial_auto_explicit_no_use_mesh(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'),
|
|
axis_types=(AxisType.Explicit,) * 2)
|
|
|
|
def g(x):
|
|
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
|
|
{AxisType.Manual: ('i',), AxisType.Explicit: ('j',)})
|
|
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
|
|
out = x * x
|
|
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
|
|
return out
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
auto=frozenset({'j'}))(x)
|
|
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
|
|
return x
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
out = f(v)
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v * v, out, check_dtypes=False)
|
|
|
|
@jtu.with_user_mesh((2, 2), ('i', 'j'))
|
|
def test_partial_auto_explicit(self, mesh):
|
|
def g(x):
|
|
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
|
|
{AxisType.Manual: ('i',), AxisType.Explicit: ('j',)})
|
|
self.assertEqual(x.aval.sharding.spec, P(None, 'j'))
|
|
out = x * x
|
|
self.assertEqual(out.aval.sharding.spec, P(None, 'j'))
|
|
return out
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
auto=frozenset({'j'}))(x)
|
|
self.assertEqual(x.aval.sharding.spec, P('i', 'j'))
|
|
return x
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
out = f(v)
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v * v, out, check_dtypes=False)
|
|
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
'sdy.sharding_constraint %1 <@mesh, [{}, {"j"}]>',
|
|
f.lower(v).as_text(),
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
'mhlo.sharding = "{devices=[1,2,2]<=[2,2]T(1,0) last_tile_dims={manual}}"}',
|
|
f.lower(v).as_text(),
|
|
)
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
return jnp.sum(f(x))
|
|
|
|
jax.grad(h)(v) # doesn't crash
|
|
jax.jit(jax.grad(h))(v) # doesn't crash
|
|
|
|
@jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l'))
|
|
def test_partial_auto_explicit_multi_explicit(self, mesh):
|
|
def g(x):
|
|
self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict,
|
|
{AxisType.Manual: ('i', 'j'),
|
|
AxisType.Explicit: ('k', 'l')})
|
|
self.assertEqual(x.aval.sharding.spec, P(None, None, 'k', 'l'))
|
|
out = x.T
|
|
self.assertEqual(out.aval.sharding.spec, P('l', 'k', None, None))
|
|
return out
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', 'j', None, None),
|
|
out_specs=P('i', 'j', None, None),
|
|
auto=frozenset({'k', 'l'}))(x)
|
|
self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None))
|
|
return x
|
|
|
|
v = jnp.arange(64.).reshape(4, 2, 2, 4)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j', 'k', 'l')))
|
|
|
|
out = f(v)
|
|
self.assertEqual(
|
|
out.sharding, NamedSharding(mesh, P(('i', 'l'), ('j', 'k'), None, None)))
|
|
|
|
def test_partial_auto_propagate_through(self):
|
|
mesh = jtu.create_mesh((2, 2, 2), ('i', 'j', 'k'))
|
|
sharding = jax.sharding.NamedSharding(mesh, P('i'))
|
|
|
|
def g(x):
|
|
return jax.lax.with_sharding_constraint(x * x, sharding)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(
|
|
g,
|
|
mesh,
|
|
in_specs=P(),
|
|
out_specs=P(),
|
|
check_rep=False,
|
|
auto=frozenset({'i'}),
|
|
)(x)
|
|
|
|
v = jnp.arange(32.0).reshape(4, 8)
|
|
v = jax.device_put(v, sharding)
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
'in_shardings=[<@mesh, [{?}, {?}]>]'
|
|
' out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={"j", "k"}',
|
|
f.lower(v).as_text(),
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
'sharding={devices=[1,1,4,2]<=[2,4]T(1,0) last_tile_dims={manual,'
|
|
' replicated}}',
|
|
f.lower(v).as_text('hlo'),
|
|
)
|
|
actual = f(v)
|
|
self.assertAllClose(v * v, actual, check_dtypes=False)
|
|
self.assertEqual(actual.sharding, sharding)
|
|
|
|
def test_shmap_close_over_unused_params(self):
|
|
mesh = jtu.create_mesh((2,), ("data",))
|
|
|
|
def loss_fn(_, batch):
|
|
return jnp.sum(batch)
|
|
|
|
@jax.jit
|
|
def update_fn(params, batch):
|
|
def grad_fn(batch):
|
|
return jax.value_and_grad(loss_fn)(params, batch)
|
|
return shard_map(grad_fn, mesh=mesh, in_specs=P("data"), out_specs=P(),
|
|
check_rep=False)(batch)
|
|
|
|
arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8),
|
|
NamedSharding(mesh, P()))
|
|
params = jnp.copy(arr_sharded)
|
|
update_fn(params, arr_sharded) # doesn't crash
|
|
|
|
def test_shmap_close_over_unused_params_vmap(self):
|
|
mesh = jtu.create_mesh((2,), ("data",))
|
|
|
|
def loss_fn(params, batch):
|
|
return jnp.sum(params) + jnp.sum(batch)
|
|
|
|
@jax.jit
|
|
def update_fn(params, batch):
|
|
def grad_fn(batch):
|
|
return jax.value_and_grad(loss_fn)(params, batch)
|
|
return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"),
|
|
out_specs=P("data"), check_rep=False)(batch)
|
|
|
|
arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8),
|
|
NamedSharding(mesh, P()))
|
|
params = jnp.copy(arr_sharded)
|
|
update_fn(params, arr_sharded) # doesn't crash
|
|
|
|
def test_sharded_prng_with_abstract_mesh(self):
|
|
shape = (8, 2, 2)
|
|
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
|
|
|
np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape)
|
|
key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl)
|
|
key = jax.device_put(key, NamedSharding(mesh, P()))
|
|
|
|
@jax.jit
|
|
def shard_key(key):
|
|
return shard_map(
|
|
lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key)
|
|
|
|
out = shard_key(key)
|
|
self.assertTrue(out.sharding.is_equivalent_to(NamedSharding(mesh, P()),
|
|
out.ndim))
|
|
|
|
def test_partial_auto_error_wsc_manual(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
x = jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
return x * x
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j'}))(x)
|
|
return jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
with self.assertRaisesRegex(ValueError, "manual"):
|
|
f(v)
|
|
|
|
def test_partial_auto_error_invalid_auto(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
x = jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
return x * x
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'k'}))(x)
|
|
return jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"):
|
|
f(v)
|
|
|
|
def test_partial_auto_error_wrong_in_specs(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
x = jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
return x * x
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = shard_map(g, mesh,
|
|
in_specs=P('i', 'j'),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j'}))(x)
|
|
return jax.lax.with_sharding_constraint(
|
|
x, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"):
|
|
f(v)
|
|
|
|
def test_nested_partial_auto(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
return x * x
|
|
|
|
def h(x):
|
|
return shard_map(g, mesh,
|
|
in_specs=P(None, 'j'),
|
|
out_specs=P(None, 'j'))(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(h, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j'}))(x)
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
|
|
|
def test_grad_nested_partial_auto(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
# manual: 'i', 'j'
|
|
return x * x
|
|
|
|
def h(x):
|
|
# auto: 'j', manual: 'i'
|
|
return shard_map(g, mesh,
|
|
in_specs=P(None, 'j'),
|
|
out_specs=P(None, 'j'))(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
# auto: 'i', 'j'
|
|
return shard_map(h, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j'}))(x).sum()
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
|
|
|
|
def test_grad_nested_partial_auto_with_residuals(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
return x * x * x
|
|
|
|
def h(x):
|
|
return shard_map(g, mesh,
|
|
in_specs=P(None, 'j'),
|
|
out_specs=P(None, 'j'))(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(h, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j'}))(x).sum()
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False)
|
|
|
|
def test_axis_size_1_partial_auto(self):
|
|
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
|
|
|
|
def h(x):
|
|
return x * x
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(h, mesh,
|
|
in_specs=P('i', None),
|
|
out_specs=P('i', None),
|
|
check_rep=False,
|
|
auto=frozenset({'j', 'k'}))(x)
|
|
|
|
v = jnp.arange(32.).reshape(4, 8)
|
|
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
|
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
|
|
|
def test_partial_auto_of_pjit(self):
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
|
|
def h():
|
|
def _make_zeros():
|
|
return jnp.zeros(())
|
|
s = jax.sharding.NamedSharding(mesh, P())
|
|
y = jax.jit(_make_zeros, out_shardings=s)()
|
|
return y.reshape((1,))
|
|
|
|
def f():
|
|
return shard_map(
|
|
h, mesh, in_specs=(),
|
|
out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))()
|
|
|
|
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
|
|
|
|
def test_partial_auto_of_pjit_different_mesh(self):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest(
|
|
'Shardy requires the mesh axis names to be the same across '
|
|
'the entire computation.'
|
|
)
|
|
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
|
mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l'))
|
|
|
|
def h():
|
|
def _make_zeros():
|
|
return jnp.zeros(())
|
|
s = jax.sharding.NamedSharding(mesh2, P())
|
|
y = jax.jit(_make_zeros, out_shardings=s)()
|
|
return y.reshape((1,))
|
|
|
|
def f():
|
|
return shard_map(
|
|
h, mesh, in_specs=(),
|
|
out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))()
|
|
|
|
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
|
|
|
|
def test_partial_auto_axis_index(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
out_sharding = NamedSharding(mesh, P('i', None))
|
|
|
|
@partial(jax.jit, out_shardings=out_sharding)
|
|
def f():
|
|
return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1),
|
|
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
|
check_rep=False, auto=frozenset({'j'}))()
|
|
|
|
self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1))
|
|
|
|
def test_partial_auto_axis_index_degenerated_axis(self):
|
|
mesh = jtu.create_mesh((1, 2), ('i', 'j'))
|
|
out_sharding = NamedSharding(mesh, P('i', None))
|
|
|
|
@partial(jax.jit, out_shardings=out_sharding)
|
|
def f():
|
|
return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1),
|
|
mesh, in_specs=P('i', None), out_specs=P('i', None),
|
|
check_rep=False, auto=frozenset({'j'}))()
|
|
|
|
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
|
|
|
|
def test_partial_auto_ppermute(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
x = jnp.arange(8.)
|
|
|
|
def g(x):
|
|
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j')))
|
|
return jax.lax.ppermute(x, 'i', [(0, 1), (1, 2), (2, 3), (3, 0)])
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(g,
|
|
mesh, in_specs=P('i'), out_specs=P('i'),
|
|
check_rep=False, auto=frozenset({'j'}))(x)
|
|
|
|
y = f(x) # don't crash
|
|
self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]),
|
|
check_dtypes=False)
|
|
|
|
# TODO(parkers,mattjj): get XLA to support this too
|
|
# def test_partial_auto_all_to_all(self):
|
|
#
|
|
# mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
# x = jnp.arange(128.).reshape(16, 8)
|
|
#
|
|
# def g(x):
|
|
# x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j')))
|
|
# return jax.lax.all_to_all(x, 'i', 0, 1, tiled=True)
|
|
#
|
|
# @jax.jit
|
|
# def f(x):
|
|
# return shard_map(g,
|
|
# mesh, in_specs=P('i', None), out_specs=P(None, 'i'),
|
|
# check_rep=False, auto=frozenset({'j'}))(x)
|
|
#
|
|
# f(x) # don't crash
|
|
|
|
def test_partial_auto_debug_print(self):
|
|
if config.use_shardy_partitioner.value:
|
|
raise unittest.SkipTest("shardy error")
|
|
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
x = jnp.arange(8.)
|
|
|
|
def g(x):
|
|
jax.debug.print('{}', x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(g,
|
|
mesh, in_specs=P('i'), out_specs=None,
|
|
check_rep=False, auto=frozenset({'j'}))(x)
|
|
|
|
y = f(x) # don't crash
|
|
|
|
def test_partial_auto_of_random_keys(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
keys = jax.random.split(jax.random.key(0), 8)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(lambda k: k,
|
|
mesh, in_specs=P('i'), out_specs=P('i'),
|
|
check_rep=False, auto=frozenset({'j'}))(keys)
|
|
|
|
y = f(keys) # doesn't crash
|
|
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
|
|
check_dtypes=False)
|
|
|
|
def test_partial_auto_of_random_keys_slice(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
keys = jax.random.split(jax.random.key(0), 8).reshape(4, 2)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return shard_map(lambda k: k[0],
|
|
mesh, in_specs=P('i'), out_specs=P('i'),
|
|
check_rep=False, auto=frozenset({'j'}))(x)
|
|
|
|
f(keys) # doesn't crash
|
|
|
|
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
|
|
# https://github.com/jax-ml/jax/pull/21032
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=P('j'),
|
|
out_specs=P('j'),
|
|
)
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
|
|
|
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
|
|
|
|
def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self):
|
|
# https://github.com/jax-ml/jax/pull/21056
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
@partial(jax.remat, policy=lambda *_, **__: True)
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=P('j'),
|
|
out_specs=P('j'),
|
|
)
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
|
|
|
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
|
|
|
|
def test_grad_shmap_residuals_axis_names_in_mesh_order(self):
|
|
# https://github.com/jax-ml/jax/issues/21236
|
|
mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
|
|
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=P('j'),
|
|
out_specs=P('j'),
|
|
)
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
xs = jnp.arange(16.)
|
|
|
|
ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs)
|
|
if config.use_shardy_partitioner.value:
|
|
self.assertIn(
|
|
'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text()
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
"{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text()
|
|
)
|
|
|
|
def test_vmap_spmd_axis_name_error(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=P('i'),
|
|
out_specs=P('i'),
|
|
)
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
|
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
|
|
jax.vmap(f, spmd_axis_name='i')(xs)
|
|
|
|
@partial(
|
|
shard_map,
|
|
mesh=mesh,
|
|
in_specs=P('j'),
|
|
out_specs=P(('i', 'j')),
|
|
check_rep=False,
|
|
)
|
|
def g(x):
|
|
return jnp.sin(x)
|
|
|
|
xs = jnp.arange(4 * 16.).reshape(4, 16)
|
|
with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"):
|
|
jax.vmap(g, spmd_axis_name='i')(xs)
|
|
|
|
def test_in_spec_none(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
x = jnp.arange(8).reshape(4, 2)
|
|
|
|
def f(o, x):
|
|
self.assertIs(o, obj)
|
|
return jnp.sin(x)
|
|
|
|
obj = object()
|
|
y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
obj = None
|
|
y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
def f2(o, x):
|
|
self.assertIsInstance(o, dict)
|
|
self.assertIs(o['a'], obj['a'])
|
|
return jnp.sin(x)
|
|
|
|
obj = {'a': object()}
|
|
y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
def f3(x, o):
|
|
self.assertIs(o, obj)
|
|
return jnp.sin(x)
|
|
|
|
obj = object()
|
|
y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
obj = None
|
|
y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
def f4(o1, o2, x, o3):
|
|
self.assertIs(o1, obj1)
|
|
self.assertIs(o2[0], obj2[0])
|
|
self.assertIs(o2[1], obj2[1])
|
|
self.assertIs(o3, obj3)
|
|
return jnp.sin(x)
|
|
|
|
obj1 = object()
|
|
obj2 = (object(), object())
|
|
obj3 = object()
|
|
y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3)
|
|
self.assertAllClose(y, jnp.sin(x), check_dtypes=False)
|
|
|
|
def test_in_spec_none_divisibility_errors(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
x = jnp.arange(4).reshape(2, 2)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'divisible'):
|
|
shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'divisible'):
|
|
shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object())
|
|
|
|
with self.assertRaisesRegex(ValueError, 'divisible'):
|
|
shard_map(lambda *_: None, mesh, (P('i'), None), None
|
|
)(x, (object(), object()))
|
|
|
|
with self.assertRaisesRegex(ValueError, 'divisible'):
|
|
shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None,
|
|
)(x, (object(), object()))
|
|
|
|
with self.assertRaisesRegex(ValueError, 'divisible'):
|
|
shard_map(lambda *_: None, mesh, ((None, None), P('i')), None,
|
|
)((object(), object()), x)
|
|
|
|
def test_in_spec_none_rank_errors(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
x = jnp.arange(4)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'rank'):
|
|
shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'rank'):
|
|
shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object())
|
|
|
|
with self.assertRaisesRegex(ValueError, 'rank'):
|
|
shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None
|
|
)(x, (object(), object()))
|
|
|
|
with self.assertRaisesRegex(ValueError, 'rank'):
|
|
shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None,
|
|
)(x, (object(), object()))
|
|
|
|
with self.assertRaisesRegex(ValueError, 'rank'):
|
|
shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None,
|
|
)((object(), object()), x)
|
|
|
|
def test_custom_linear_solve_rep_rules(self):
|
|
# https://github.com/jax-ml/jax/issues/20162
|
|
mesh = jtu.create_mesh((1,), ('i',))
|
|
a = jnp.array(1).reshape(1, 1)
|
|
b = jnp.array(1).reshape(1)
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
|
def f(a, b):
|
|
c = jnp.linalg.solve(a, b)
|
|
return c
|
|
|
|
_ = f(a, b) # don't crash
|
|
|
|
def test_temporary_error_suppression_flag(self):
|
|
mesh = jtu.create_mesh((2,), ('i',))
|
|
|
|
def f(x, y):
|
|
z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True),
|
|
mesh=mesh, in_specs=(P(None), P('i')), out_specs=P(None),
|
|
check_rep=False,
|
|
)(x, y)
|
|
return z
|
|
|
|
y = jnp.arange(8)
|
|
xs = jnp.arange(32).reshape(4, 8)
|
|
with self.assertRaisesRegex(ValueError, 'vmap spmd_axis_name cannot appear in'):
|
|
_ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y)
|
|
|
|
with config.disable_vmap_shmap_error():
|
|
_ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y)
|
|
|
|
def test_in_spec_none_hashability(self):
|
|
mesh = jtu.create_mesh((2,), ('i',))
|
|
|
|
class A:
|
|
def __hash__(self):
|
|
raise Exception
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=())
|
|
def f(a):
|
|
return ()
|
|
|
|
f(A()) # don't crash
|
|
|
|
def test_get_check_rep(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
|
|
def f(x, reduce_along, use_jit):
|
|
out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along))
|
|
|
|
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec)
|
|
def g(x):
|
|
result = lax.psum(x, axis_name=reduce_along)
|
|
def check_rep(result):
|
|
self.assertEqual(
|
|
jax.experimental.shard_map.get_replication(result),
|
|
set(reduce_along))
|
|
return result
|
|
result = check_rep(result)
|
|
result = jax.vmap(check_rep)(result)
|
|
return result
|
|
if use_jit:
|
|
return jax.jit(g)(x)
|
|
else:
|
|
return g(x)
|
|
|
|
for use_jit in [True, False]:
|
|
x = np.zeros((8, 8), dtype=np.float32)
|
|
f(x, reduce_along=('y',), use_jit=use_jit)
|
|
f(x, reduce_along=('x',), use_jit=use_jit)
|
|
f(x, reduce_along=('x', 'y'), use_jit=use_jit)
|
|
|
|
def test_pmin(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
x = jnp.arange(8., dtype=np.float32)
|
|
y = shard_map(lambda x: jax.lax.pmin(x, 'i'),
|
|
mesh=mesh, in_specs=P('i'), out_specs=P()
|
|
)(x) # don't crash
|
|
self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32))
|
|
|
|
def test_pmax(self):
|
|
mesh = jtu.create_mesh((4,), ('i',))
|
|
x = jnp.arange(8., dtype=np.float32)
|
|
y = shard_map(lambda x: jax.lax.pmax(x, 'i'),
|
|
mesh=mesh, in_specs=P('i'), out_specs=P()
|
|
)(x) # don't crash
|
|
self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32))
|
|
|
|
|
|
class FunSpec(NamedTuple):
|
|
name: str
|
|
num_inputs: int
|
|
fun: Callable
|
|
out_rep: Callable
|
|
valid_types: Callable | None = None
|
|
|
|
fun_specs = [
|
|
FunSpec('id', 1, lambda x: x, lambda r: r),
|
|
FunSpec('flip', 2, lambda x, y: (y, x), lambda r_x, r_y: (r_y, r_x)),
|
|
FunSpec('transpose', 1, lambda x: x.T, lambda r: r),
|
|
FunSpec('ravel', 1, lambda x: x.ravel(), lambda r: r),
|
|
FunSpec(
|
|
'dot', 2, jnp.dot, lambda r1, r2: r1 & r2,
|
|
lambda x1, x2: (x1.shape and x2.shape and
|
|
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]),
|
|
),
|
|
FunSpec(
|
|
'sin_dot_sin', 2,
|
|
lambda x1, x2: jnp.sin(jnp.dot(jnp.sin(x1), x2)),
|
|
lambda r1, r2: r1 & r2,
|
|
lambda x1, x2: (x1.shape and x2.shape and
|
|
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0])),
|
|
FunSpec('relu', 1, lambda x: jax.nn.relu(x + 1) - 1, lambda r: r),
|
|
]
|
|
|
|
input_shapes = [
|
|
jax.ShapeDtypeStruct(shape, jnp.dtype('float32'))
|
|
# TODO(mattjj): 0 axis sizes lead to XLA sigfpe, file bug!
|
|
for k in range(1, 4) for shape in it.permutations(range(1, 4), k)
|
|
if not shape or len(set(shape)) > 1 # skip all-equal shapes, boring!
|
|
]
|
|
|
|
mesh_shapes = [
|
|
(1,),
|
|
(1, 1),
|
|
(1, 2),
|
|
(2, 2),
|
|
(2, 4),
|
|
(4, 2),
|
|
]
|
|
|
|
# Reference implementation of shard_map.
|
|
|
|
ShapeDtypeDuck = Any # has shape and dtype attributes
|
|
Specs = Any # pytree of PartitionSpec
|
|
|
|
def shmap_reference(
|
|
body_in_types: Sequence[ShapeDtypeDuck],
|
|
body_out_types: Sequence[ShapeDtypeDuck],
|
|
out_types: Sequence[ShapeDtypeDuck],
|
|
f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
|
|
) -> Callable:
|
|
def f_shmapped(*args):
|
|
outs = jax.tree.map(lambda y: jnp.zeros(y.shape, y.dtype), out_types)
|
|
getters = [make_indexer(mesh, s, x) for s, x in zip(in_specs, args)]
|
|
putters = jax.tree.map(partial(make_indexer, mesh), out_specs, outs)
|
|
for idx in it.product(*map(range, mesh.shape.values())):
|
|
args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)]
|
|
assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types))
|
|
out_shards = f(*args_shards)
|
|
assert jax.tree.all(jax.tree.map(lambda y, r: y.shape == r.shape,
|
|
out_shards, body_out_types))
|
|
outs = jax.tree.map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
|
|
out_shards, outs, putters)
|
|
return outs
|
|
return f_shmapped
|
|
|
|
def make_indexer(mesh: Mesh, spec: P, x: Any
|
|
) -> Callable[[tuple[int, ...]], tuple[slice, ...]]:
|
|
block_shape = [d // math.prod(mesh.shape[ax] for ax in (elt or ()))
|
|
for d, elt in zip(x.shape, spec)]
|
|
def indexer(idx):
|
|
starts = [0 if el is None else
|
|
idx[list(mesh.shape).index(el)] if type(el) is not tuple else
|
|
sum(idx[list(mesh.shape).index(el[i])]
|
|
* math.prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
|
|
for el in spec]
|
|
return tuple(slice(start * size, (start + 1) * size)
|
|
for start, size in zip(starts, block_shape))
|
|
return indexer
|
|
|
|
|
|
# The code below is similar to named_cases_from_sampler in test_util.py, but it
|
|
# uses generators instead of passing a "select" function around.
|
|
|
|
# To sample test cases efficiently, we construct a generator which yields to the
|
|
# caller to choose one of an iterable's options. That is, we can read 'yield' in
|
|
# this code as 'choose one'. To call functions which themselves need to make
|
|
# choices, we use 'yield from'. That is, we can read 'yield from' in this code
|
|
# as 'call this choice-making function'.
|
|
Option = Any
|
|
CaseSpec = tuple # first element is a string test name
|
|
Chooser = Generator[Iterable[Option], Option, CaseSpec]
|
|
|
|
def sample_shmap() -> Chooser:
|
|
spec = yield fun_specs
|
|
mesh_shape = yield mesh_shapes
|
|
axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)]
|
|
mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)),
|
|
axis_names=axis_names)
|
|
in_types = (tys for tys in it.product(input_shapes, repeat=spec.num_inputs)
|
|
if not spec.valid_types or spec.valid_types(*tys))
|
|
body_in_types = yield in_types
|
|
body_out_types = jax.eval_shape(spec.fun, *body_in_types)
|
|
in_types, in_specs = yield from make_in_specs(mesh, body_in_types)
|
|
args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size
|
|
for ty in in_types]
|
|
out_reps = spec.out_rep(*map(partial(unmentioned, mesh), in_specs))
|
|
out_specs = yield from make_out_specs(mesh, body_out_types, out_reps)
|
|
out_types = jax.tree.map(partial(dilate, mesh), out_specs, body_out_types)
|
|
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
|
|
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
|
|
for t in in_types) + ')'
|
|
jit = yield [True, False]
|
|
name = f'{spec.name}_{mesh.shape}_jit={jit}_{in_specs}_{out_specs}_{in_str}'
|
|
return name, spec.fun, mesh.shape, jit, in_specs, out_specs, args, ref
|
|
|
|
def unmentioned(mesh: Mesh, pspec: P) -> set[core.AxisName]:
|
|
return set(mesh.axis_names) - {n for ns in pspec if ns is not None
|
|
for n in (ns if type(ns) is tuple else [ns])}
|
|
|
|
|
|
# To drive the sampler, we have `sample` function which just runs a loop.
|
|
def sample(num: int, make_gen: Callable[[], Chooser]) -> Iterator[CaseSpec]:
|
|
rng = np.random.RandomState(0)
|
|
seen: set[str] = set()
|
|
while len(seen) < num:
|
|
name, *case = sample_one(rng, make_gen())
|
|
if name not in seen:
|
|
seen.add(name)
|
|
yield case
|
|
|
|
# To sample one test spec, we run the generator, getting back sequences of
|
|
# options from it and sending in our choices from those options until finally a
|
|
# test case spec is produced.
|
|
def sample_one(rng: np.random.RandomState, gen: Chooser) -> CaseSpec:
|
|
lst = list(next(gen))
|
|
try:
|
|
while True:
|
|
choice = lst[rng.randint(len(lst))]
|
|
lst = list(gen.send(choice))
|
|
except StopIteration as e:
|
|
return e.value
|
|
|
|
# Next are some choice-making functions for shard_map test specifications.
|
|
|
|
MeshDuck = Any # same attributes as a Mesh
|
|
|
|
def make_in_specs(mesh: MeshDuck, in_types: Sequence[ShapeDtypeDuck]
|
|
) -> Chooser:
|
|
pairs = []
|
|
for ty in in_types:
|
|
pair = yield from make_in_spec(mesh, ty)
|
|
pairs.append(pair)
|
|
return tuple(zip(*pairs))
|
|
|
|
def make_in_spec(mesh: Mesh, in_type_base: ShapeDtypeDuck) -> Chooser:
|
|
assert len(list(powerset(mesh.shape)))
|
|
subset = yield powerset(mesh.shape)
|
|
elts = yield partitions(subset, len(in_type_base.shape))
|
|
partition_spec = P(*(tuple(e) if e else None for e in elts))
|
|
new_type = dilate(mesh, partition_spec, in_type_base)
|
|
return new_type, partition_spec
|
|
|
|
def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck:
|
|
new_shape = tuple(d * math.prod(mesh.shape[ax] for ax in (elt or ()))
|
|
for d, elt in zip(shape.shape, spec))
|
|
return jax.ShapeDtypeStruct(new_shape, shape.dtype)
|
|
|
|
def make_out_specs(
|
|
mesh: MeshDuck, out_types: ShapeDtypeDuck | Sequence[ShapeDtypeDuck],
|
|
out_reps: set[core.AxisName] | Sequence[set[core.AxisName]]
|
|
) -> Chooser:
|
|
if type(out_types) is not tuple:
|
|
out_spec = yield from make_out_spec(mesh, out_types, out_reps) # type: ignore
|
|
return out_spec
|
|
else:
|
|
out_specs = []
|
|
for ty, rep in zip(out_types, out_reps):
|
|
out_spec = yield from make_out_spec(mesh, ty, rep) # type: ignore
|
|
out_specs.append(out_spec)
|
|
return tuple(out_specs)
|
|
|
|
def make_out_spec(
|
|
mesh: Mesh, out_type: ShapeDtypeDuck, out_rep: set[core.AxisName]
|
|
) -> Chooser:
|
|
subset = yield (s for s in powerset(mesh.shape)
|
|
if out_rep | set(s) == set(mesh.shape))
|
|
elts = yield partitions(subset, len(out_type.shape))
|
|
return P(*(tuple(e) if e else None for e in elts))
|
|
|
|
# Combinatorial helper functions
|
|
|
|
T = TypeVar('T')
|
|
def partitions(s: Sequence[T], k: int) -> Iterator[list[list[T]]]:
|
|
for indices in it.product(range(k), repeat=len(s)):
|
|
outs: list[list[T]] = [[] for _ in range(k)]
|
|
for i, elt in zip(indices, s):
|
|
outs[i].append(elt)
|
|
yield outs
|
|
|
|
def powerset(s: Iterable[T]) -> Iterator[Sequence[T]]:
|
|
s = list(s)
|
|
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
|
|
|
# Vmap test helpers
|
|
|
|
Arr = Any
|
|
|
|
def sample_shmap_batched(bdim_size: int) -> Chooser:
|
|
name, *shmap_specs, args, ref = yield from sample_shmap()
|
|
bdims = yield all_bdims(*map(op.attrgetter('shape'), args))
|
|
batch_args = map(partial(batchify_arg, bdim_size), bdims, args)
|
|
return name + f'_vmap_{bdims}', bdims, *shmap_specs, batch_args, ref
|
|
|
|
def all_bdims(*shapes: tuple[int, ...]
|
|
) -> Iterator[Sequence[int | None]]:
|
|
bdims = ((None, *range(len(shape) + 1)) for shape in shapes)
|
|
return (t for t in it.product(*bdims) if not all(e is None for e in t))
|
|
|
|
def batchify_arg(size: int, bdim: int | None, x: Arr) -> Arr:
|
|
if bdim is None:
|
|
return x
|
|
else:
|
|
iota = np.arange(1, size + 1, dtype=x.dtype).reshape(
|
|
[1 if i != bdim else -1 for i in range(len(x.shape) + 1)])
|
|
return np.expand_dims(x, bdim) * iota
|
|
|
|
def args_slicer(args: Sequence[Arr], bdims: Sequence[int | None]
|
|
) -> Callable[[int], Sequence[Arr]]:
|
|
def slicer(x, bdim):
|
|
if bdim is None:
|
|
return lambda _: x
|
|
else:
|
|
return lambda i: x.take(indices=i, axis=bdim)
|
|
slicers = map(slicer, args, bdims)
|
|
return lambda i: [sl(i) for sl in slicers]
|
|
|
|
|
|
class ShardMapSystematicTest(jtu.JaxTestCase):
|
|
|
|
@staticmethod
|
|
def make_mesh(mesh_shape):
|
|
return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
|
|
|
@parameterized.parameters(
|
|
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
|
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
|
mesh = self.make_mesh(mesh)
|
|
args = map(jnp.array, args)
|
|
out = shard_map(fun, mesh, in_specs, out_specs)(*args)
|
|
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
|
self.assertAllClose(expected, out, check_dtypes=False)
|
|
|
|
@parameterized.parameters(
|
|
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
|
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
|
mesh = self.make_mesh(mesh)
|
|
args = map(jnp.array, args)
|
|
out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args)
|
|
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
|
self.assertAllClose(expected, out, check_dtypes=False)
|
|
|
|
@parameterized.parameters(
|
|
(*params, check_rep)
|
|
for params in sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)
|
|
for check_rep in [True, False]
|
|
)
|
|
@jax.default_matmul_precision("float32")
|
|
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
|
|
mesh = self.make_mesh(mesh)
|
|
args = map(jnp.array, args)
|
|
f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
@parameterized.parameters(
|
|
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
|
|
@jax.default_matmul_precision("float32")
|
|
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
|
mesh = self.make_mesh(mesh)
|
|
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
|
|
args, closed_over_args = partition_list(no_sharding, args)
|
|
in_specs, _ = partition_list(no_sharding, in_specs)
|
|
def f(x, *closed_over_args):
|
|
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
|
|
def g(*args):
|
|
args = [x * arg for arg in args]
|
|
args = merge_lists(no_sharding, args, closed_over_args)
|
|
return fun(*args)
|
|
if jit:
|
|
g = jax.jit(g)
|
|
return g(*args)
|
|
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
|
|
|
@parameterized.parameters(
|
|
sample(jtu.NUM_GENERATED_CASES.value,
|
|
partial(sample_shmap_batched, 5)))
|
|
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
|
|
mesh = self.make_mesh(mesh)
|
|
args = map(jnp.array, args)
|
|
|
|
f = shard_map(fun, mesh, in_specs, out_specs)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
ans = jax.vmap(f, bdims)(*args)
|
|
|
|
args_slice = args_slicer(args, bdims)
|
|
expected_slices = [f(*args_slice(i)) for i in range(5)]
|
|
treedef = jax.tree.structure(ans)
|
|
if tree_util.treedef_is_strict_leaf(treedef):
|
|
expected = jnp.stack(expected_slices)
|
|
else:
|
|
slices = map(jnp.stack, zip(*expected_slices))
|
|
expected = jax.tree.unflatten(treedef, slices)
|
|
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
|
|
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
|
|
|
@parameterized.parameters(
|
|
sample(jtu.NUM_GENERATED_CASES.value,
|
|
partial(sample_shmap_batched, 5)))
|
|
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
|
mesh = self.make_mesh(mesh)
|
|
args = map(jnp.array, args)
|
|
|
|
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
|
|
args, closed_over_args = partition_list(no_sharding, args)
|
|
in_specs, _ = partition_list(no_sharding, in_specs)
|
|
explicit_bdims, closed_over_bdims = partition_list(no_sharding, bdims)
|
|
|
|
def f(x, *closed_over_args):
|
|
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
|
|
def g(*args):
|
|
args = [x * arg for arg in args]
|
|
args = merge_lists(no_sharding, args, closed_over_args)
|
|
return fun(*args)
|
|
if jit:
|
|
g = jax.jit(g)
|
|
if any(d is not None for d in explicit_bdims):
|
|
return jax.vmap(g, explicit_bdims)(*args)
|
|
else:
|
|
return g(*args)
|
|
|
|
xs = jnp.arange(5., dtype='float32')
|
|
ans = jax.vmap(f, (0, *closed_over_bdims))(xs, *closed_over_args)
|
|
|
|
args_slice = args_slicer((xs, *closed_over_args), (0, *closed_over_bdims))
|
|
expected_slices = [f(*args_slice(i)) for i in range(5)]
|
|
treedef = jax.tree.structure(ans)
|
|
if tree_util.treedef_is_strict_leaf(treedef):
|
|
expected = jnp.stack(expected_slices)
|
|
else:
|
|
slices = map(jnp.stack, zip(*expected_slices))
|
|
expected = jax.tree.unflatten(treedef, slices)
|
|
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
|
|
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
|
|
|
@jtu.pytest_mark_if_available('multiaccelerator')
|
|
class CustomPartitionerTest(jtu.JaxTestCase):
|
|
|
|
def skip_if_custom_partitioning_not_supported(self):
|
|
if jtu.is_cloud_tpu():
|
|
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
|
|
|
def test_custom_partitioning(self):
|
|
self.skip_if_custom_partitioning_not_supported()
|
|
|
|
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
|
assert a.addressable_data(0).shape == (4, 2)
|
|
|
|
def partition(mesh, arg_shapes, result_shape):
|
|
def lower_fn(x):
|
|
return x
|
|
|
|
return (
|
|
mesh,
|
|
lower_fn,
|
|
arg_shapes[0].sharding,
|
|
(arg_shapes[0].sharding,),
|
|
)
|
|
|
|
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
|
return arg_shapes[0].sharding
|
|
|
|
def propagate_user_sharding(mesh, user_shape):
|
|
return user_shape.sharding
|
|
|
|
@custom_partitioning
|
|
def f(x):
|
|
return x
|
|
|
|
f.def_partition(
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
partition=partition,
|
|
propagate_user_sharding=propagate_user_sharding,
|
|
)
|
|
|
|
@jax.jit
|
|
def fwd(a):
|
|
c = shard_map(
|
|
f,
|
|
mesh,
|
|
check_rep=False,
|
|
in_specs=(P('z', ('x', 'y')),),
|
|
out_specs=P('z', ('x', 'y')))(a)
|
|
return c
|
|
|
|
c = fwd(a)
|
|
self.assertEqual(c.addressable_data(0).shape, (4, 2))
|
|
|
|
def test_partially_sharded_dim_with_auto(self):
|
|
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
|
|
|
def g(x):
|
|
return jnp.sum(x)[None]
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(('i', 'j'))))
|
|
re = shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'),
|
|
check_rep=False, auto=frozenset({'j'}))(x)
|
|
re = jax.lax.with_sharding_constraint(re, NamedSharding(mesh, P(('i', 'j'))))
|
|
return re
|
|
|
|
self.assertAllClose(f(jnp.arange(8.)), jnp.array([1., 5., 9., 13.]))
|
|
|
|
|
|
@jtu.with_config(jax_use_shardy_partitioner=True)
|
|
# TODO(phawkins): enable this test unconditionally once shardy is the default.
|
|
@unittest.skipIf(sdy is None, "shardy is not enabled")
|
|
class SdyIntegrationTest(jtu.JaxTestCase):
|
|
|
|
# Verify we can lower to a `ManualComputationOp`.
|
|
def test_shardy_collective_permute(self):
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
a = jax.device_put(
|
|
jnp.arange(8 * 8).reshape((8, 8)),
|
|
jax.sharding.NamedSharding(mesh, P('x', None)),
|
|
)
|
|
|
|
@jax.jit
|
|
@partial(
|
|
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
|
|
)
|
|
def fwd(a):
|
|
axis_size = lax.psum(1, 'x')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(a, 'x', perm=perm)
|
|
|
|
self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|