Merge pull request #14633 from mattjj:shmap-test-vmap

PiperOrigin-RevId: 515117185
This commit is contained in:
jax authors 2023-03-08 12:56:54 -08:00
commit 9c4db8c962
2 changed files with 119 additions and 10 deletions

View File

@ -1053,6 +1053,11 @@ jax_test(
jax_test(
name = "shard_map_test",
srcs = ["shard_map_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
},
)
jax_test(

View File

@ -15,6 +15,7 @@
from functools import partial
import itertools as it
import math
import operator as op
import os
from types import SimpleNamespace
from typing import (Any, Sequence, Set, Iterable, Iterator, NamedTuple,
@ -34,6 +35,7 @@ from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src import tree_util
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
@ -598,8 +600,9 @@ def sample_shmap() -> Chooser:
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
for t in in_types) + ')'
name = f'{spec.name}_{mesh.shape}_{in_specs}_{out_specs}_{in_str}'
return name, spec.fun, mesh.shape, in_specs, out_specs, args, ref
jit = yield [True, False]
name = f'{spec.name}_{mesh.shape}_jit={jit}_{in_specs}_{out_specs}_{in_str}'
return name, spec.fun, mesh.shape, jit, in_specs, out_specs, args, ref
def unmentioned(mesh: Mesh, pspec: P) -> Set[core.AxisName]:
return set(mesh.axis_names) - {n for ns in pspec if ns is not None
@ -689,6 +692,38 @@ def powerset(s: Iterable[T]) -> Iterator[Sequence[T]]:
s = list(s)
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
# Vmap test helpers
Arr = Any
def sample_shmap_batched(bdim_size: int) -> Chooser:
name, *shmap_specs, args, ref = yield from sample_shmap()
bdims = yield all_bdims(*map(op.attrgetter('shape'), args))
batch_args = map(partial(batchify_arg, bdim_size), bdims, args)
return name + f'_vmap_{bdims}', bdims, *shmap_specs, batch_args, ref
def all_bdims(*shapes: Tuple[int, ...]
) -> Iterator[Sequence[Optional[int]]]:
bdims = ((None, *range(len(shape) + 1)) for shape in shapes)
return (t for t in it.product(*bdims) if not all(e is None for e in t))
def batchify_arg(size: int, bdim: Optional[int], x: Arr) -> Arr:
if bdim is None:
return x
else:
iota = np.arange(1, size + 1, dtype=x.dtype).reshape(
[1 if i != bdim else -1 for i in range(len(x.shape) + 1)])
return np.expand_dims(x, bdim) * iota
def args_slicer(args: Sequence[Arr], bdims: Sequence[Optional[int]]
) -> Callable[[int], Sequence[Arr]]:
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: x.take(indices=i, axis=bdim)
slicers = map(slicer, args, bdims)
return lambda i: [sl(i) for sl in slicers]
class ShardMapSystematicTest(jtu.JaxTestCase):
@ -699,16 +734,18 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
def test_eager_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
out = shard_map(fun, mesh, in_specs, out_specs)(*args)
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
def test_jit_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args)
expected = ref(fun, mesh, in_specs, out_specs)(*args)
self.assertAllClose(expected, out, check_dtypes=False)
@ -716,31 +753,98 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads(self, fun, mesh, in_specs, out_specs, args, _):
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
if xla_bridge.xla_client._version < 134:
raise unittest.SkipTest("requires later jaxlib version")
mesh = self.make_mesh(mesh)
f = jax.jit(shard_map(fun, mesh, in_specs, out_specs))
args = map(jnp.array, args)
f = shard_map(fun, mesh, in_specs, out_specs)
if jit:
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, in_specs, out_specs, args, _):
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
if xla_bridge.xla_client._version < 134:
raise unittest.SkipTest("requires later jaxlib version")
mesh = self.make_mesh(mesh)
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
args, closed_over_args = partition_list(no_sharding, args)
in_specs, _ = partition_list(no_sharding, in_specs)
def f(x, *closed_over_args):
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
def g(*args):
args = [x * arg for arg in args]
args = merge_lists(no_sharding, args, closed_over_args)
return fun(*args)
if jit:
g = jax.jit(g)
return g(*args)
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
f = shard_map(fun, mesh, in_specs, out_specs)
if jit:
f = jax.jit(f)
ans = jax.vmap(f, bdims)(*args)
args_slice = args_slicer(args, bdims)
expected_slices = [f(*args_slice(i)) for i in range(5)]
treedef = tree_util.tree_structure(ans)
if tree_util.treedef_is_strict_leaf(treedef):
expected = jnp.stack(expected_slices)
else:
slices = map(jnp.stack, zip(*expected_slices))
expected = tree_util.tree_unflatten(treedef, slices)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
raise unittest.SkipTest("need BatchTrace.post_process_shard_map") # TODO
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
args, closed_over_args = partition_list(no_sharding, args)
in_specs, _ = partition_list(no_sharding, in_specs)
explicit_bdims, closed_over_bdims = partition_list(no_sharding, bdims)
def f(x, *closed_over_args):
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
def g(*args):
args = [x * arg for arg in args]
args = merge_lists(no_sharding, args, closed_over_args)
return fun(*args)
if jit:
g = jax.jit(g)
if any(d is not None for d in explicit_bdims):
return jax.vmap(g, explicit_bdims)(*args)
else:
return g(*args)
xs = jnp.arange(5., dtype='float32')
ans = jax.vmap(f, (0, *closed_over_bdims))(xs, *closed_over_args)
args_slice = args_slicer((xs, *closed_over_args), (0, *closed_over_bdims))
expected_slices = [f(*args_slice(i)) for i in range(5)]
treedef = tree_util.tree_structure(ans)
if tree_util.treedef_is_strict_leaf(treedef):
expected = jnp.stack(expected_slices)
else:
slices = map(jnp.stack, zip(*expected_slices))
expected = tree_util.tree_unflatten(treedef, slices)
self.assertAllClose(ans, expected, check_dtypes=False)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())