mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14513 from mattjj:shmap-test
PiperOrigin-RevId: 510330159
This commit is contained in:
commit
8962d2f701
@ -313,7 +313,6 @@ class ShardMapPrimitive(core.Primitive):
|
||||
return map(core.full_lower, core.apply_todos(todos, outs))
|
||||
|
||||
def get_bind_params(self, params):
|
||||
"""Goes from jaxpr form to python traceable form."""
|
||||
new_params = dict(params)
|
||||
jaxpr = new_params.pop('jaxpr')
|
||||
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr), jaxpr, ())
|
||||
@ -522,7 +521,7 @@ def _unmatch(mesh, src_tup, x):
|
||||
|
||||
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
|
||||
) -> None:
|
||||
fail = [a if not max(n, default=0) < a.ndim else no_fail
|
||||
fail = [a if n and not max(n) < a.ndim else no_fail
|
||||
for n, a in zip(names, avals)]
|
||||
if any(f is not no_fail for f in fail): raise _SpecError(fail)
|
||||
class _SpecError(Exception): pass
|
||||
@ -654,7 +653,6 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
custom_derivatives.__dict__.values()):
|
||||
if isinstance(o, core.Primitive): register_standard(o)
|
||||
|
||||
register_standard(xla.xla_call_p)
|
||||
register_standard(lax_parallel.ppermute_p) # doesn't change replication
|
||||
|
||||
@register_rule(lax_parallel.psum_p)
|
||||
@ -693,6 +691,10 @@ def _axis_index_rule(mesh, *, axis_name):
|
||||
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
@register_rule(xla.xla_call_p)
|
||||
def _jit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr, in_rep)
|
||||
|
||||
# Batching
|
||||
|
||||
def _shard_map_batch(
|
||||
@ -754,7 +756,7 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
result = shard_map_p.bind(f_jvp, *args, **params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree(), result)
|
||||
tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t
|
||||
for p, t in zip(primal_out, tangent_out)]
|
||||
for p, t in zip(primal_out, tangent_out)]
|
||||
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
|
||||
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
||||
|
||||
@ -783,7 +785,6 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
f = _promote_scalar_residuals(f)
|
||||
f_known, aux = pe.partial_eval_wrapper_nounits(
|
||||
f, (*in_knowns,), (*in_avals_sharded,))
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def known_out_names():
|
||||
|
@ -12,10 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import NamedTuple, Callable, Optional
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
@ -26,12 +30,16 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.util import safe_zip, safe_map, prod, partition_list, merge_lists
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
# Helper for some tests.
|
||||
def create_inputs(a_sharding, b_sharding):
|
||||
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
|
||||
@ -64,7 +72,7 @@ def setUpModule():
|
||||
if len(jax.devices()) < 8:
|
||||
raise unittest.SkipTest("tests require 8 devices")
|
||||
if not jax.config.jax_array:
|
||||
raise unittest.SkipTest("test requires jax_array")
|
||||
raise unittest.SkipTest("tests require jax_array")
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
@ -474,5 +482,226 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
||||
|
||||
|
||||
class CaseSpec(NamedTuple):
|
||||
name: str
|
||||
num_inputs: int
|
||||
fun: Callable
|
||||
out_rep: Callable
|
||||
valid_types: Optional[Callable] = None
|
||||
|
||||
fun_specs = [
|
||||
CaseSpec('id', 1, lambda x: x, lambda r: r),
|
||||
CaseSpec('flip', 2, lambda x, y: (y, x), lambda r_x, r_y: (r_y, r_x)),
|
||||
CaseSpec('transpose', 1, lambda x: x.T, lambda r: r),
|
||||
CaseSpec('ravel', 1, lambda x: x.ravel(), lambda r: r),
|
||||
CaseSpec('dot', 2, jnp.dot,
|
||||
lambda r1, r2: r1 & r2,
|
||||
lambda x1, x2: (x1.shape and x2.shape and
|
||||
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]),
|
||||
),
|
||||
CaseSpec('sin_dot_sin', 2,
|
||||
lambda x1, x2: jnp.sin(jnp.dot(jnp.sin(x1), x2)),
|
||||
lambda r1, r2: r1 & r2,
|
||||
lambda x1, x2: (x1.shape and x2.shape and
|
||||
x1.shape[-1] == x2.shape[-2 if x2.ndim > 1 else 0]),
|
||||
),
|
||||
]
|
||||
|
||||
input_shapes = [
|
||||
jax.ShapeDtypeStruct(shape, jnp.dtype('float32'))
|
||||
# TODO(mattjj): 0 axis sizes lead to XLA sigfpe, file bug!
|
||||
for k in range(1, 4) for shape in it.permutations(range(1, 4), k)
|
||||
if not shape or len(set(shape)) > 1 # skip all-equal shapes, boring!
|
||||
]
|
||||
|
||||
mesh_shapes = [
|
||||
(1,),
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 2),
|
||||
]
|
||||
|
||||
def make_in_specs(mesh, in_types):
|
||||
pairs = []
|
||||
for ty in in_types:
|
||||
pair = yield from make_in_spec(mesh, ty)
|
||||
pairs.append(pair)
|
||||
return list(zip(*pairs))
|
||||
|
||||
def make_in_spec(mesh, in_type_base):
|
||||
assert len(list(powerset(mesh.shape)))
|
||||
subset = yield powerset(mesh.shape)
|
||||
elts = yield partitions(subset, len(in_type_base.shape))
|
||||
partition_spec = P(*(tuple(e) if e else None for e in elts))
|
||||
new_type = dilate(mesh, partition_spec, in_type_base)
|
||||
return new_type, partition_spec
|
||||
|
||||
def dilate(mesh, spec, shape):
|
||||
new_shape = tuple(d * prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
for d, elt in zip(shape.shape, spec))
|
||||
return jax.ShapeDtypeStruct(new_shape, shape.dtype)
|
||||
|
||||
def make_out_specs(mesh, out_types, out_reps):
|
||||
if type(out_types) is not tuple:
|
||||
out_spec = yield from make_out_spec(mesh, out_types, out_reps)
|
||||
return out_spec
|
||||
else:
|
||||
out_specs = []
|
||||
for ty, rep in zip(out_types, out_reps):
|
||||
out_spec = yield from make_out_spec(mesh, ty, rep)
|
||||
out_specs.append(out_spec)
|
||||
return tuple(out_specs)
|
||||
|
||||
def make_out_spec(mesh, out_type, out_rep):
|
||||
subset = yield (s for s in powerset(mesh.shape)
|
||||
if out_rep | set(s) == set(mesh.shape))
|
||||
elts = yield partitions(subset, len(out_type.shape))
|
||||
return P(*(tuple(e) if e else None for e in elts))
|
||||
|
||||
def partitions(s, k):
|
||||
for indices in it.product(range(k), repeat=len(s)):
|
||||
outs = [[] for _ in range(k)]
|
||||
for i, elt in zip(indices, s):
|
||||
outs[i].append(elt)
|
||||
yield outs
|
||||
|
||||
def powerset(s):
|
||||
s = list(s)
|
||||
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
||||
|
||||
def unmentioned(mesh, pspec):
|
||||
return set(mesh.axis_names) - {n for ns in pspec if ns is not None
|
||||
for n in (ns if type(ns) is tuple else [ns])}
|
||||
|
||||
|
||||
def shmap_reference(body_in_types, body_out_types, out_types,
|
||||
f, mesh, in_specs, out_specs):
|
||||
def f_shmapped(*args):
|
||||
outs = jax.tree_map(lambda y: jnp.zeros(y.shape, y.dtype), out_types)
|
||||
getters = [make_indexer(mesh, s, x) for s, x in zip(in_specs, args)]
|
||||
putters = jax.tree_map(partial(make_indexer, mesh), out_specs, outs)
|
||||
for idx in it.product(*map(range, mesh.shape.values())):
|
||||
args_shards = [x[indexer(idx)] for x, indexer in zip(args, getters)]
|
||||
assert all(x.shape == r.shape for x, r in zip(args_shards, body_in_types))
|
||||
out_shards = f(*args_shards)
|
||||
assert jax.tree_util.tree_all(jax.tree_map(lambda y, r: y.shape == r.shape,
|
||||
out_shards, body_out_types))
|
||||
outs = jax.tree_map(lambda y, out, indexer: out.at[indexer(idx)].set(y),
|
||||
out_shards, outs, putters)
|
||||
return outs
|
||||
return f_shmapped
|
||||
|
||||
def make_indexer(mesh, spec, x):
|
||||
block_shape = [d // prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
for d, elt in zip(x.shape, spec)]
|
||||
def indexer(idx):
|
||||
starts = [0 if el is None else
|
||||
idx[list(mesh.shape).index(el)] if type(el) is not tuple else
|
||||
sum(idx[list(mesh.shape).index(el[i])]
|
||||
* prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
|
||||
for el in spec]
|
||||
return tuple(slice(start * size, (start + 1) * size)
|
||||
for start, size in zip(starts, block_shape))
|
||||
return indexer
|
||||
|
||||
|
||||
def sample_shmap():
|
||||
spec = yield fun_specs
|
||||
mesh_shape = yield mesh_shapes
|
||||
axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)]
|
||||
mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)),
|
||||
axis_names=axis_names)
|
||||
in_types = (tys for tys in it.product(input_shapes, repeat=spec.num_inputs)
|
||||
if not spec.valid_types or spec.valid_types(*tys))
|
||||
body_in_types = yield in_types
|
||||
body_out_types = jax.eval_shape(spec.fun, *body_in_types)
|
||||
in_types, in_specs = yield from make_in_specs(mesh, body_in_types)
|
||||
args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size
|
||||
for ty in in_types]
|
||||
out_reps = spec.out_rep(*map(partial(unmentioned, mesh), in_specs))
|
||||
out_specs = yield from make_out_specs(mesh, body_out_types, out_reps)
|
||||
out_types = jax.tree_map(partial(dilate, mesh), out_specs, body_out_types)
|
||||
ref = partial(shmap_reference, body_in_types, body_out_types, out_types)
|
||||
in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short()
|
||||
for t in in_types) + ')'
|
||||
name = f'{spec.name}_{mesh.shape}_{in_specs}_{out_specs}_{in_str}'
|
||||
return name, spec.fun, mesh.shape, in_specs, out_specs, args, ref
|
||||
|
||||
def sample(num, make_gen):
|
||||
rng = np.random.RandomState(0)
|
||||
seen = set()
|
||||
while len(seen) < num:
|
||||
name, *case = sample_one(rng, make_gen())
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
yield name, *case
|
||||
|
||||
def sample_one(rng, gen):
|
||||
lst = list(next(gen))
|
||||
try:
|
||||
while True:
|
||||
choice = lst[rng.randint(len(lst))]
|
||||
lst = list(gen.send(choice))
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
|
||||
class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def make_mesh(mesh_shape):
|
||||
shape, axis_names = tuple(mesh_shape.values()), tuple(mesh_shape)
|
||||
if len(jax.devices()) < prod(shape):
|
||||
raise unittest.SkipTest("too few devices for test")
|
||||
m = Mesh(np.array(jax.devices()[:prod(shape)]).reshape(shape), axis_names)
|
||||
return m
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
def test_eager_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
out = shard_map(fun, mesh, in_specs, out_specs)(*args)
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
def test_jit_against_ref(self, fun, mesh, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args)
|
||||
expected = ref(fun, mesh, in_specs, out_specs)(*args)
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads(self, fun, mesh, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
|
||||
mesh = self.make_mesh(mesh)
|
||||
f = jax.jit(shard_map(fun, mesh, in_specs, out_specs))
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads_closure(self, fun, mesh, in_specs, out_specs, args, _):
|
||||
raise unittest.SkipTest("internal xla failures") # TODO(b/269660532)
|
||||
mesh = self.make_mesh(mesh)
|
||||
no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
|
||||
args, closed_over_args = partition_list(no_sharding, args)
|
||||
in_specs, _ = partition_list(no_sharding, in_specs)
|
||||
def f(x, *closed_over_args):
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
|
||||
def g(*args):
|
||||
args = [x * arg for arg in args]
|
||||
args = merge_lists(no_sharding, args, closed_over_args)
|
||||
return fun(*args)
|
||||
return g(*args)
|
||||
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user