Merge pull request #14513 from mattjj:shmap-test

PiperOrigin-RevId: 510330159
This commit is contained in:
jax authors 2023-02-16 21:21:20 -08:00
commit 8962d2f701
2 changed files with 236 additions and 6 deletions

View File

@ -313,7 +313,6 @@ class ShardMapPrimitive(core.Primitive):
return map(core.full_lower, core.apply_todos(todos, outs))
def get_bind_params(self, params):
"""Goes from jaxpr form to python traceable form."""
new_params = dict(params)
jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
@ -522,7 +521,7 @@ def _unmatch(mesh, src_tup, x):
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
) -> None:
fail = [a if not max(n, default=0) < a.ndim else no_fail
fail = [a if n and not max(n) < a.ndim else no_fail
for n, a in zip(names, avals)]
if any(f is not no_fail for f in fail): raise _SpecError(fail)
class _SpecError(Exception): pass
@ -654,7 +653,6 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
custom_derivatives.__dict__.values()):
if isinstance(o, core.Primitive): register_standard(o)
register_standard(xla.xla_call_p)
register_standard(lax_parallel.ppermute_p) # doesn't change replication
@register_rule(lax_parallel.psum_p)
@ -693,6 +691,10 @@ def _axis_index_rule(mesh, *, axis_name):
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
@register_rule(xla.xla_call_p)
def _jit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr, in_rep)
# Batching
def _shard_map_batch(
@ -754,7 +756,7 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
result = shard_map_p.bind(f_jvp, *args, **params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t
for p, t in zip(primal_out, tangent_out)]
for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
@ -783,7 +785,6 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits(
f, (*in_knowns,), (*in_avals_sharded,))
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
@as_hashable_function(closure=out_names_thunk)
def known_out_names():

View File

@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools as it
import os
from types import SimpleNamespace
from typing import NamedTuple, Callable, Optional
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
@ -26,12 +30,16 @@ from jax.sharding import PartitionSpec as P
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge
from jax._src.util import safe_zip, safe_map, prod, partition_list, merge_lists
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Helper for some tests.
def create_inputs(a_sharding, b_sharding):
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
@ -64,7 +72,7 @@ def setUpModule():
if len(jax.devices()) < 8:
raise unittest.SkipTest("tests require 8 devices")
if not jax.config.jax_array:
raise unittest.SkipTest("test requires jax_array")
raise unittest.SkipTest("tests require jax_array")
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
@ -474,5 +482,226 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
class CaseSpec(NamedTuple):
name: str
num_inputs: int
fun: Callable
out_rep: Callable
valid_types: Optional[Callable] = None
fun_specs = [
CaseSpec('id', 1, lambda x: x, lambda r: r),
CaseSpec('flip', 2, lambda x, y: (y, x), lambda r_x, r_y: (r_y, r_x)),
CaseSpec('transpose', 1, lambda x: x.T, lambda r: r),
CaseSpec('ravel', 1, lambda x: x.ravel(), lambda r: r),
CaseSpec('dot', 2, jnp.dot,
lambda r1, r2: r1 & r2,
lambda x1, x2: (x1.shape and x2.shape and
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]),
),
CaseSpec('sin_dot_sin', 2,
lambda x1, x2: jnp.sin(jnp.dot(jnp.sin(x1), x2)),
lambda r1, r2: r1 & r2,
lambda x1, x2: (x1.shape and x2.shape and
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]),
),
]
input_shapes = [
jax.ShapeDtypeStruct(shape, jnp.dtype('float32'))
# TODO(mattjj): 0 axis sizes lead to XLA sigfpe, file bug!
for k in range(1, 4) for shape in it.permutations(range(1, 4), k)
if not shape or len(set(shape)) > 1 # skip all-equal shapes, boring!
]
mesh_shapes = [
(1,),
(1, 1),
(1, 2),
(2, 2),
(2, 4),
(4, 2),
]
def make_in_specs(mesh, in_types):
pairs = []
for ty in in_types:
pair = yield from make_in_spec(mesh, ty)
pairs.append(pair)
return list(zip(*pairs))
def make_in_spec(mesh, in_type_base):
assert len(list(powerset(mesh.shape)))
subset = yield powerset(mesh.shape)
elts = yield partitions(subset, len(in_type_base.shape))
partition_spec = P(*(tuple(e) if e else None for e in elts))
new_type = dilate(mesh, partition_spec, in_type_base)
return new_type, partition_spec
def dilate(mesh, spec, shape):
new_shape = tuple(d * prod(mesh.shape[ax] for ax in (elt or ()))
for d, elt in zip(shape.shape, spec))
return jax.ShapeDtypeStruct(new_shape, shape.dtype)
def make_out_specs(mesh, out_types, out_reps):
if type(out_types) is not tuple:
out_spec = yield from make_out_spec(mesh, out_types, out_reps)
return out_spec
else:
out_specs = []
for ty, rep in zip(out_types, out_reps):
out_spec = yield from make_out_spec(mesh, ty, rep)
out_specs.append(out_spec)
return tuple(out_specs)
def make_out_spec(mesh, out_type, out_rep):
subset = yield (s for s in powerset(mesh.shape)
if out_rep | set(s) == set(mesh.shape))
elts = yield partitions(subset, len(out_type.shape))
return P(*(tuple(e) if e else None for e in elts))
def partitions(s, k):
for indices in it.product(range(k), repeat=len(s)):
outs = [[] for _ in range(k)]
for i, elt in zip(indices, s):
outs[i].append(elt)
yield outs
def powerset(s):
s = list(s)
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
def unmentioned(mesh, pspec):
return set(mesh.axis_names) - {n for ns in pspec if ns is not None
for n in (ns if type(ns) is tuple else [ns])}
def shmap_reference(body_in_types, body_out_types, out_types,
f, mesh, in_specs, out_specs):
def f_shmapped(*args):
outs = jax.tree_map(lambda y: jnp.zeros(y.shape, y.dtype), out_types)
getters = [make_indexer(mesh, s, x) for s, x in zip(in_specs, args)]
putters = jax.tree_map(partial(make_indexer, mesh), out_specs, outs)
for idx in it.product(*map(range, mesh.shape.values())):
args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)]
assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types))
out_shards = f(*args_shards)
assert jax.tree_util.tree_all(jax.tree_map(lambda y, r: y.shape == r.shape,
out_shards, body_out_types))
outs = jax.tree_map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
out_shards, outs, putters)
return outs
return f_shmapped
def make_indexer(mesh, spec, x):
block_shape = [d // prod(mesh.shape[ax] for ax in (elt or ()))
for d, elt in zip(x.shape, spec)]
def indexer(idx):
starts = [0 if el is None else
idx[list(mesh.shape).index(el)] if type(el) is not tuple else
sum(idx[list(mesh.shape).index(el[i])]
* prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
for el in spec]
return tuple(slice(start * size, (start + 1) * size)
for start, size in zip(starts, block_shape))
return indexer
def sample_shmap():
spec = yield fun_specs
mesh_shape = yield mesh_shapes
axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)]
mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)),
axis_names=axis_names)
in_types = (tys for tys in it.product(input_shapes, repeat=spec.num_inputs)
if not spec.valid_types or spec.valid_types(*tys))
body_in_types = yield in_types
body_out_types = jax.eval_shape(spec.fun, *body_in_types)
in_types, in_specs = yield from make_in_specs(mesh, body_in_types)
args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size
for ty in in_types]
out_reps = spec.out_rep(*map(partial(unmentioned, mesh), in_specs))
out_specs = yield from make_out_specs(mesh, body_out_types, out_reps)
out_types = jax.tree_map(partial(dilate, mesh), out_specs, body_out_types)
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
for t in in_types) + ')'
name = f'{spec.name}_{mesh.shape}_{in_specs}_{out_specs}_{in_str}'
return name, spec.fun, mesh.shape, in_specs, out_specs, args, ref
def sample(num, make_gen):
rng = np.random.RandomState(0)
seen = set()
while len(seen) < num:
name, *case = sample_one(rng, make_gen())
if name not in seen:
seen.add(name)
yield name, *case
def sample_one(rng, gen):
lst = list(next(gen))
try:
while True:
choice = lst[rng.randint(len(lst))]
lst = list(gen.send(choice))
except StopIteration as e:
return e.value
class ShardMapSystematicTest(jtu.JaxTestCase):
@staticmethod
def make_mesh(mesh_shape):
shape, axis_names = tuple(mesh_shape.values()), tuple(mesh_shape)
if len(jax.devices()) < prod(shape):
raise unittest.SkipTest("too few devices for test")
m = Mesh(np.array(jax.devices()[:prod(shape)]).reshape(shape), axis_names)
return m
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
def test_eager_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
out = shard_map(fun, mesh, in_specs, out_specs)(*args)
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
def test_jit_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args)
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads(self, fun, mesh, in_specs, out_specs, args, _):
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
mesh = self.make_mesh(mesh)
f = jax.jit(shard_map(fun, mesh, in_specs, out_specs))
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, in_specs, out_specs, args, _):
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
mesh = self.make_mesh(mesh)
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
args, closed_over_args = partition_list(no_sharding, args)
in_specs, _ = partition_list(no_sharding, in_specs)
def f(x, *closed_over_args):
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
def g(*args):
args = [x * arg for arg in args]
args = merge_lists(no_sharding, args, closed_over_args)
return fun(*args)
return g(*args)
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())