mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make xmap into a primitive
Add a compilation cache. Also make sure that it raises a clear error when you try to use it with other transforms. Also, add a bunch of checks to make sure that the arguments are valid.
This commit is contained in:
parent
ad6ceeeaf5
commit
c5433b0e4e
@ -16,6 +16,7 @@ import enum
|
||||
import threading
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import itertools as it
|
||||
from collections import namedtuple, OrderedDict
|
||||
from typing import Callable, Iterable, List, Tuple, Optional, Dict, Any, Set
|
||||
from warnings import warn
|
||||
@ -26,29 +27,56 @@ from .. import numpy as jnp
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..api import _mapped_axis_size, _check_callable, _check_arg
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
from ..api_util import flatten_fun
|
||||
from ..tree_util import tree_flatten, tree_unflatten, tree_leaves
|
||||
from ..api_util import flatten_fun, flatten_fun_nokwargs, flatten_axes
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import pxla
|
||||
from ..util import safe_map, safe_zip, curry
|
||||
|
||||
map = safe_map
|
||||
map, unsafe_map = safe_map, map
|
||||
zip = safe_zip
|
||||
|
||||
class FrozenDict: # dataclasses might remove some boilerplate here
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.contents = dict(*args, **kwargs)
|
||||
|
||||
allowed_methods = {'items', 'values', 'keys', 'get'}
|
||||
def __getattr__(self, name):
|
||||
if name in self.allowed_methods:
|
||||
return getattr(self.contents, name)
|
||||
raise AttributeError(name)
|
||||
|
||||
def __iter__(self):
|
||||
return self.contents.__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return self.contents.__len__()
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.contents.__getitem__(name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, FrozenDict) and self.contents == other.contents
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.contents.items()))
|
||||
|
||||
# Multi-dimensional generalized map
|
||||
|
||||
# TODO: Use a more concrete type annotation (we need __eq__ and __hash__)
|
||||
AxisName = Any
|
||||
ResourceAxisName = Any
|
||||
|
||||
AxisName = core.AxisName
|
||||
ResourceAxisName = AxisName # Different name just for documentation purposes
|
||||
Mesh = pxla.Mesh
|
||||
|
||||
# TODO: Support sequential mapping
|
||||
class ResourceEnv(threading.local):
|
||||
def __init__(self):
|
||||
self.physical_mesh : Mesh = Mesh(np.empty((), dtype=object), ())
|
||||
self.fake_resources : Dict[ResourceAxisName, int] = {}
|
||||
class ResourceEnv:
|
||||
__slots__ = ('physical_mesh', 'fake_resources')
|
||||
physical_mesh: Mesh
|
||||
fake_resources: FrozenDict
|
||||
|
||||
def __init__(self, physical_mesh: Mesh, fake_resources: FrozenDict):
|
||||
super().__setattr__('physical_mesh', physical_mesh)
|
||||
super().__setattr__('fake_resources', fake_resources)
|
||||
|
||||
@property
|
||||
def physical_resource_axes(self) -> Set[ResourceAxisName]:
|
||||
@ -68,26 +96,40 @@ class ResourceEnv(threading.local):
|
||||
shape.update((name, size) for name, size in self.fake_resources.items())
|
||||
return shape
|
||||
|
||||
# TODO: Make this thread local
|
||||
thread_resource_env = ResourceEnv()
|
||||
def __setattr__(self, name, value):
|
||||
raise RuntimeError("ResourceEnv is immutable!")
|
||||
|
||||
def __delattr__(self):
|
||||
raise RuntimeError("ResourceEnv is immutable!")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(other) is ResourceEnv and
|
||||
self.physical_mesh == other.physical_mesh and
|
||||
self.fake_resources == other.fake_resources)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.physical_mesh, self.fake_resources))
|
||||
|
||||
thread_resources = threading.local()
|
||||
thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), FrozenDict())
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fake_resources(**axes):
|
||||
old_axes = thread_resource_env.fake_resources
|
||||
thread_resource_env.fake_resources = axes
|
||||
old_env = thread_resources.env
|
||||
thread_resources.env = ResourceEnv(old_env.physical_mesh, FrozenDict(axes))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_resource_env.axes = old_axes
|
||||
thread_resources.env = old_env
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mesh(*args, **kwargs):
|
||||
old = thread_resource_env.physical_mesh
|
||||
thread_resource_env.physical_mesh = Mesh(*args, **kwargs)
|
||||
old_env = thread_resources.env
|
||||
thread_resources.env = ResourceEnv(Mesh(*args, **kwargs), old_env.fake_resources)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_resource_env.physical_mesh = old
|
||||
thread_resources.env = old_env
|
||||
|
||||
_next_resource_id = 0
|
||||
class UniqueResourceName:
|
||||
@ -103,118 +145,200 @@ def fresh_resource_name():
|
||||
|
||||
# This is really a Dict[AxisName, int], but we don't define a
|
||||
# pytree instance for it, so that it is treated as a leaf.
|
||||
class AxisNamePos(dict):
|
||||
class AxisNamePos(FrozenDict):
|
||||
pass
|
||||
|
||||
A = AxisNamePos
|
||||
|
||||
|
||||
# TODO: Some syntactic sugar to make the API more usable in a single-axis case?
|
||||
# TODO: Are the resource axes scoped lexically or dynamically? Dynamically for now!
|
||||
def xmap(fun: Callable,
|
||||
in_axes, # PyTree[AxisNamePos]
|
||||
out_axes, # PyTree[AxisNamePos],
|
||||
schedule: Iterable[Tuple[AxisName, ResourceAxisName]]):
|
||||
schedule: Iterable[Tuple[AxisName, ResourceAxisName]],
|
||||
backend: Optional[str] = None):
|
||||
warn("xmap is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
|
||||
def fun_mapped(*args, **kwargs):
|
||||
# Putting this outside of fun_mapped would make resources lexically scoped
|
||||
resource_env = thread_resource_env
|
||||
frozen_schedule = tuple(tuple(x) for x in schedule)
|
||||
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
if isinstance(in_axes, list):
|
||||
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
||||
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
||||
# in cases like these users expect tuples and lists to be treated
|
||||
# essentially interchangeably, so we canonicalize lists to tuples here
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_axes = tuple(in_axes)
|
||||
|
||||
in_axes_entries = tree_leaves(in_axes)
|
||||
out_axes_entries = tree_leaves(out_axes)
|
||||
# Check that {in|out}_axes have the right types, and don't use the same positional axis twice
|
||||
if not all(isinstance(x, A) for x in in_axes_entries):
|
||||
raise TypeError(f"xmap in_axes must be AxisNamePos (A) instances or (nested) "
|
||||
f"containers with those types as leaves, but got {in_axes}")
|
||||
if not all(isinstance(x, A) for x in out_axes_entries):
|
||||
raise TypeError(f"xmap out_axes must be AxisNamePos (A) instances or (nested) "
|
||||
f"containers with those types as leaves, but got {in_axes}")
|
||||
for x in in_axes_entries:
|
||||
if len(set(x.values())) != len(x):
|
||||
raise ValueError(f"Positional dimension indices should be unique within each "
|
||||
f"in_axes dictionary, but one of the entries is: {x}")
|
||||
for x in out_axes_entries:
|
||||
if len(set(x.values())) != len(x):
|
||||
raise ValueError(f"Positional dimension indices should be unique within each "
|
||||
f"in_axes dictionary, but one of the entries is: {x}")
|
||||
|
||||
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
|
||||
scheduled_axes = set(x[0] for x in frozen_schedule)
|
||||
if scheduled_axes != in_axes_names:
|
||||
raise ValueError("The set of axes names appearing in in_axes has to equal the "
|
||||
"set of scheduled axes, but {in_axes_names} != {scheduled_axes}")
|
||||
|
||||
necessary_resources = set(x[1] for x in frozen_schedule if x[1] != 'vectorize')
|
||||
if len(set(frozen_schedule)) != len(frozen_schedule):
|
||||
raise ValueError(f"xmap schedule contains duplicate entries: {frozen_schedule}")
|
||||
|
||||
def fun_mapped(*args):
|
||||
# Putting this outside of fun_mapped would make resources lexically scoped
|
||||
resource_env = thread_resources.env
|
||||
available_resources = set(resource_env.shape.keys())
|
||||
|
||||
if necessary_resources > available_resources:
|
||||
raise ValueError(f"In-scope resources are insufficient to execute the "
|
||||
f"xmapped function. The missing resources are: "
|
||||
f"{necessary_resources - available_resources}")
|
||||
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
for arg in args_flat: _check_arg(arg)
|
||||
fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
# TODO: Check that:
|
||||
# - every scheduled axis name appears in at least one input
|
||||
# - every used resource axis name appears in the resource env
|
||||
# - every axis name is scheduled to a single resource axis only once
|
||||
# - every out axis has a distinct index
|
||||
# - two axes mapped to the same resource never coincide (even inside f)
|
||||
in_axes_flat, in_axes_tree = tree_flatten(in_axes)
|
||||
# TODO: Verify that in_axes are equal, or better expand their prefix
|
||||
# assert in_axes_tree == in_tree
|
||||
in_axes_flat = flatten_axes("xmap in_axes", in_tree, in_axes)
|
||||
out_axes_flat, out_axes_tree = tree_flatten(out_axes)
|
||||
# TODO: Verify that out_axes are equal, or better expand their prefix
|
||||
# assert out_axes_tree == out_tree
|
||||
|
||||
axis_sizes = _get_axis_sizes(args_flat, in_axes_flat)
|
||||
jaxpr, out_tree = _trace_mapped_jaxpr(fun, args_flat, in_axes_flat, axis_sizes, in_tree)
|
||||
|
||||
# TODO: The order of maps should be derived from the schedule, not from the
|
||||
# resource env. This doesn't really matter for as long as we only support
|
||||
# vectorization and parallelization, but will be important for sequential.
|
||||
# We should be able to do that by building a graph of dependencies between
|
||||
# resources based on the order in which they appear within each axis.
|
||||
# If it has cycles then we cannot realize it. Otherwise, if the DAG doesn't
|
||||
# uniquely identify a linear order, we should use the order of entries in
|
||||
# the schedule to break ties.
|
||||
# Technically the order doesn't matter right now, but we use the ordered dict
|
||||
# to at least limit the amount of non-determinism in this code.
|
||||
fake_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
|
||||
physical_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
|
||||
vectorized: Dict[AxisName, ResourceAxisName] = OrderedDict()
|
||||
|
||||
axis_subst: Dict[AxisName, List[ResourceAxisName]] = {}
|
||||
for axis, resource in schedule:
|
||||
if resource == 'vectorize':
|
||||
assert axis not in vectorized
|
||||
resource = fresh_resource_name()
|
||||
vectorized[axis] = resource
|
||||
elif resource in resource_env.physical_resource_axes:
|
||||
# TODO: Make sure that axis was not in the set?
|
||||
physical_resource_map.setdefault(resource, set()).add(axis)
|
||||
elif resource in resource_env.fake_resource_axes:
|
||||
# TODO: Make sure that axis was not in the set?
|
||||
fake_resource_map.setdefault(resource, set()).add(axis)
|
||||
else:
|
||||
raise ValueError(f"Mapping axis {axis} to an undefined resource axis {resource}. "
|
||||
f"The resource axes currently in scope are: {resource_env.resource_axes}")
|
||||
axis_subst.setdefault(axis, []).append(resource)
|
||||
|
||||
axis_subst_t = {axis: tuple(resources) for axis, resources in axis_subst.items()}
|
||||
jaxpr = jaxpr.map_jaxpr(partial(subst_axis_names, axis_subst=axis_subst_t))
|
||||
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f = hide_mapped_axes(f, in_axes_flat, out_axes_flat)
|
||||
|
||||
for naxis, raxis in vectorized.items():
|
||||
map_in_axes = map(lambda spec: spec.get(naxis, None), in_axes_flat)
|
||||
map_out_axes = map(lambda spec: spec.get(naxis, None), out_axes_flat)
|
||||
f = vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=raxis)
|
||||
|
||||
resource_env_shape = resource_env.shape
|
||||
for raxis, naxes in fake_resource_map.items():
|
||||
map_in_axes = map(lambda spec: lookup_exactly_one_of(spec, naxes), in_axes_flat)
|
||||
map_out_axes = map(lambda spec: lookup_exactly_one_of(spec, naxes), out_axes_flat)
|
||||
map_size = resource_env_shape[raxis]
|
||||
f = vtile(f, map_in_axes, map_out_axes, tile_size=map_size, axis_name=raxis)
|
||||
|
||||
if physical_resource_map:
|
||||
submesh = resource_env.physical_mesh[tuple(physical_resource_map.keys())]
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
|
||||
def to_mesh_axes(axes):
|
||||
mesh_axes = {}
|
||||
for paxis, naxes in physical_resource_map.items():
|
||||
axis = lookup_exactly_one_of(axes, naxes)
|
||||
if axis is None:
|
||||
continue
|
||||
mesh_axes[paxis] = axis
|
||||
return mesh_axes
|
||||
mesh_in_axes = map(to_mesh_axes, in_axes)
|
||||
mesh_out_axes = map(to_mesh_axes, out_axes)
|
||||
f = pxla.mesh_tiled_callable(*in_avals, fun=f,
|
||||
transformed_name=f.__name__,
|
||||
backend_name=None,
|
||||
mesh=submesh,
|
||||
in_axes=mesh_in_axes,
|
||||
out_axes_thunk=lambda: mesh_out_axes)
|
||||
flat_out = f(*args_flat)
|
||||
else:
|
||||
flat_out = f.call_wrapped(*args_flat)
|
||||
|
||||
return tree_unflatten(out_tree, flat_out)
|
||||
out_flat = xmap_p.bind(
|
||||
fun_flat, *args_flat,
|
||||
in_axes=tuple(in_axes_flat),
|
||||
out_axes=tuple(out_axes_flat),
|
||||
axis_sizes=FrozenDict(axis_sizes),
|
||||
schedule=frozen_schedule,
|
||||
resource_env=resource_env,
|
||||
backend=backend)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
return fun_mapped
|
||||
|
||||
def xmap_impl(fun: lu.WrappedFun, *args, in_axes, out_axes, axis_sizes, schedule, resource_env, backend):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
return make_xmap_callable(fun, in_axes, out_axes, axis_sizes, schedule,
|
||||
resource_env, backend, *in_avals)(*args)
|
||||
|
||||
@lu.cache
|
||||
def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_axes, out_axes, axis_sizes,
|
||||
schedule, resource_env, backend,
|
||||
*in_avals):
|
||||
mapped_in_avals = [_delete_aval_axes(aval, in_axes)
|
||||
for aval, in_axes in zip(in_avals, in_axes)]
|
||||
with core.extend_axis_env_nd(axis_sizes.items()):
|
||||
raw_jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
|
||||
jaxpr = core.ClosedJaxpr(raw_jaxpr, consts)
|
||||
|
||||
# TODO: The order of maps should be derived from the schedule, not from the
|
||||
# resource env. This doesn't really matter for as long as we only support
|
||||
# vectorization and parallelization, but will be important for sequential.
|
||||
# We should be able to do that by building a graph of dependencies between
|
||||
# resources based on the order in which they appear within each axis.
|
||||
# If it has cycles then we cannot realize it. Otherwise, if the DAG doesn't
|
||||
# uniquely identify a linear order, we should use the order of entries in
|
||||
# the schedule to break ties.
|
||||
# Technically the order doesn't matter right now, but we use the ordered dict
|
||||
# to at least limit the amount of non-determinism in this code.
|
||||
fake_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
|
||||
physical_resource_map: Dict[ResourceAxisName, Set[AxisName]] = OrderedDict()
|
||||
vectorized: Dict[AxisName, ResourceAxisName] = OrderedDict()
|
||||
|
||||
axis_subst: Dict[AxisName, List[ResourceAxisName]] = {}
|
||||
for axis, resource in schedule:
|
||||
if resource == 'vectorize':
|
||||
assert axis not in vectorized
|
||||
resource = fresh_resource_name()
|
||||
vectorized[axis] = resource
|
||||
elif resource in resource_env.physical_resource_axes:
|
||||
physical_resource_map.setdefault(resource, set()).add(axis)
|
||||
elif resource in resource_env.fake_resource_axes:
|
||||
fake_resource_map.setdefault(resource, set()).add(axis)
|
||||
else:
|
||||
raise ValueError(f"Mapping axis {axis} to an undefined resource axis {resource}. "
|
||||
f"The resource axes currently in scope are: {resource_env.resource_axes}")
|
||||
axis_subst.setdefault(axis, []).append(resource)
|
||||
|
||||
axis_subst_t = {axis: tuple(resources) for axis, resources in axis_subst.items()}
|
||||
jaxpr = jaxpr.map_jaxpr(partial(subst_axis_names, axis_subst=axis_subst_t))
|
||||
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f = hide_mapped_axes(f, tuple(in_axes), tuple(out_axes))
|
||||
|
||||
for naxis, raxis in vectorized.items():
|
||||
map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
|
||||
map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
|
||||
f = vtile(f, map_in_axes, map_out_axes, tile_size=None, axis_name=raxis)
|
||||
|
||||
resource_env_shape = resource_env.shape
|
||||
for raxis, naxes in fake_resource_map.items():
|
||||
map_in_axes = tuple(unsafe_map(lambda spec: lookup_exactly_one_of(spec, naxes), in_axes))
|
||||
map_out_axes = tuple(unsafe_map(lambda spec: lookup_exactly_one_of(spec, naxes), out_axes))
|
||||
map_size = resource_env_shape[raxis]
|
||||
f = vtile(f, map_in_axes, map_out_axes, tile_size=map_size, axis_name=raxis)
|
||||
|
||||
if physical_resource_map:
|
||||
submesh = resource_env.physical_mesh[tuple(physical_resource_map.keys())]
|
||||
def to_mesh_axes(axes):
|
||||
mesh_axes = {}
|
||||
for paxis, naxes in physical_resource_map.items():
|
||||
axis = lookup_exactly_one_of(axes, naxes)
|
||||
if axis is None:
|
||||
continue
|
||||
mesh_axes[paxis] = axis
|
||||
return A(mesh_axes)
|
||||
mesh_in_axes = tuple(unsafe_map(to_mesh_axes, in_axes))
|
||||
mesh_out_axes = tuple(unsafe_map(to_mesh_axes, out_axes))
|
||||
return pxla.mesh_tiled_callable(f,
|
||||
f.__name__,
|
||||
backend,
|
||||
submesh,
|
||||
mesh_in_axes,
|
||||
mesh_out_axes,
|
||||
*in_avals)
|
||||
else:
|
||||
return f.call_wrapped
|
||||
|
||||
# xmap has a different set of parameters than pmap, so we make it its own primitive type
|
||||
class XMapPrimitive(core.Primitive):
|
||||
multiple_results = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('xmap')
|
||||
self.def_impl(xmap_impl)
|
||||
self.def_custom_bind(partial(core.call_bind, self))
|
||||
|
||||
def bind(self, fun, *args, **params):
|
||||
assert len(params['in_axes']) == len(args)
|
||||
return core.call_bind(self, fun, *args, **params)
|
||||
|
||||
def process(self, trace, fun, tracers, params):
|
||||
return trace.process_xmap(self, fun, tracers, params)
|
||||
|
||||
def post_process(self, trace, out_tracers, params):
|
||||
raise NotImplementedError
|
||||
return trace.post_process_xmap(self, out_tracers, params)
|
||||
|
||||
xmap_p = XMapPrimitive()
|
||||
core.EvalTrace.process_xmap = core.EvalTrace.process_call # type: ignore
|
||||
|
||||
def _delete_aval_axes(aval, axes: AxisNamePos):
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
shape = list(aval.shape)
|
||||
@ -222,19 +346,7 @@ def _delete_aval_axes(aval, axes: AxisNamePos):
|
||||
del shape[i]
|
||||
return core.ShapedArray(tuple(shape), aval.dtype)
|
||||
|
||||
def _trace_mapped_jaxpr(fun,
|
||||
args_flat,
|
||||
in_axes_flat: List[AxisNamePos],
|
||||
axis_sizes: Dict[AxisName, int],
|
||||
in_tree):
|
||||
fun_flat, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
|
||||
mapped_avals = [_delete_aval_axes(aval, in_axes)
|
||||
for aval, in_axes in zip(avals_flat, in_axes_flat)]
|
||||
with core.extend_axis_env_nd(axis_sizes.items()):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun_flat, mapped_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts), out_tree()
|
||||
|
||||
# TODO: pmap has some very fancy error messages for this function!
|
||||
def _get_axis_sizes(args_flat: Iterable[Any], in_axes_flat: Iterable[AxisNamePos]):
|
||||
axis_sizes: Dict[AxisName, int] = {}
|
||||
for arg, in_axes in zip(args_flat, in_axes_flat):
|
||||
@ -286,7 +398,10 @@ def untile_axis(out, axis: Optional[int]):
|
||||
return out.reshape(shape)
|
||||
|
||||
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
|
||||
def vtile(f_flat, in_axes_flat, out_axes_flat, tile_size: Optional[int], axis_name):
|
||||
def vtile(f_flat,
|
||||
in_axes_flat: Tuple[AxisNamePos, ...],
|
||||
out_axes_flat: Tuple[AxisNamePos, ...],
|
||||
tile_size: Optional[int], axis_name):
|
||||
@lu.transformation
|
||||
def _map_to_tile(*args_flat):
|
||||
real_tile_size = tile_size
|
||||
|
@ -1223,14 +1223,13 @@ def untile_aval_nd(axis_sizes, out_axes: AxisNameMap, aval):
|
||||
shape[axis] *= axis_sizes[name]
|
||||
return ShapedArray(tuple(shape), aval.dtype)
|
||||
|
||||
# TODO(apaszke): Cache compilation
|
||||
def mesh_tiled_callable(*in_avals,
|
||||
fun: lu.WrappedFun,
|
||||
def mesh_tiled_callable(fun: lu.WrappedFun,
|
||||
transformed_name: str,
|
||||
backend_name: Optional[str],
|
||||
mesh: Mesh,
|
||||
in_axes: Sequence[AxisNameMap],
|
||||
out_axes_thunk: Callable[[], Sequence[AxisNameMap]]):
|
||||
out_axes: Sequence[AxisNameMap],
|
||||
*in_avals):
|
||||
assert config.omnistaging_enabled
|
||||
local_mesh = mesh.local_mesh
|
||||
|
||||
@ -1244,7 +1243,6 @@ def mesh_tiled_callable(*in_avals,
|
||||
for aval_in_axes, aval in zip(in_axes, in_avals))
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, sharded_avals)
|
||||
out_axes = out_axes_thunk()
|
||||
assert len(out_axes) == len(out_sharded_avals)
|
||||
# TODO(apaszke): What about outfeed?
|
||||
# jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
@ -31,7 +31,7 @@ from jax import vmap
|
||||
from jax import lax
|
||||
from jax.experimental.general_map import gmap, fake_resources, Mesh, mesh, xmap, A
|
||||
from jax.lib import xla_bridge
|
||||
from jax.util import curry
|
||||
from jax.util import curry, unzip2
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -86,6 +86,19 @@ def check_default_schedules(cond, fun):
|
||||
{"testcase_name": "_" + name, "schedule": schedule}
|
||||
for name, schedule in schedules)(fun)
|
||||
|
||||
@curry
|
||||
def with_mesh(named_shape, f):
|
||||
def new_f(*args, **kwargs):
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = np.prod(shape)
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise SkipTest(f"Test requires {size} local devices")
|
||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||
with mesh(mesh_devices, axis_names):
|
||||
return f(*args, **kwargs)
|
||||
return new_f
|
||||
|
||||
|
||||
class GmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ -227,5 +240,22 @@ class GmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(c, (a * 2).sum(0))
|
||||
self.assertAllClose(d, b * 4)
|
||||
|
||||
@ignore_gmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testXMapCompilationCache(self):
|
||||
def f(x):
|
||||
assert python_should_be_executing
|
||||
return x * 2
|
||||
fm = xmap(f,
|
||||
in_axes=[A({'a': 0})],
|
||||
out_axes=[A({'a': 0})],
|
||||
schedule=[('a', 'x'), ('a', 'vectorize')])
|
||||
x = np.arange(8).reshape((2, 2, 2))
|
||||
python_should_be_executing = True
|
||||
fm(x)
|
||||
python_should_be_executing = False
|
||||
fm(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user