Merge pull request #5527 from google:remove-soft-pmap

PiperOrigin-RevId: 354328027
This commit is contained in:
jax authors 2021-01-28 09:29:57 -08:00
commit 543dcb37e3
8 changed files with 55 additions and 232 deletions

View File

@ -42,7 +42,7 @@ pytype_library(
"*_test.py",
"**/*_test.py",
],
),
) + ["experimental/maps.py"], # until xmap is moved out of experimental
srcs_version = "PY3",
deps = [
"//third_party/py/jax/jaxlib:_pocketfft",

View File

@ -69,7 +69,6 @@ from .api import (
shapecheck,
ShapedArray,
ShapeDtypeStruct,
soft_pmap,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
@ -86,6 +85,7 @@ from .api import (
xla, # TODO(phawkins): update users to avoid this.
xla_computation,
)
from .experimental.maps import soft_pmap
from .version import __version__
# These submodules are separate because they are in an import cycle with

View File

@ -484,15 +484,6 @@ class XeinsumSpecParser:
### parallel primitives
def _allreduce_soft_pmap_rule(prim, reducer, vals, mapped, chunk_size,
*, axis_name, axis_index_groups):
if axis_index_groups is not None:
raise NotImplementedError("soft_pmap does not yet support axis_index_groups")
reduced_vals = [reducer(x, [0]) if m else x for x, m in zip(vals, mapped)]
outs = prim.bind(*reduced_vals, axis_name=axis_name,
axis_index_groups=axis_index_groups)
return outs, (False,) * len(vals)
# This is only used for collectives that do not include the vmapped axis name,
# which is why the rule is so simple.
def _collective_batcher(prim, args, dims, **params):
@ -593,8 +584,6 @@ def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
psum_p = core.Primitive('psum')
psum_p.multiple_results = True
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.soft_pmap_rules[psum_p] = \
partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
ad.deflinear2(psum_p, _psum_transpose_rule)
pxla.multi_host_supported_collectives.add(psum_p)
@ -957,14 +946,8 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
assert not vals and not mapped
idx = axis_index(axis_name) # type: ignore
return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True
axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), np.int32))
pxla.multi_host_supported_collectives.add(axis_index_p)
@ -991,6 +974,7 @@ def _axis_index_bind(*, axis_name):
axis_index_p.def_custom_bind(_axis_index_bind)
def _process_axis_index(self, frame):
assert frame.size is not None
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore

View File

@ -1577,37 +1577,6 @@ def pmap(
return f_pmapped
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
if not config.omnistaging_enabled:
raise NotImplementedError("soft_pmap requires omnistaging.")
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
if any(axis != 0 for axis in tree_leaves(in_axes)):
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
@wraps(fun)
@api_boundary
def f_pmapped(*args, **kwargs):
f = lu.wrap_init(fun)
args_flat, in_tree = tree_flatten((args, kwargs))
in_axes_flat = flatten_axes("soft_pmap in_axes", in_tree, (in_axes, 0))
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
# See note about out_axes_thunk in pmap for the explanation of why we choose this key
out_axes_thunk = HashableFunction(
lambda: tuple(flatten_axes("soft_pmap out_axes", out_tree(), 0)),
closure=())
outs = pxla.soft_pmap(flat_fun, *args_flat, axis_name=axis_name,
axis_size=axis_size, in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk)
return tree_unflatten(out_tree(), outs)
return f_pmapped
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
_check_callable(fun)
unique_ids = masking.UniqueIds()

View File

@ -26,7 +26,8 @@ from .. import numpy as jnp
from .. import core
from .. import linear_util as lu
from ..api import _check_callable, _check_arg
from ..tree_util import tree_flatten, tree_unflatten, all_leaves
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves,
_replace_nones, tree_map, tree_leaves)
from ..api_util import flatten_fun_nokwargs, flatten_axes
from ..interpreters import partial_eval as pe
from ..interpreters import pxla
@ -839,3 +840,27 @@ def subst_eqn_axis_names(eqn, axis_subst: Dict[AxisName, Tuple[AxisName]]):
axis_names = (axis_names,)
new_axis_names = sum((axis_subst.get(name, (name,)) for name in axis_names), ())
return eqn._replace(params=dict(eqn.params, axis_name=new_axis_names))
# -------- soft_pmap --------
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
) -> Callable:
warn("soft_pmap is an experimental feature and probably has bugs!")
_check_callable(fun)
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
if any(axis != 0 for axis in tree_leaves(in_axes)):
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
proxy = object()
in_axes = _replace_nones(proxy, in_axes)
in_axes = tree_map(lambda i: {i: axis_name} if i is not proxy else {}, in_axes)
@wraps(fun)
def f_pmapped(*args, **kwargs):
mesh_devices = np.array(xb.local_devices())
with mesh(mesh_devices, ['devices']):
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
axis_resources={axis_name: 'devices'})(*args, **kwargs)
return f_pmapped

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence
import jax
from ..config import config
@ -28,16 +28,24 @@ from . import partial_eval as pe
map = safe_map
def batch(fun: lu.WrappedFun, axis_name, axis_size, in_dims, out_dim_dests,
BatchDim = Optional[int]
BatchDims = Sequence[BatchDim]
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
def batch(fun: lu.WrappedFun, axis_name: core.AxisName,
axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec,
) -> lu.WrappedFun:
# anlogue of `jvp` in ad.py
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
fun, out_dims_thunk = batch_subtrace(fun)
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
axis_size, out_dims_thunk, out_dim_dests)
axis_size, in_dims, out_dims_thunk, out_dim_dests)
@lu.transformation
def batchfun(axis_name, axis_size, in_dims, *in_vals):
# analogue of `jvpfun` in ad.py
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
and not isinstance(core.get_aval(x), core.AbstractUnit) # non-omnistaging
@ -60,7 +68,9 @@ def batch_subtrace(main, in_dims, *in_vals):
yield out_vals, out_dims
@lu.transformation
def _match_axes(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
if axis_size is None:
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
out_vals = yield in_vals, {}
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_dims = out_dims_thunk()

View File

@ -45,14 +45,13 @@ from .. import core
from .. import linear_util as lu
from .. import lazy
from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray, Var, Literal
from ..core import ConcreteArray, ShapedArray
from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, taggedtuple, curry)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..tree_util import tree_flatten, tree_map
from .batching import broadcast, not_mapped, moveaxis
from . import batching
from . import partial_eval as pe
from . import xla
@ -1563,7 +1562,6 @@ def vtile(f_flat,
_forbidden_primitives = {
'xla_pmap': 'pmap',
'soft_pmap': 'soft_pmap',
'sharded_call': 'sharded_jit',
}
def _sanitize_mesh_jaxpr(jaxpr):
@ -1597,166 +1595,6 @@ def mesh_sharding_specs(axis_sizes, axis_names):
return ShardingSpec(sharding, mesh_mapping)
return mk_sharding_spec
# ------------------- soft_pmap -------------------
def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk):
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun = _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk,
*abstract_args)
return compiled_fun(*args)
@lu.cache
def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals):
mapped_avals = [core.mapped_aval(axis_size, in_axis, aval) if in_axis is not None else aval
for in_axis, aval in safe_zip(in_axes, avals)]
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
out_axes = out_axes_thunk()
assert all(out_axis == 0 for out_axis in out_axes)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
num_devices = xb.local_device_count()
chunk_size, ragged = divmod(axis_size, num_devices)
if ragged:
msg = f"number of devices {num_devices} must divide axis size {axis_size}"
raise NotImplementedError(msg)
jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes,
axis_name, axis_size, chunk_size)
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
if jaxpr_replicas != 1: raise NotImplementedError
tuple_args = len(avals) > 100 # pass long arg lists as tuple for TPU
c = xb.make_computation_builder("soft_pmap_{}".format(fun.__name__))
xla_consts = map(partial(xb.constant, c), consts)
chunked_avals = [core.unmapped_aval(chunk_size, in_axis, aval) if in_axis is not None else aval
for in_axis, aval in safe_zip(in_axes, mapped_avals)]
xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args)
axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,))
out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts,
'soft_pmap', *xla_args)
built = c.Build(xops.Tuple(c, out_nodes))
compile_options = xb.get_compile_options(
num_replicas=num_devices, num_partitions=1, device_assignment=None)
compile_options.tuple_arguments = tuple_args
backend = xb.get_backend(None)
compiled = xla.backend_compile(backend, built, compile_options)
input_specs = [
ShardingSpec(
sharding=tuple_insert((_UNSHARDED_INSTANCE,) *
(aval.ndim - 1), in_axis, Chunked(num_devices)),
mesh_mapping=[ShardedAxis(0)])
if in_axis is not None else ShardingSpec(
sharding=[_UNSHARDED_INSTANCE] * aval.ndim,
mesh_mapping=[Replicated(num_devices)])
for aval, in_axis in safe_zip(avals, in_axes)
]
input_indices = [spec and spec_to_indices(aval.shape, spec)
for aval, spec in safe_zip(avals, input_specs)]
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
handle_outs = soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals)
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size):
assert all(in_axis is None or in_axis == 0 for in_axis in in_axes), in_axes
mapped_invars = [in_axis is not None for in_axis in in_axes]
fun = partial(_soft_pmap_interp, chunk_size, jaxpr, consts, mapped_invars)
in_avals = [core.unmapped_aval(chunk_size, in_axis, v.aval) if in_axis is not None else v.aval
for v, in_axis in safe_zip(jaxpr.invars, in_axes)]
with core.extend_axis_env(axis_name, axis_size, None):
return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args):
env: Dict[Var, Tuple[Any, bool]] = {}
def read(atom: Union[Var, Literal]) -> Tuple[Any, bool]:
if isinstance(atom, Literal):
return (atom.val, False)
else:
return env[atom]
def write(v: Var, val: Any, mapped: bool) -> None:
env[v] = (val, mapped)
write(core.unitvar, core.unit, False)
map(write, jaxpr.constvars, consts, (False,) * len(consts))
map(write, jaxpr.invars, args, mapped_invars)
for eqn in jaxpr.eqns:
in_vals, in_mapped = unzip2(map(read, eqn.invars))
if eqn.primitive in xla.parallel_translations:
rule = soft_pmap_rules[eqn.primitive]
out_vals, out_mapped = rule(in_vals, in_mapped, chunk_size, **eqn.params)
if not eqn.primitive.multiple_results:
out_vals, out_mapped = [out_vals], [out_mapped]
elif isinstance(eqn.primitive, core.CallPrimitive):
# we just inline here for convenience
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
out_vals = _soft_pmap_interp(chunk_size, call_jaxpr, (), in_mapped, *in_vals)
out_mapped = [True] * len(out_vals)
elif isinstance(eqn.primitive, core.MapPrimitive):
raise NotImplementedError # TODO
else:
if any(in_mapped):
rule = batching.get_primitive_batcher(eqn.primitive, None)
in_axes = [0 if m else batching.not_mapped for m in in_mapped]
out_vals, out_axes = rule(in_vals, in_axes, **eqn.params)
if not eqn.primitive.multiple_results:
out_vals, out_axes = [out_vals], [out_axes]
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x
for x, d in safe_zip(out_vals, out_axes)]
out_mapped = [d is not not_mapped for d in out_axes]
else:
out_vals = eqn.primitive.bind(*in_vals, **eqn.params)
if not eqn.primitive.multiple_results:
out_vals = [out_vals]
out_mapped = [False for _ in out_vals]
map(write, eqn.outvars, out_vals, out_mapped)
out_vals, out_mapped = unzip2(map(read, jaxpr.outvars))
out_vals = [out if mapped else broadcast(out, chunk_size, 0)
for out, mapped in safe_zip(out_vals, out_mapped)]
return out_vals
# TODO(mattjj): dedup w/ with other aval_to_result_handler via ShardingSpec
def soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals):
nouts = len(out_avals)
handlers = [soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval)
for aval in out_avals]
def handler(out_bufs):
buffers = [[result_to_populate] * num_devices for _ in range(nouts)]
for r, tuple_buf in enumerate(out_bufs):
for i, buf in enumerate(tuple_buf):
buffers[i][r] = buf
assert not any(buf is result_to_populate for bufs in buffers
for buf in bufs)
return [h(bufs) for h, bufs in safe_zip(handlers, buffers)]
return handler
def soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval):
axis_size = chunk_size * num_devices
if aval is core.abstract_unit:
return lambda _: core.unit
elif isinstance(aval, core.ShapedArray):
new_aval = aval.update(shape=(axis_size,) + aval.shape)
spec = ShardingSpec(
sharding=(Chunked(num_devices),) + (_UNSHARDED_INSTANCE,) * aval.ndim,
mesh_mapping=(ShardedAxis(0),))
return lambda bufs: ShardedDeviceArray(new_aval, spec, bufs)
else:
raise TypeError(aval)
soft_pmap_p = core.MapPrimitive('soft_pmap')
soft_pmap = soft_pmap_p.bind
soft_pmap_p.def_impl(soft_pmap_impl)
soft_pmap_rules: Dict[core.Primitive, Callable] = {}
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):

View File

@ -35,8 +35,8 @@ from jax import tree_util
from jax import lax
from jax import random
from jax.core import ShapedArray
from jax.api import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax.lib import xla_bridge
from jax._src.util import prod, safe_map
from jax.interpreters import pxla
@ -100,9 +100,6 @@ def tearDownModule():
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
ignore_soft_pmap_warning = partial(
jtu.ignore_warning, message="soft_pmap is an experimental.*")
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
@ -1239,7 +1236,7 @@ class PmapTest(jtu.JaxTestCase):
self.assertAllClose(r, arr + 1)
self.assertEqual(len(r.device_buffers), 6)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapBatchMatmul(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1249,7 +1246,7 @@ class PmapTest(jtu.JaxTestCase):
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapBatchMatmulJit(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1259,7 +1256,7 @@ class PmapTest(jtu.JaxTestCase):
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapPsumConstant(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1269,7 +1266,7 @@ class PmapTest(jtu.JaxTestCase):
expected = n * np.ones(n)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapPsum(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1279,7 +1276,7 @@ class PmapTest(jtu.JaxTestCase):
expected = np.ones(n) / n
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapAxisIndex(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1289,7 +1286,7 @@ class PmapTest(jtu.JaxTestCase):
expected = 2 * np.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapOfJit(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
@ -1299,7 +1296,7 @@ class PmapTest(jtu.JaxTestCase):
expected = 3 * np.arange(n)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapNested(self):
raise SkipTest("not implemented") # TODO(mattjj): re-implement
n = 4 * xla_bridge.device_count()
@ -1314,7 +1311,7 @@ class PmapTest(jtu.JaxTestCase):
expected = np.arange(n ** 2).reshape(n, n).T
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testGradOfSoftPmap(self):
raise SkipTest("not implemented") # TODO(mattjj): re-implement
n = 4 * xla_bridge.device_count()
@ -1327,7 +1324,7 @@ class PmapTest(jtu.JaxTestCase):
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_soft_pmap_warning()
@ignore_xmap_warning()
def testSoftPmapDevicePersistence(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
device_count = xla_bridge.device_count()