add systematic pdot tests, utility functions

Run lots of tests with e.g.

```
env JAX_NUM_GENERATED_CASES=1000 python tests/xmap_test.py PDotTests
```
This commit is contained in:
Matthew Johnson 2021-01-18 10:02:24 -08:00
parent 60a87fd2da
commit c02d8041f4
2 changed files with 272 additions and 18 deletions

View File

@ -352,11 +352,11 @@ def axis_index(axis_name):
return axis_index_p.bind(axis_name=axis_name)
def pdot(x, y, axis_name, pos_contract=((), ())):
def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
return pdot_p.bind(x, y, axis_name=axis_name,
pos_contract=pos_contract, pos_batch=[(), ()])
pos_contract=pos_contract, pos_batch=pos_batch)
### parallel primitives
@ -870,6 +870,7 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
@pdot_p.def_abstract_eval
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
# TODO: avals with names, check inputs are mapped along axis_name, eliminate
if not len(set(axis_name)) == len(axis_name): raise ValueError
return lax.dot_general_p.abstract_eval(
x, y, dimension_numbers=[pos_contract, pos_batch],
precision=None, preferred_element_type=None)

View File

@ -14,11 +14,14 @@
# flake8: noqa
from contextlib import contextmanager
import functools
import itertools
import itertools as it
import os
import unittest
from itertools import product, permutations
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
Any, Hashable, Iterable, Iterator, Union)
from unittest import SkipTest, skip, skipIf
import numpy as np
@ -33,7 +36,8 @@ from jax import vmap
from jax import lax
from jax.experimental.maps import Mesh, mesh, xmap
from jax.lib import xla_bridge
from jax._src.util import curry, unzip2
from jax._src.util import curry, unzip2, split_list, prod
from jax._src.lax.lax import DotDimensionNumbers
from jax.interpreters import pxla
from jax.config import config
@ -63,20 +67,20 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
@curry
def with_mesh(named_shape, f):
if not named_shape:
return f
def new_f(*args, **kwargs):
axis_names, shape = unzip2(named_shape)
size = np.prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
return f(*args, **kwargs)
return new_f
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
class XMapTest(jtu.JaxTestCase):
@ -338,12 +342,222 @@ class XMapTestSPMD(XMapTest):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
AxisIndices = Tuple[int, ...]
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
AxisNames = Tuple[str, ...]
class PdotTestSpec:
# The axis indices stored by a PdotTestSpec are all positional indices
# *before* taking mapping into account.
map_cont: MatchedAxisIndices
pos_cont: MatchedAxisIndices
map_batch: MatchedAxisIndices
pos_batch: MatchedAxisIndices
all_names: AxisNames
contract_names: AxisNames
batch_names: AxisNames
def __init__(self, map_cont, pos_cont, map_batch, pos_batch):
self.map_cont = map_cont
self.pos_cont = pos_cont
self.map_batch = map_batch
self.pos_batch = pos_batch
names = gen_axis_names()
self.contract_names = [next(names) for _ in range(len(map_cont[0]))]
self.batch_names = [next(names) for _ in range(len(map_batch[0]))]
self.all_names = self.contract_names + self.batch_names
@property
def dot_general_dim_nums(self):
lhs_contract = (*self.map_cont[0], *self.pos_cont[0])
rhs_contract = (*self.map_cont[1], *self.pos_cont[1])
lhs_batch = (*self.map_batch[0], *self.pos_batch[0])
rhs_batch = (*self.map_batch[1], *self.pos_batch[1])
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
@property
def pos_contract_after_mapping(self):
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_cont[0]]
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_cont[1]]
return (lhs, rhs)
@property
def pos_batch_after_mapping(self):
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_batch[0]]
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_batch[1]]
return (lhs, rhs)
@property
def _lhs_mapped(self):
return {*self.map_cont[0], *self.map_batch[0]}
@property
def _rhs_mapped(self):
return {*self.map_cont[1], *self.map_batch[1]}
@property
def lhs_in_axes(self):
axis_indices = [*self.map_cont[0], *self.map_batch[0]]
return dict(zip(axis_indices, self.all_names))
@property
def rhs_in_axes(self):
axis_indices = [*self.map_cont[1], *self.map_batch[1]]
return dict(zip(axis_indices, self.all_names))
def all_pdot_specs(lhs_shape, rhs_shape):
for matching in axis_matchings(lhs_shape, rhs_shape):
for lists in partitions(matching, 4):
yield PdotTestSpec(*map(unzip2, lists))
def axis_matchings(lhs_shape, rhs_shape):
def helper(start, exc1, exc2):
yield ()
for i in range(start, len(lhs_shape)):
d1 = lhs_shape[i]
if i not in exc1:
for j, d2 in enumerate(rhs_shape):
if d1 == d2 and j not in exc2:
for matches in helper(i + 1, exc1 | {i}, exc2 | {j}):
yield ((i, j), *matches)
return helper(0, set(), set())
def partitions(s, k):
for indices in product(range(k), repeat=len(s)):
outs = [[] for _ in range(k)]
for i, elt in zip(indices, s):
outs[i].append(elt)
yield outs
def powerset(s):
s = list(s)
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
def gen_axis_names():
names = 'ijkl'
for n in it.count(1):
for chars in product(names, repeat=n):
yield ''.join(chars)
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
def schedules(sizes: Dict[str, int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
"""Test utility generating xmap parallel schedules from logical names & sizes.
Args:
sizes: dict mapping logical axis name to its corresponding size.
Returns:
A generator producing finitely many values, where each value is a pair in
which the first element is a value suitable for xmap's axis_resources
argument and the second element is a list of pairs with the first element
representing a generated physical mesh axis name and the second element
representing a corresponding generated mesh axis size. The generated mesh
names/sizes can be used to define a physical mesh in tests.
This function doesn't generate schedules which map distinct logical axis names
to the same parallel resource name. It only generates parallel resources; the
rest are implicitly left for vectorization. Parallel resource names are
generated by prepending an 'r', 'r1', or 'r2' to the corresponding logical
name.
Exa,mples:
>>> for sched in schedules({'i': 2, 'j': 4}):
... print(sched)
({}, [])
({'i': 'ri'}, [('ri', 1)])
({'i': 'ri'}, [('ri', 2)])
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 1)])
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 2)])
({'i': ('r1i', 'r2i')}, [('r1i', 2), ('r2i', 1)])
({'j': 'rj'}, [('rj', 1)])
({'j': 'rj'}, [('rj', 2)])
({'j': 'rj'}, [('rj', 4)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 1)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 2)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 4)])
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 1)])
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 2)])
({'j': ('r1j', 'r2j')}, [('r1j', 4), ('r2j', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 2)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 4)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 2)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 4), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 4), ('r2j', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 2), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 2), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 2), ('r2i', 1)])
"""
def divisors(n: int) -> List[int]:
return [m for m in range(1, n + 1) if not n % m]
def divisors2(n: int) -> Iterator[Tuple[int, int]]:
for k1 in divisors(n):
for k2 in divisors(n // k1):
yield (k1, k2)
# choose a subset of logical axis names to map to parallel resources
for names in powerset(sizes):
# partition that set of logical axis names into two subsets: one subset to
# map to one parallel resource axis and a second subset to map to two
# parallel resource axes.
for names1, names2 in partitions(names, 2):
# to avoid generating too many complex cases, we skip generating cases
# where more than one logical axis name is to be mapped to two parallel
# resource axes. comment out this line to generate more complex tests.
if len(names2) > 1: continue
# make up parallel resource axis names for each logical axis
axis_resources1 = ((name, 'r' + name) for name in names1)
axis_resources2 = ((name, ('r1' + name, 'r2' + name)) for name in names2)
axis_resources = dict(it.chain(axis_resources1, axis_resources2))
# make up sizes for each resource axis, where the size must divide the
# corresponding logical axis
for mesh_sizes1 in product(*(divisors(sizes[n]) for n in names1)):
for mesh_sizes2 in product(*(divisors2(sizes[n]) for n in names2)):
mesh_data1 = (('r' + name, size) for name, size in zip(names1, mesh_sizes1))
mesh_data2 = (pair for name, (size1, size2) in zip(names2, mesh_sizes2)
for pair in [('r1' + name, size1), ('r2' + name, size2)])
mesh_data = list(it.chain(mesh_data1, mesh_data2))
yield axis_resources, mesh_data
def schedules_from_pdot_spec(
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
logical_sizes = {
name: shape[ax]
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
(rhs_shape, spec.rhs_in_axes)]
for ax, name in in_axes.items()}
yield from schedules(logical_sizes)
class PDotTests(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
@ignore_xmap_warning()
@with_mesh([('r1', 2)])
@ -402,6 +616,45 @@ class PDotTests(jtu.JaxTestCase):
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{next(test_counter)}",
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
"axis_resources": axis_resources, "mesh_data": mesh_data}
for test_counter in [it.count()]
for lhs_shape, rhs_shape in product(
[(2,), (2, 4, 2, 1)],
repeat=2)
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
for axis_resources, mesh_data in schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape)))
@ignore_xmap_warning()
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
mesh_data):
rng = jtu.rand_default(self.rng())
lhs = rng(lhs_shape, np.float32)
rhs = rng(rhs_shape, np.float32)
def pdot_fun(x, y):
# print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
# f' axis_name={contract_names},\n'
# f' pos_contract={spec.pos_contract_after_mapping}\n'
# f' pos_batch={spec.pos_batch_after_mapping})')
return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
pos_batch=pdot_spec.pos_batch_after_mapping,
pos_contract=pdot_spec.pos_contract_after_mapping)
fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes],
out_axes=[*pdot_spec.batch_names, ...],
axis_resources=axis_resources)
with with_mesh(mesh_data):
result = fun(lhs, rhs)
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(result, expected, check_dtypes=False,
atol=tol, rtol=tol)
class XMapErrorTest(jtu.JaxTestCase):