Cleaning up eager pmap implementation

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Sharad Vikram 2022-08-11 10:49:56 -07:00
parent a7f760d9ed
commit fe040cc01e
5 changed files with 115 additions and 353 deletions

View File

@ -113,12 +113,12 @@ zip, unsafe_zip = safe_zip, zip
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", False),
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
"A flag enabling the C++ jax.jit fast path."
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False),
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")

View File

@ -897,6 +897,13 @@ config.define_bool_state(
default=False,
help='Enable using a cache for lowering subjaxprs.')
# TODO(sharadmv,mattjj): set default to True, then remove
config.define_bool_state(
name='jax_eager_pmap',
default=False,
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""

View File

@ -907,7 +907,9 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
if config.jax_disable_jit:
if (config.jax_disable_jit and config.jax_eager_pmap and
global_axis_size is None and not any(d for d in donated_invars) and
not all(g is not None for g in global_arg_shapes)):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
@ -929,6 +931,10 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
("fingerprint", fingerprint))
return compiled_fun(*args)
class EmapInfo(NamedTuple):
backend: Optional[str]
devices: Optional[Sequence[Any]]
def _emap_impl(fun: lu.WrappedFun, *args,
backend: Optional[str],
axis_name: core.AxisName,
@ -940,76 +946,64 @@ def _emap_impl(fun: lu.WrappedFun, *args,
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
if global_axis_size is not None: raise NotImplementedError
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size,
global_axis_size, devices, None, None, ())
# if devices is not None:
# if len(devices) == 0:
# raise ValueError("'devices' argument to pmap must be non-empty, or None.")
# if len(devices) != axis_size:
# raise ValueError(
# f"Leading axis size of input to pmapped function must equal the "
# f"number of local devices passed to pmap. Got axis_size="
# f"{axis_size}, num_local_devices={len(devices)}.")
# else:
# devices = xb.devices(backend=backend)[:axis_size]
# if len(devices) != axis_size:
# msg = ("compiling computation that requires {} logical devices, but only {} XLA "
# "devices are available (num_replicas={}, num_partitions={})")
# raise ValueError(msg.format(axis_size,
# xb.device_count(backend),
# None,
# None))
sharded_args = []
shard_axes = []
for arg, in_axis in zip(args, in_axes):
if in_axis == 0:
sharded_args.append(arg)
shard_axes.append({axis_name: 0})
elif in_axis is None:
sharded_args.append(arg)
shard_axes.append({})
else:
perm = list(range(arg.ndim))
a = perm.pop(in_axis)
perm.insert(0, a)
sharded_args.append(arg.transpose(perm))
shard_axes.append({axis_name: 0})
with core.new_base_main(MapTrace, pci=pci) as main:
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if any(g is not None for g in global_arg_shapes):
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
if global_axis_size is not None:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
emap_info = EmapInfo(backend, devices)
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
t = main.with_cur_sublevel()
tracers = [
MapTracer(t, arg, s) for arg, s in zip(sharded_args, shard_axes)]
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
ans = fun.call_wrapped(*tracers)
out_tracers = map(t.full_raise, ans)
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
del main
out_axes = out_axes_thunk()
# This next bit is like matchaxis in batching.py (for the end of a vmap)
new_outvals = []
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax._src.config.disable_jit(False):
with jax.disable_jit(False):
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis)(np.arange(axis_size), outval)
out_axes=out_axis, devices=devices, backend=backend)(
np.arange(axis_size), outval)
new_outvals.append(out)
return new_outvals
def _map_schedule(idx: Tuple[Optional[int], ...]) -> List[Optional[int]]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
return [None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx)]
def _map_indices_to_map_schedule(idx: Tuple[Optional[int], ...]):
return tuple(None if i is None else i - sum(j is not None and j < i for j in idx[:l]) for l, i in enumerate(idx))
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
all_axes: List[Tuple[Optional[int], ...]]
) -> Tuple[Callable, Dict[core.AxisName, int]]:
used_names = []
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
backend=info.backend, devices=info.devices)
used_names.append(name)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes
class MapTrace(core.Trace):
def __init__(self, *args, pci):
def __init__(self, *args, emap_info):
super().__init__(*args)
self.pci = pci
def _get_frames(self):
frames = [f for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main]
return frames
self.emap_info = emap_info
def pure(self, val):
return MapTracer(self, val, {})
@ -1018,24 +1012,15 @@ class MapTrace(core.Trace):
return MapTracer(self, tracer.val, tracer.shard_axes)
def process_primitive(self, primitive, tracers, params):
vals = [t.val for t in tracers]
names = [f.name for f in self._get_frames()]
f = lambda *args: primitive.bind(*args, **params)
used_names = []
all_axes = []
for t in tracers:
arg_axes = tuple(t.shard_axes.get(name, None) for name in names)
arg_axes = _map_indices_to_map_schedule(arg_axes)
all_axes.append(arg_axes)
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(f, in_axes=in_axes, axis_name=name,
devices=self.main.payload['pci'].devices)
used_names.append(name)
with core.eval_context(), jax._src.config.disable_jit(False):
outvals = f(*vals)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
info = self.main.payload["emap_info"]
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = [f.name for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main]
all_axes = [_map_schedule(map(s.get, names)) for s in shard_axes]
f_mapped, out_shard_axes = _multi_pmap(partial(primitive.bind, **params),
info, names, all_axes)
with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals)
if primitive.multiple_results:
return [MapTracer(self, val, out_shard_axes) for val in outvals]
return MapTracer(self, outvals, out_shard_axes)
@ -1049,55 +1034,26 @@ class MapTrace(core.Trace):
def process_map(self, call_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if config.jax_disable_jit:
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
invals = [t.val for t in tracers]
shard_axes_src = [t.shard_axes for t in tracers]
shard_axes = []
for inval, in_axis, shard_axis_src in zip(invals, in_axes, shard_axes_src):
new_shard_axis_src = dict(shard_axis_src)
if in_axis is not None:
idx = [i for i in range(inval.ndim) if i not in shard_axis_src.values()]
new_idx = idx[in_axis]
new_shard_axis_src = {axis_name: new_idx, **shard_axis_src}
shard_axes.append(new_shard_axis_src)
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = [MapTracer(t, val, shard_axis) for val, shard_axis in
zip(invals, shard_axes)]
ans = fun.call_wrapped(*in_tracers)
out_axes_dest = out_axes_thunk()
out_tracers = map(t.full_raise, ans)
outvals, shard_axes_src = util.unzip2([(t.val, t.shard_axes) for t in
out_tracers])
new_out_tracers = map(
partial(self._make_output_tracer, axis_name, axis_size), outvals,
shard_axes_src, out_axes_dest)
return new_out_tracers
else:
if not config.jax_disable_jit:
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(call_primitive.bind, fun))
return self.process_primitive(fake_primitive, tracers, params)
def _make_output_tracer(self, axis_name, axis_size, val, shard_axis_src,
dst_annotation):
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = annotation_to_flat(np.ndim(val), shard_axis_out.values(),
src, dst_annotation)
with core.eval_context():
if src == dst:
outval = val
elif type(src) == type(dst) == int:
outval = batching.moveaxis(val, src, dst)
shard_axis_out = moveaxis(np.ndim(val), shard_axis_src, src, dst)
elif src is None and dst is not None:
outval = batching.broadcast(val, axis_size, dst)
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
else:
assert False
return MapTracer(self, outval, shard_axis_out)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)]
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes)
def process_axis_index(self, frame):
fake_primitive = types.SimpleNamespace(
@ -1107,19 +1063,35 @@ class MapTrace(core.Trace):
dummy_tracer = MapTracer(self, range, {frame.name: 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
def annotation_to_flat(ndim: int, mapped_axes: Sequence[int],
src_flat: Optional[int], dst_annotation: Optional[int]
) -> Optional[int]:
if dst_annotation is None:
return None
ndim_ = ndim - len(mapped_axes) + (src_flat is None)
dst_annotation = batching.canonicalize_axis(dst_annotation, ndim_)
idx = [i for i in range(ndim + (src_flat is None)) if i not in mapped_axes]
out = idx[dst_annotation]
return out
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: Optional[int]) -> Optional[int]:
if annotation is None: return None
mapped_axes_ = set(mapped_axes)
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
def moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
src: int, dst: int):
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
shard_axis_src: Dict[core.AxisName, int],
dst_annotation: Optional[int]
) -> Tuple[Any, Dict[core.AxisName, int]]:
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
dst_annotation)
with core.eval_context():
if src == dst:
outval = val
elif type(src) == type(dst) == int:
outval = batching.moveaxis(val, src, dst)
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
elif src is None and dst is not None:
outval = batching.broadcast(val, axis_size, dst)
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
else:
raise NotImplementedError
return outval, shard_axis_out
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
src: int, dst: int) -> Dict[core.AxisName, int]:
lst: List[Optional[core.AxisName]] = [None] * ndim
for k, v in shard_axes.items():
lst[v] = k

188
matt.py
View File

@ -1,188 +0,0 @@
import pdb, sys, traceback
def info(type, value, tb):
traceback.print_exception(type, value, tb)
pdb.pm()
sys.excepthook = info
import os
import functools
import re
import jax
import jax.numpy as jnp
from jax.config import config
import numpy as np
config.update("jax_new_checkpoint", True)
config.update("jax_traceback_filtering", "off")
config.update("jax_platform_name", "cpu")
def set_host_device_count(n):
xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
).split()
os.environ["XLA_FLAGS"] = " ".join(
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
)
set_host_device_count(8)
# @functools.partial(jax.pmap, axis_name="foo")
# def f(x):
# def cond(i):
# return i < 10
# def body(i):
# return i + 1
# return jax.lax.while_loop(cond, body, x)
# @functools.partial(jax.pmap, axis_name="foo")
# def f(x):
# def cond(i):
# return i < 10
# def body(i):
# return i + 1
# return jax.lax.while_loop(cond, body, x)
# print(f(jnp.arange(4.)))
# with jax.disable_jit():
# print(f(jnp.arange(4.)))
# @jax.pmap
# def g(x):
# with jax._src.config.disable_jit(False):
# return jax.jit(jnp.sin)(x)
# print(g(jnp.arange(4.)))
# with jax.disable_jit():
# print(g(jnp.arange(4.)))
# @functools.partial(jax.pmap, in_axes=(0, None, 0, None), axis_name='i')
# @functools.partial(jax.pmap, in_axes=(None, 0, 0, None), axis_name='j')
# def f(x, y, z, w):
# return jax.lax.axis_index(['i', 'j']) + x * y + z + w
# print(f(jnp.arange(4.), jnp.arange(2.), jnp.arange(8.).reshape((4, 2)), 100.))
# with jax.disable_jit():
# print(f(jnp.arange(4.), jnp.arange(2.), jnp.arange(8.).reshape((4, 2)), 100.))
device_count = jax.device_count()
# @functools.partial(jax.pmap, axis_name='i')
# def f(x):
# @functools.partial(jax.pmap, axis_name='j')
# def g(y):
# a = jax.lax.psum(1, 'i')
# b = jax.lax.psum(1, 'j')
# c = jax.lax.psum(1, ('i', 'j'))
# return a, b, c
# return g(x)
# import numpy as np
# shape = (device_count, 1, 4)
# x = jnp.arange(np.prod(shape)).reshape(shape)
# a, b, c = f(x)
# print(a)
# print(b)
# print(c)
# with jax.disable_jit():
# a, b, c = f(x)
# print(a)
# print(b)
# print(c)
# f = lambda axis: jax.pmap(jax.pmap(lambda x: x + jax.lax.axis_index(axis), 'j'), 'i')
# x = jnp.ones((2, 2), dtype='int32')
# print(f('i')(x))
# print(f('j')(x))
# with jax.disable_jit():
# print(f('i')(x))
# print(f('j')(x))
# def f(key):
# key = jax.random.fold_in(key, jax.lax.axis_index('i'))
# return jax.random.bernoulli(key, p=0.5)
# keys = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
# print(jax.pmap(jax.remat(f), axis_name='i')(keys))
# with jax.disable_jit():
# print(jax.pmap(jax.remat(f), axis_name='i')(keys))
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
# with jax.disable_jit():
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
# with jax.disable_jit():
# jax.pmap(jax.pmap(jnp.square))(jnp.arange(16).reshape((4, 4)))
# f = jax.pmap(lambda x: jax.pmap(lambda x: x)(x))
# x = jnp.ones((jax.device_count(), 2, 10))
# f(x)
# with jax.disable_jit():
# print(f(x))
f = jax.pmap(jax.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2, 3)
x = jnp.arange(np.prod(shape)).reshape(shape)
print(f(x))
with jax.disable_jit():
print(f(x))
# TODO:
# * [x] process_call
# * jit-of-emap = pmap (already, b/c we're only changing the pmap impl)
# * emap-of-jit = our processs_call rule should act same as initial style HOPs
# * emap-of-core.call = do a subtrace like thing where we turn around and stay
# in python
# * [ ] collectives
# * [ ] testing
# * [ ] nesting (process_map, sublift, etc)
# * [ ] shadowing of names
# * delete process_call and core.call, just have an xla_call rule
# * no call updaters!
# * blocked on delete old remat
# * first just make xla_call have its own rule
# * then make it INITIAL STYLE
# * make it take closed jaxprs, so we can delete core.closed_call
# * delete all updaters
# brainstorming process_map
# assert map_primitive is xla_pmap_p
# backend, axis_size, axis_name = (
# params['backend'], params['axis_size'], params['axis_name'])
# if config.jax_disable_jit:
# shape = [f.size for f in self._get_frames()]
# devices = xb.devices(backend=backend)[:prod(shape) * axis_size]
# breakpoint()
# # def reshard(x: jnp.ndarray, devices: Array[devices]):
# # assert x.ndim == devices.ndim
# # e.g . x.shape = (4, 3, 2)
# # devices.shape = (4, 1, 2)
# # reshard(x, devices.reshape(4, 1, 2))
# sharded_args = [jax.device_put_sharded(list(x), devices) for x in args]
# with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
# t = main.with_cur_sublevel()
# shard_axes = {axis_name: 0}
# tracers = [MapTracer(t, arg, shard_axes) for arg in sharded_args]
# ans = fun.call_wrapped(*tracers)
# out_tracers = map(t.full_raise, ans)
# outvals = [t.val for t in out_tracers]
# return outvals
# else:
# breakpoint()

View File

@ -727,39 +727,6 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
@parameterized.named_parameters(
{"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""),
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 2)])
def testNestedWithClosure2(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
@partial(self.pmap, axis_name='i')
def test_fun(x):
y = jnp.sum(jnp.sin(x))
@partial(self.pmap, axis_name='j')
def g(z):
return jnp.exp(x.sum())
return grad(lambda w: jnp.sum(g(w)))(x)
@vmap
def baseline_fun(x):
y = jnp.sum(jnp.sin(x))
@vmap
def g(z):
return jnp.exp(x.sum())
return grad(lambda w: jnp.sum(g(w)))(x)
shape = mesh_shape
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
print(x.shape)
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
def testShardedDeviceArrays(self):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
@ -3087,8 +3054,6 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.assertArraysEqual(out2, input_data)
def test_pmap_array_in_axes_out_axes(self):
if config.jax_disable_jit:
raise SkipTest("test hangs with disable_jit") # TODO(mattjj,sharadmv,yashkatariya)
dc = jax.device_count()
input_shape = (dc, 2)
a1, input_data = create_input_array_for_pmap(input_shape, in_axes=0)
@ -3180,14 +3145,17 @@ class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
pass
class EagerPmapMixin: # i hate mixins
class EagerPmapMixin:
def setUp(self):
super().setUp()
self.eager_pmap_enabled = config.jax_eager_pmap
self.jit_disabled = config.jax_disable_jit
config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True)
def tearDown(self):
config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()
@ -3203,6 +3171,9 @@ class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest):
class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest):
pass
class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest):
pass
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())