Make xmap into a primitive

Add a compilation cache. Also make sure that it raises a clear error
when you try to use it with other transforms.

Also, add a bunch of checks to make sure that the arguments are valid.
This commit is contained in:
Adam Paszke 2020-11-19 18:22:35 +00:00
parent ad6ceeeaf5
commit c5433b0e4e
3 changed files with 271 additions and 128 deletions

View File

@ -16,6 +16,7 @@ import enum
import threading
import contextlib
import numpy as np
import itertools as it
from collections import namedtuple, OrderedDict
from typing import Callable, Iterable, List, Tuple, Optional, Dict, Any, Set
from warnings import warn
@ -26,29 +27,56 @@ from .. import numpy as jnp
from .. import core
from .. import linear_util as lu
from ..api import _mapped_axis_size, _check_callable, _check_arg
from ..tree_util import tree_flatten, tree_unflatten
from ..api_util import flatten_fun
from ..tree_util import tree_flatten, tree_unflatten, tree_leaves
from ..api_util import flatten_fun, flatten_fun_nokwargs, flatten_axes
from ..interpreters import partial_eval as pe
from ..interpreters import batching
from ..interpreters import pxla
from ..util import safe_map, safe_zip, curry
map = safe_map
map, unsafe_map = safe_map, map
zip = safe_zip
class FrozenDict: # dataclasses might remove some boilerplate here
def __init__(self, *args, **kwargs):
self.contents = dict(*args, **kwargs)
allowed_methods = {'items', 'values', 'keys', 'get'}
def __getattr__(self, name):
if name in self.allowed_methods:
return getattr(self.contents, name)
raise AttributeError(name)
def __iter__(self):
return self.contents.__iter__()
def __len__(self):
return self.contents.__len__()
def __getitem__(self, name):
return self.contents.__getitem__(name)
def __eq__(self, other):
return isinstance(other, FrozenDict) and self.contents == other.contents
def __hash__(self):
return hash(tuple(self.contents.items()))
# Multi-dimensional generalized map
# TODO: Use a more concrete type annotation (we need __eq__ and __hash__)
AxisName = Any
ResourceAxisName = Any
AxisName = core.AxisName
ResourceAxisName = AxisName # Different name just for documentation purposes
Mesh = pxla.Mesh
# TODO: Support sequential mapping
class ResourceEnv(threading.local):
def __init__(self):
self.physical_mesh : Mesh = Mesh(np.empty((), dtype=object), ())
self.fake_resources : Dict[ResourceAxisName, int] = {}
class ResourceEnv:
__slots__ = ('physical_mesh', 'fake_resources')
physical_mesh: Mesh
fake_resources: FrozenDict
def __init__(self, physical_mesh: Mesh, fake_resources: FrozenDict):
super().__setattr__('physical_mesh', physical_mesh)
super().__setattr__('fake_resources', fake_resources)
@property
def physical_resource_axes(self) -> Set[ResourceAxisName]:
@ -68,26 +96,40 @@ class ResourceEnv(threading.local):
shape.update((name, size) for name, size in self.fake_resources.items())
return shape
# TODO: Make this thread local
thread_resource_env = ResourceEnv()
def __setattr__(self, name, value):
raise RuntimeError("ResourceEnv is immutable!")
def __delattr__(self):
raise RuntimeError("ResourceEnv is immutable!")
def __eq__(self, other):
return (type(other) is ResourceEnv and
self.physical_mesh == other.physical_mesh and
self.fake_resources == other.fake_resources)
def __hash__(self):
return hash((self.physical_mesh, self.fake_resources))
thread_resources = threading.local()
thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), FrozenDict())
@contextlib.contextmanager
def fake_resources(**axes):
old_axes = thread_resource_env.fake_resources
thread_resource_env.fake_resources = axes
old_env = thread_resources.env
thread_resources.env = ResourceEnv(old_env.physical_mesh, FrozenDict(axes))
try:
yield
finally:
thread_resource_env.axes = old_axes
thread_resources.env = old_env
@contextlib.contextmanager
def mesh(*args, **kwargs):
old = thread_resource_env.physical_mesh
thread_resource_env.physical_mesh = Mesh(*args, **kwargs)
old_env = thread_resources.env
thread_resources.env = ResourceEnv(Mesh(*args, **kwargs), old_env.fake_resources)
try:
yield
finally:
thread_resource_env.physical_mesh = old
thread_resources.env = old_env
_next_resource_id = 0
class UniqueResourceName:
@ -103,118 +145,200 @@ def fresh_resource_name():
# This is really a Dict[AxisName, int], but we don't define a
# pytree instance for it, so that it is treated as a leaf.
class AxisNamePos(dict):
class AxisNamePos(FrozenDict):
pass
A = AxisNamePos
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
def xmap(fun: Callable,
in_axes, # PyTree[AxisNamePos]
out_axes, # PyTree[AxisNamePos],
schedule: Iterable[Tuple[AxisName, ResourceAxisName]]):
schedule: Iterable[Tuple[AxisName, ResourceAxisName]],
backend: Optional[str] = None):
warn("xmap is an experimental feature and probably has bugs!")
_check_callable(fun)
def fun_mapped(*args, **kwargs):
# Putting this outside of fun_mapped would make resources lexically scoped
resource_env = thread_resource_env
frozen_schedule = tuple(tuple(x) for x in schedule)
args_flat, in_tree = tree_flatten((args, kwargs))
if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axes = tuple(in_axes)
in_axes_entries = tree_leaves(in_axes)
out_axes_entries = tree_leaves(out_axes)
# Check that {in|out}_axes have the right types, and don't use the same positional axis twice
if not all(isinstance(x, A) for x in in_axes_entries):
raise TypeError(f"xmap in_axes must be AxisNamePos (A) instances or (nested) "
f"containers with those types as leaves, but got {in_axes}")
if not all(isinstance(x, A) for x in out_axes_entries):
raise TypeError(f"xmap out_axes must be AxisNamePos (A) instances or (nested) "
f"containers with those types as leaves, but got {in_axes}")
for x in in_axes_entries:
if len(set(x.values())) != len(x):
raise ValueError(f"Positional dimension indices should be unique within each "
f"in_axes dictionary, but one of the entries is: {x}")
for x in out_axes_entries:
if len(set(x.values())) != len(x):
raise ValueError(f"Positional dimension indices should be unique within each "
f"in_axes dictionary, but one of the entries is: {x}")
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
scheduled_axes = set(x[0] for x in frozen_schedule)
if scheduled_axes != in_axes_names:
raise ValueError("The set of axes names appearing in in_axes has to equal the "
"set of scheduled axes, but {in_axes_names} != {scheduled_axes}")
necessary_resources = set(x[1] for x in frozen_schedule if x[1] != 'vectorize')
if len(set(frozen_schedule)) != len(frozen_schedule):
raise ValueError(f"xmap schedule contains duplicate entries: {frozen_schedule}")
def fun_mapped(*args):
# Putting this outside of fun_mapped would make resources lexically scoped
resource_env = thread_resources.env
available_resources = set(resource_env.shape.keys())
if necessary_resources > available_resources:
raise ValueError(f"In-scope resources are insufficient to execute the "
f"xmapped function. The missing resources are: "
f"{necessary_resources - available_resources}")
args_flat, in_tree = tree_flatten(args)
for arg in args_flat: _check_arg(arg)
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
# TODO: Check that:
# - every scheduled axis name appears in at least one input
# - every used resource axis name appears in the resource env
# - every axis name is scheduled to a single resource axis only once
# - every out axis has a distinct index
# - two axes mapped to the same resource never coincide (even inside f)
in_axes_flat, in_axes_tree = tree_flatten(in_axes)
# TODO: Verify that in_axes are equal, or better expand their prefix
# assert in_axes_tree == in_tree
in_axes_flat = flatten_axes("xmap in_axes", in_tree, in_axes)
out_axes_flat, out_axes_tree = tree_flatten(out_axes)
# TODO: Verify that out_axes are equal, or better expand their prefix
# assert out_axes_tree == out_tree
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
jaxpr, out_tree = _trace_mapped_jaxpr(fun, args_flat, in_axes_flat, axis_sizes, in_tree)
# TODO: The order of maps should be derived from the schedule, not from the
# resource env. This doesn't really matter for as long as we only support
# vectorization and parallelization, but will be important for sequential.
# We should be able to do that by building a graph of dependencies between
# resources based on the order in which they appear within each axis.
# If it has cycles then we cannot realize it. Otherwise, if the DAG doesn't
# uniquely identify a linear order, we should use the order of entries in
# the schedule to break ties.
# Technically the order doesn't matter right now, but we use the ordered dict
# to at least limit the amount of non-determinism in this code.
fake_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
physical_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
vectorized: Dict[AxisName, ResourceAxisName] = OrderedDict()
axis_subst: Dict[AxisName, List[ResourceAxisName]] = {}
for axis, resource in schedule:
if resource == 'vectorize':
assert axis not in vectorized
resource = fresh_resource_name()
vectorized[axis] = resource
elif resource in resource_env.physical_resource_axes:
# TODO: Make sure that axis was not in the set?
physical_resource_map.setdefault(resource, set()).add(axis)
elif resource in resource_env.fake_resource_axes:
# TODO: Make sure that axis was not in the set?
fake_resource_map.setdefault(resource, set()).add(axis)
else:
raise ValueError(f"Mapping axis {axis} to an undefined resource axis {resource}. "
f"The resource axes currently in scope are: {resource_env.resource_axes}")
axis_subst.setdefault(axis, []).append(resource)
axis_subst_t = {axis: tuple(resources) for axis, resources in axis_subst.items()}
jaxpr = jaxpr.map_jaxpr(partial(subst_axis_names, axis_subst=axis_subst_t))
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f = hide_mapped_axes(f, in_axes_flat, out_axes_flat)
for naxis, raxis in vectorized.items():
map_in_axes = map(lambda spec: spec.get(naxis, None), in_axes_flat)
map_out_axes = map(lambda spec: spec.get(naxis, None), out_axes_flat)
f = vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=raxis)
resource_env_shape = resource_env.shape
for raxis, naxes in fake_resource_map.items():
map_in_axes = map(lambda spec: lookup_exactly_one_of(spec, naxes), in_axes_flat)
map_out_axes = map(lambda spec: lookup_exactly_one_of(spec, naxes), out_axes_flat)
map_size = resource_env_shape[raxis]
f = vtile(f, map_in_axes, map_out_axes, tile_size=map_size, axis_name=raxis)
if physical_resource_map:
submesh = resource_env.physical_mesh[tuple(physical_resource_map.keys())]
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
def to_mesh_axes(axes):
mesh_axes = {}
for paxis, naxes in physical_resource_map.items():
axis = lookup_exactly_one_of(axes, naxes)
if axis is None:
continue
mesh_axes[paxis] = axis
return mesh_axes
mesh_in_axes = map(to_mesh_axes, in_axes)
mesh_out_axes = map(to_mesh_axes, out_axes)
f = pxla.mesh_tiled_callable(*in_avals, fun=f,
transformed_name=f.__name__,
backend_name=None,
mesh=submesh,
in_axes=mesh_in_axes,
out_axes_thunk=lambda: mesh_out_axes)
flat_out = f(*args_flat)
else:
flat_out = f.call_wrapped(*args_flat)
return tree_unflatten(out_tree, flat_out)
out_flat = xmap_p.bind(
fun_flat, *args_flat,
in_axes=tuple(in_axes_flat),
out_axes=tuple(out_axes_flat),
axis_sizes=FrozenDict(axis_sizes),
schedule=frozen_schedule,
resource_env=resource_env,
backend=backend)
return tree_unflatten(out_tree(), out_flat)
return fun_mapped
def xmap_impl(fun: lu.WrappedFun, *args, in_axes, out_axes, axis_sizes, schedule, resource_env, backend):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
return make_xmap_callable(fun, in_axes, out_axes, axis_sizes, schedule,
resource_env, backend, *in_avals)(*args)
@lu.cache
def make_xmap_callable(fun: lu.WrappedFun,
in_axes, out_axes, axis_sizes,
schedule, resource_env, backend,
*in_avals):
mapped_in_avals = [_delete_aval_axes(aval, in_axes)
for aval, in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(axis_sizes.items()):
raw_jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
jaxpr = core.ClosedJaxpr(raw_jaxpr, consts)
# TODO: The order of maps should be derived from the schedule, not from the
# resource env. This doesn't really matter for as long as we only support
# vectorization and parallelization, but will be important for sequential.
# We should be able to do that by building a graph of dependencies between
# resources based on the order in which they appear within each axis.
# If it has cycles then we cannot realize it. Otherwise, if the DAG doesn't
# uniquely identify a linear order, we should use the order of entries in
# the schedule to break ties.
# Technically the order doesn't matter right now, but we use the ordered dict
# to at least limit the amount of non-determinism in this code.
fake_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
physical_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
vectorized: Dict[AxisName, ResourceAxisName] = OrderedDict()
axis_subst: Dict[AxisName, List[ResourceAxisName]] = {}
for axis, resource in schedule:
if resource == 'vectorize':
assert axis not in vectorized
resource = fresh_resource_name()
vectorized[axis] = resource
elif resource in resource_env.physical_resource_axes:
physical_resource_map.setdefault(resource, set()).add(axis)
elif resource in resource_env.fake_resource_axes:
fake_resource_map.setdefault(resource, set()).add(axis)
else:
raise ValueError(f"Mapping axis {axis} to an undefined resource axis {resource}. "
f"The resource axes currently in scope are: {resource_env.resource_axes}")
axis_subst.setdefault(axis, []).append(resource)
axis_subst_t = {axis: tuple(resources) for axis, resources in axis_subst.items()}
jaxpr = jaxpr.map_jaxpr(partial(subst_axis_names, axis_subst=axis_subst_t))
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
for naxis, raxis in vectorized.items():
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
f = vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=raxis)
resource_env_shape = resource_env.shape
for raxis, naxes in fake_resource_map.items():
map_in_axes = tuple(unsafe_map(lambda spec: lookup_exactly_one_of(spec, naxes), in_axes))
map_out_axes = tuple(unsafe_map(lambda spec: lookup_exactly_one_of(spec, naxes), out_axes))
map_size = resource_env_shape[raxis]
f = vtile(f, map_in_axes, map_out_axes, tile_size=map_size, axis_name=raxis)
if physical_resource_map:
submesh = resource_env.physical_mesh[tuple(physical_resource_map.keys())]
def to_mesh_axes(axes):
mesh_axes = {}
for paxis, naxes in physical_resource_map.items():
axis = lookup_exactly_one_of(axes, naxes)
if axis is None:
continue
mesh_axes[paxis] = axis
return A(mesh_axes)
mesh_in_axes = tuple(unsafe_map(to_mesh_axes, in_axes))
mesh_out_axes = tuple(unsafe_map(to_mesh_axes, out_axes))
return pxla.mesh_tiled_callable(f,
f.__name__,
backend,
submesh,
mesh_in_axes,
mesh_out_axes,
*in_avals)
else:
return f.call_wrapped
# xmap has a different set of parameters than pmap, so we make it its own primitive type
class XMapPrimitive(core.Primitive):
multiple_results = True
def __init__(self):
super().__init__('xmap')
self.def_impl(xmap_impl)
self.def_custom_bind(partial(core.call_bind, self))
def bind(self, fun, *args, **params):
assert len(params['in_axes']) == len(args)
return core.call_bind(self, fun, *args, **params)
def process(self, trace, fun, tracers, params):
return trace.process_xmap(self, fun, tracers, params)
def post_process(self, trace, out_tracers, params):
raise NotImplementedError
return trace.post_process_xmap(self, out_tracers, params)
xmap_p = XMapPrimitive()
core.EvalTrace.process_xmap = core.EvalTrace.process_call # type: ignore
def _delete_aval_axes(aval, axes: AxisNamePos):
assert isinstance(aval, core.ShapedArray)
shape = list(aval.shape)
@ -222,19 +346,7 @@ def _delete_aval_axes(aval, axes: AxisNamePos):
del shape[i]
return core.ShapedArray(tuple(shape), aval.dtype)
def _trace_mapped_jaxpr(fun,
args_flat,
in_axes_flat: List[AxisNamePos],
axis_sizes: Dict[AxisName, int],
in_tree):
fun_flat, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
mapped_avals = [_delete_aval_axes(aval, in_axes)
for aval, in_axes in zip(avals_flat, in_axes_flat)]
with core.extend_axis_env_nd(axis_sizes.items()):
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun_flat, mapped_avals)
return core.ClosedJaxpr(jaxpr, consts), out_tree()
# TODO: pmap has some very fancy error messages for this function!
def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos]):
axis_sizes: Dict[AxisName, int] = {}
for arg, in_axes in zip(args_flat, in_axes_flat):
@ -286,7 +398,10 @@ def untile_axis(out, axis: Optional[int]):
return out.reshape(shape)
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat, in_axes_flat, out_axes_flat, tile_size: Optional[int], axis_name):
def vtile(f_flat,
in_axes_flat: Tuple[AxisNamePos, ...],
out_axes_flat: Tuple[AxisNamePos, ...],
tile_size: Optional[int], axis_name):
@lu.transformation
def _map_to_tile(*args_flat):
real_tile_size = tile_size

View File

@ -1223,14 +1223,13 @@ def untile_aval_nd(axis_sizes, out_axes: AxisNameMap, aval):
shape[axis] *= axis_sizes[name]
return ShapedArray(tuple(shape), aval.dtype)
# TODO(apaszke): Cache compilation
def mesh_tiled_callable(*in_avals,
fun: lu.WrappedFun,
def mesh_tiled_callable(fun: lu.WrappedFun,
transformed_name: str,
backend_name: Optional[str],
mesh: Mesh,
in_axes: Sequence[AxisNameMap],
out_axes_thunk: Callable[[], Sequence[AxisNameMap]]):
out_axes: Sequence[AxisNameMap],
*in_avals):
assert config.omnistaging_enabled
local_mesh = mesh.local_mesh
@ -1244,7 +1243,6 @@ def mesh_tiled_callable(*in_avals,
for aval_in_axes, aval in zip(in_axes, in_avals))
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, sharded_avals)
out_axes = out_axes_thunk()
assert len(out_axes) == len(out_sharded_avals)
# TODO(apaszke): What about outfeed?
# jaxpr = xla.apply_outfeed_rewriter(jaxpr)

View File

@ -31,7 +31,7 @@ from jax import vmap
from jax import lax
from jax.experimental.general_map import gmap, fake_resources, Mesh, mesh, xmap, A
from jax.lib import xla_bridge
from jax.util import curry
from jax.util import curry, unzip2
from jax.config import config
config.parse_flags_with_absl()
@ -86,6 +86,19 @@ def check_default_schedules(cond, fun):
{"testcase_name": "_" + name, "schedule": schedule}
for name, schedule in schedules)(fun)
@curry
def with_mesh(named_shape, f):
def new_f(*args, **kwargs):
axis_names, shape = unzip2(named_shape)
size = np.prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
return f(*args, **kwargs)
return new_f
class GmapTest(jtu.JaxTestCase):
@ -227,5 +240,22 @@ class GmapTest(jtu.JaxTestCase):
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@ignore_gmap_warning()
@with_mesh([('x', 2)])
def testXMapCompilationCache(self):
def f(x):
assert python_should_be_executing
return x * 2
fm = xmap(f,
in_axes=[A({'a': 0})],
out_axes=[A({'a': 0})],
schedule=[('a', 'x'), ('a', 'vectorize')])
x = np.arange(8).reshape((2, 2, 2))
python_should_be_executing = True
fm(x)
python_should_be_executing = False
fm(x)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())