mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add systematic pdot tests, utility functions
Run lots of tests with e.g. ``` env JAX_NUM_GENERATED_CASES=1000 python tests/xmap_test.py PDotTests ```
This commit is contained in:
parent
60a87fd2da
commit
c02d8041f4
@ -352,11 +352,11 @@ def axis_index(axis_name):
|
||||
return axis_index_p.bind(axis_name=axis_name)
|
||||
|
||||
|
||||
def pdot(x, y, axis_name, pos_contract=((), ())):
|
||||
def pdot(x, y, axis_name, pos_contract=((), ()), pos_batch=((), ())):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
return pdot_p.bind(x, y, axis_name=axis_name,
|
||||
pos_contract=pos_contract, pos_batch=[(), ()])
|
||||
pos_contract=pos_contract, pos_batch=pos_batch)
|
||||
|
||||
|
||||
### parallel primitives
|
||||
@ -870,6 +870,7 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch):
|
||||
@pdot_p.def_abstract_eval
|
||||
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch):
|
||||
# TODO: avals with names, check inputs are mapped along axis_name, eliminate
|
||||
if not len(set(axis_name)) == len(axis_name): raise ValueError
|
||||
return lax.dot_general_p.abstract_eval(
|
||||
x, y, dimension_numbers=[pos_contract, pos_batch],
|
||||
precision=None, preferred_element_type=None)
|
||||
|
@ -14,11 +14,14 @@
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
from contextlib import contextmanager
|
||||
import functools
|
||||
import itertools
|
||||
import itertools as it
|
||||
import os
|
||||
import unittest
|
||||
from itertools import product, permutations
|
||||
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
|
||||
Any, Hashable, Iterable, Iterator, Union)
|
||||
from unittest import SkipTest, skip, skipIf
|
||||
|
||||
import numpy as np
|
||||
@ -33,7 +36,8 @@ from jax import vmap
|
||||
from jax import lax
|
||||
from jax.experimental.maps import Mesh, mesh, xmap
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.util import curry, unzip2
|
||||
from jax._src.util import curry, unzip2, split_list, prod
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax.interpreters import pxla
|
||||
|
||||
from jax.config import config
|
||||
@ -63,20 +67,20 @@ def tearDownModule():
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
@curry
|
||||
def with_mesh(named_shape, f):
|
||||
if not named_shape:
|
||||
return f
|
||||
def new_f(*args, **kwargs):
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = np.prod(shape)
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
return f(*args, **kwargs)
|
||||
return new_f
|
||||
MeshSpec = List[Tuple[str, int]]
|
||||
|
||||
@contextmanager
|
||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = prod(shape)
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
yield
|
||||
|
||||
|
||||
class XMapTest(jtu.JaxTestCase):
|
||||
@ -338,12 +342,222 @@ class XMapTestSPMD(XMapTest):
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
|
||||
|
||||
AxisIndices = Tuple[int, ...]
|
||||
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
|
||||
AxisNames = Tuple[str, ...]
|
||||
|
||||
class PdotTestSpec:
|
||||
# The axis indices stored by a PdotTestSpec are all positional indices
|
||||
# *before* taking mapping into account.
|
||||
map_cont: MatchedAxisIndices
|
||||
pos_cont: MatchedAxisIndices
|
||||
map_batch: MatchedAxisIndices
|
||||
pos_batch: MatchedAxisIndices
|
||||
all_names: AxisNames
|
||||
contract_names: AxisNames
|
||||
batch_names: AxisNames
|
||||
|
||||
def __init__(self, map_cont, pos_cont, map_batch, pos_batch):
|
||||
self.map_cont = map_cont
|
||||
self.pos_cont = pos_cont
|
||||
self.map_batch = map_batch
|
||||
self.pos_batch = pos_batch
|
||||
|
||||
names = gen_axis_names()
|
||||
self.contract_names = [next(names) for _ in range(len(map_cont[0]))]
|
||||
self.batch_names = [next(names) for _ in range(len(map_batch[0]))]
|
||||
self.all_names = self.contract_names + self.batch_names
|
||||
|
||||
@property
|
||||
def dot_general_dim_nums(self):
|
||||
lhs_contract = (*self.map_cont[0], *self.pos_cont[0])
|
||||
rhs_contract = (*self.map_cont[1], *self.pos_cont[1])
|
||||
lhs_batch = (*self.map_batch[0], *self.pos_batch[0])
|
||||
rhs_batch = (*self.map_batch[1], *self.pos_batch[1])
|
||||
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
|
||||
|
||||
@property
|
||||
def pos_contract_after_mapping(self):
|
||||
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_cont[0]]
|
||||
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_cont[1]]
|
||||
return (lhs, rhs)
|
||||
|
||||
@property
|
||||
def pos_batch_after_mapping(self):
|
||||
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_batch[0]]
|
||||
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_batch[1]]
|
||||
return (lhs, rhs)
|
||||
|
||||
@property
|
||||
def _lhs_mapped(self):
|
||||
return {*self.map_cont[0], *self.map_batch[0]}
|
||||
|
||||
@property
|
||||
def _rhs_mapped(self):
|
||||
return {*self.map_cont[1], *self.map_batch[1]}
|
||||
|
||||
@property
|
||||
def lhs_in_axes(self):
|
||||
axis_indices = [*self.map_cont[0], *self.map_batch[0]]
|
||||
return dict(zip(axis_indices, self.all_names))
|
||||
|
||||
@property
|
||||
def rhs_in_axes(self):
|
||||
axis_indices = [*self.map_cont[1], *self.map_batch[1]]
|
||||
return dict(zip(axis_indices, self.all_names))
|
||||
|
||||
def all_pdot_specs(lhs_shape, rhs_shape):
|
||||
for matching in axis_matchings(lhs_shape, rhs_shape):
|
||||
for lists in partitions(matching, 4):
|
||||
yield PdotTestSpec(*map(unzip2, lists))
|
||||
|
||||
def axis_matchings(lhs_shape, rhs_shape):
|
||||
def helper(start, exc1, exc2):
|
||||
yield ()
|
||||
for i in range(start, len(lhs_shape)):
|
||||
d1 = lhs_shape[i]
|
||||
if i not in exc1:
|
||||
for j, d2 in enumerate(rhs_shape):
|
||||
if d1 == d2 and j not in exc2:
|
||||
for matches in helper(i + 1, exc1 | {i}, exc2 | {j}):
|
||||
yield ((i, j), *matches)
|
||||
return helper(0, set(), set())
|
||||
|
||||
def partitions(s, k):
|
||||
for indices in product(range(k), repeat=len(s)):
|
||||
outs = [[] for _ in range(k)]
|
||||
for i, elt in zip(indices, s):
|
||||
outs[i].append(elt)
|
||||
yield outs
|
||||
|
||||
def powerset(s):
|
||||
s = list(s)
|
||||
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
||||
|
||||
def gen_axis_names():
|
||||
names = 'ijkl'
|
||||
for n in it.count(1):
|
||||
for chars in product(names, repeat=n):
|
||||
yield ''.join(chars)
|
||||
|
||||
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
||||
|
||||
def schedules(sizes: Dict[str, int]
|
||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
||||
"""Test utility generating xmap parallel schedules from logical names & sizes.
|
||||
|
||||
Args:
|
||||
sizes: dict mapping logical axis name to its corresponding size.
|
||||
|
||||
Returns:
|
||||
A generator producing finitely many values, where each value is a pair in
|
||||
which the first element is a value suitable for xmap's axis_resources
|
||||
argument and the second element is a list of pairs with the first element
|
||||
representing a generated physical mesh axis name and the second element
|
||||
representing a corresponding generated mesh axis size. The generated mesh
|
||||
names/sizes can be used to define a physical mesh in tests.
|
||||
|
||||
This function doesn't generate schedules which map distinct logical axis names
|
||||
to the same parallel resource name. It only generates parallel resources; the
|
||||
rest are implicitly left for vectorization. Parallel resource names are
|
||||
generated by prepending an 'r', 'r1', or 'r2' to the corresponding logical
|
||||
name.
|
||||
|
||||
Exa,mples:
|
||||
>>> for sched in schedules({'i': 2, 'j': 4}):
|
||||
... print(sched)
|
||||
({}, [])
|
||||
({'i': 'ri'}, [('ri', 1)])
|
||||
({'i': 'ri'}, [('ri', 2)])
|
||||
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 1)])
|
||||
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 2)])
|
||||
({'i': ('r1i', 'r2i')}, [('r1i', 2), ('r2i', 1)])
|
||||
({'j': 'rj'}, [('rj', 1)])
|
||||
({'j': 'rj'}, [('rj', 2)])
|
||||
({'j': 'rj'}, [('rj', 4)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 1)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 2)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 4)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 1)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 2)])
|
||||
({'j': ('r1j', 'r2j')}, [('r1j', 4), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 1)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 2)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 4)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 1)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 2)])
|
||||
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 4)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 2)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 4)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 2)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 4), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 2)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 4)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 1)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 2)])
|
||||
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 4), ('r2j', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 2)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 2), ('r2i', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 2)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 2), ('r2i', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 1)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 2)])
|
||||
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 2), ('r2i', 1)])
|
||||
"""
|
||||
def divisors(n: int) -> List[int]:
|
||||
return [m for m in range(1, n + 1) if not n % m]
|
||||
|
||||
def divisors2(n: int) -> Iterator[Tuple[int, int]]:
|
||||
for k1 in divisors(n):
|
||||
for k2 in divisors(n // k1):
|
||||
yield (k1, k2)
|
||||
|
||||
# choose a subset of logical axis names to map to parallel resources
|
||||
for names in powerset(sizes):
|
||||
# partition that set of logical axis names into two subsets: one subset to
|
||||
# map to one parallel resource axis and a second subset to map to two
|
||||
# parallel resource axes.
|
||||
for names1, names2 in partitions(names, 2):
|
||||
# to avoid generating too many complex cases, we skip generating cases
|
||||
# where more than one logical axis name is to be mapped to two parallel
|
||||
# resource axes. comment out this line to generate more complex tests.
|
||||
if len(names2) > 1: continue
|
||||
# make up parallel resource axis names for each logical axis
|
||||
axis_resources1 = ((name, 'r' + name) for name in names1)
|
||||
axis_resources2 = ((name, ('r1' + name, 'r2' + name)) for name in names2)
|
||||
axis_resources = dict(it.chain(axis_resources1, axis_resources2))
|
||||
# make up sizes for each resource axis, where the size must divide the
|
||||
# corresponding logical axis
|
||||
for mesh_sizes1 in product(*(divisors(sizes[n]) for n in names1)):
|
||||
for mesh_sizes2 in product(*(divisors2(sizes[n]) for n in names2)):
|
||||
mesh_data1 = (('r' + name, size) for name, size in zip(names1, mesh_sizes1))
|
||||
mesh_data2 = (pair for name, (size1, size2) in zip(names2, mesh_sizes2)
|
||||
for pair in [('r1' + name, size1), ('r2' + name, size2)])
|
||||
mesh_data = list(it.chain(mesh_data1, mesh_data2))
|
||||
yield axis_resources, mesh_data
|
||||
|
||||
def schedules_from_pdot_spec(
|
||||
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
|
||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
||||
logical_sizes = {
|
||||
name: shape[ax]
|
||||
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
|
||||
(rhs_shape, spec.rhs_in_axes)]
|
||||
for ax, name in in_axes.items()}
|
||||
yield from schedules(logical_sizes)
|
||||
|
||||
|
||||
class PDotTests(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('r1', 2)])
|
||||
@ -402,6 +616,45 @@ class PDotTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{next(test_counter)}",
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
|
||||
"axis_resources": axis_resources, "mesh_data": mesh_data}
|
||||
for test_counter in [it.count()]
|
||||
for lhs_shape, rhs_shape in product(
|
||||
[(2,), (2, 4, 2, 1)],
|
||||
repeat=2)
|
||||
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
|
||||
for axis_resources, mesh_data in schedules_from_pdot_spec(
|
||||
pdot_spec, lhs_shape, rhs_shape)))
|
||||
@ignore_xmap_warning()
|
||||
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
|
||||
mesh_data):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
lhs = rng(lhs_shape, np.float32)
|
||||
rhs = rng(rhs_shape, np.float32)
|
||||
|
||||
def pdot_fun(x, y):
|
||||
# print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
|
||||
# f' axis_name={contract_names},\n'
|
||||
# f' pos_contract={spec.pos_contract_after_mapping}\n'
|
||||
# f' pos_batch={spec.pos_batch_after_mapping})')
|
||||
return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
|
||||
pos_batch=pdot_spec.pos_batch_after_mapping,
|
||||
pos_contract=pdot_spec.pos_contract_after_mapping)
|
||||
|
||||
fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes],
|
||||
out_axes=[*pdot_spec.batch_names, ...],
|
||||
axis_resources=axis_resources)
|
||||
|
||||
with with_mesh(mesh_data):
|
||||
result = fun(lhs, rhs)
|
||||
|
||||
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(result, expected, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user