mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #14633 from mattjj:shmap-test-vmap
PiperOrigin-RevId: 515117185
This commit is contained in:
commit
9c4db8c962
@ -1053,6 +1053,11 @@ jax_test(
|
||||
jax_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"tpu": 10,
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -15,6 +15,7 @@
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import math
|
||||
import operator as op
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import (Any, Sequence, Set, Iterable, Iterator, NamedTuple,
|
||||
@ -34,6 +35,7 @@ from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src import tree_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -598,8 +600,9 @@ def sample_shmap() -> Chooser:
|
||||
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
|
||||
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
|
||||
for t in in_types) + ')'
|
||||
name = f'{spec.name}_{mesh.shape}_{in_specs}_{out_specs}_{in_str}'
|
||||
return name, spec.fun, mesh.shape, in_specs, out_specs, args, ref
|
||||
jit = yield [True, False]
|
||||
name = f'{spec.name}_{mesh.shape}_jit={jit}_{in_specs}_{out_specs}_{in_str}'
|
||||
return name, spec.fun, mesh.shape, jit, in_specs, out_specs, args, ref
|
||||
|
||||
def unmentioned(mesh: Mesh, pspec: P) -> Set[core.AxisName]:
|
||||
return set(mesh.axis_names) - {n for ns in pspec if ns is not None
|
||||
@ -689,6 +692,38 @@ def powerset(s: Iterable[T]) -> Iterator[Sequence[T]]:
|
||||
s = list(s)
|
||||
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
||||
|
||||
# Vmap test helpers
|
||||
|
||||
Arr = Any
|
||||
|
||||
def sample_shmap_batched(bdim_size: int) -> Chooser:
|
||||
name, *shmap_specs, args, ref = yield from sample_shmap()
|
||||
bdims = yield all_bdims(*map(op.attrgetter('shape'), args))
|
||||
batch_args = map(partial(batchify_arg, bdim_size), bdims, args)
|
||||
return name + f'_vmap_{bdims}', bdims, *shmap_specs, batch_args, ref
|
||||
|
||||
def all_bdims(*shapes: Tuple[int, ...]
|
||||
) -> Iterator[Sequence[Optional[int]]]:
|
||||
bdims = ((None, *range(len(shape) + 1)) for shape in shapes)
|
||||
return (t for t in it.product(*bdims) if not all(e is None for e in t))
|
||||
|
||||
def batchify_arg(size: int, bdim: Optional[int], x: Arr) -> Arr:
|
||||
if bdim is None:
|
||||
return x
|
||||
else:
|
||||
iota = np.arange(1, size + 1, dtype=x.dtype).reshape(
|
||||
[1 if i != bdim else -1 for i in range(len(x.shape) + 1)])
|
||||
return np.expand_dims(x, bdim) * iota
|
||||
|
||||
def args_slicer(args: Sequence[Arr], bdims: Sequence[Optional[int]]
|
||||
) -> Callable[[int], Sequence[Arr]]:
|
||||
def slicer(x, bdim):
|
||||
if bdim is None:
|
||||
return lambda _: x
|
||||
else:
|
||||
return lambda i: x.take(indices=i, axis=bdim)
|
||||
slicers = map(slicer, args, bdims)
|
||||
return lambda i: [sl(i) for sl in slicers]
|
||||
|
||||
|
||||
class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
@ -699,16 +734,18 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
def test_eager_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
|
||||
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
out = shard_map(fun, mesh, in_specs, out_specs)(*args)
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
def test_jit_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
|
||||
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args)
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
@ -716,31 +753,98 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads(self, fun, mesh, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
|
||||
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
if xla_bridge.xla_client._version < 134:
|
||||
raise unittest.SkipTest("requires later jaxlib version")
|
||||
mesh = self.make_mesh(mesh)
|
||||
f = jax.jit(shard_map(fun, mesh, in_specs, out_specs))
|
||||
args = map(jnp.array, args)
|
||||
f = shard_map(fun, mesh, in_specs, out_specs)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads_closure(self, fun, mesh, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
|
||||
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
if xla_bridge.xla_client._version < 134:
|
||||
raise unittest.SkipTest("requires later jaxlib version")
|
||||
mesh = self.make_mesh(mesh)
|
||||
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
|
||||
args, closed_over_args = partition_list(no_sharding, args)
|
||||
in_specs, _ = partition_list(no_sharding, in_specs)
|
||||
def f(x, *closed_over_args):
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
|
||||
def g(*args):
|
||||
args = [x * arg for arg in args]
|
||||
args = merge_lists(no_sharding, args, closed_over_args)
|
||||
return fun(*args)
|
||||
if jit:
|
||||
g = jax.jit(g)
|
||||
return g(*args)
|
||||
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
|
||||
f = shard_map(fun, mesh, in_specs, out_specs)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
ans = jax.vmap(f, bdims)(*args)
|
||||
|
||||
args_slice = args_slicer(args, bdims)
|
||||
expected_slices = [f(*args_slice(i)) for i in range(5)]
|
||||
treedef = tree_util.tree_structure(ans)
|
||||
if tree_util.treedef_is_strict_leaf(treedef):
|
||||
expected = jnp.stack(expected_slices)
|
||||
else:
|
||||
slices = map(jnp.stack, zip(*expected_slices))
|
||||
expected = tree_util.tree_unflatten(treedef, slices)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("need BatchTrace.post_process_shard_map") # TODO
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
|
||||
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
|
||||
args, closed_over_args = partition_list(no_sharding, args)
|
||||
in_specs, _ = partition_list(no_sharding, in_specs)
|
||||
explicit_bdims, closed_over_bdims = partition_list(no_sharding, bdims)
|
||||
|
||||
def f(x, *closed_over_args):
|
||||
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
|
||||
def g(*args):
|
||||
args = [x * arg for arg in args]
|
||||
args = merge_lists(no_sharding, args, closed_over_args)
|
||||
return fun(*args)
|
||||
if jit:
|
||||
g = jax.jit(g)
|
||||
if any(d is not None for d in explicit_bdims):
|
||||
return jax.vmap(g, explicit_bdims)(*args)
|
||||
else:
|
||||
return g(*args)
|
||||
|
||||
xs = jnp.arange(5., dtype='float32')
|
||||
ans = jax.vmap(f, (0, *closed_over_bdims))(xs, *closed_over_args)
|
||||
|
||||
args_slice = args_slicer((xs, *closed_over_args), (0, *closed_over_bdims))
|
||||
expected_slices = [f(*args_slice(i)) for i in range(5)]
|
||||
treedef = tree_util.tree_structure(ans)
|
||||
if tree_util.treedef_is_strict_leaf(treedef):
|
||||
expected = jnp.stack(expected_slices)
|
||||
else:
|
||||
slices = map(jnp.stack, zip(*expected_slices))
|
||||
expected = tree_util.tree_unflatten(treedef, slices)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user