Merge pull request #11854 from mattjj:sharads-cool-stuff

PiperOrigin-RevId: 467766379
This commit is contained in:
jax authors 2022-08-15 14:47:29 -07:00
commit b75969c5a1
5 changed files with 272 additions and 18 deletions

View File

@ -660,7 +660,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
@contextmanager
def disable_jit():
def disable_jit(disable: bool = True):
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
For debugging it is useful to have a mechanism that disables :py:func:`jit`
@ -704,7 +704,7 @@ def disable_jit():
Value of y is [2 4 6]
[5 7 9]
"""
with _disable_jit(True):
with _disable_jit(disable):
yield

View File

@ -897,6 +897,13 @@ config.define_bool_state(
default=False,
help='Enable using a cache for lowering subjaxprs.')
# TODO(sharadmv,mattjj): set default to True, then remove
config.define_bool_state(
name='jax_eager_pmap',
default=False,
upgrade=True,
help='Enable eager-mode pmap when jax_disable_jit is activated.')
@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""

View File

@ -914,10 +914,11 @@ def new_main(trace_type: Type[Trace],
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
def new_base_main(trace_type: Type[Trace],
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
main = MainTrace(0, trace_type, **payload)
prev_dynamic, stack.dynamic = stack.dynamic, main
prev_base, stack.stack[0] = stack.stack[0], main
_update_thread_local_jit_state(stack.dynamic)

View File

@ -37,13 +37,15 @@ import dataclasses
from functools import partial, lru_cache
import itertools as it
import operator as op
import sys
import threading
import types
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
import sys
from absl import logging
import numpy as np
import jax
@ -61,6 +63,7 @@ from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src import dispatch
@ -77,7 +80,8 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
tuple_insert, tuple_delete, distributed_debug_log,
split_dict, unzip2)
if TYPE_CHECKING:
from jax.experimental.sharding import MeshPspecSharding, XLACompatibleSharding
@ -903,6 +907,15 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
if (config.jax_disable_jit and config.jax_eager_pmap and
global_axis_size is None and not any(d for d in donated_invars) and
not all(g is not None for g in global_arg_shapes)):
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
axis_size=axis_size, global_axis_size=global_axis_size,
devices=devices, name=name, in_axes=in_axes,
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
global_arg_shapes=global_arg_shapes)
abstract_args = unsafe_map(xla.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
@ -918,6 +931,197 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
("fingerprint", fingerprint))
return compiled_fun(*args)
class EmapInfo(NamedTuple):
backend: Optional[str]
devices: Optional[Sequence[Any]]
def _emap_impl(fun: lu.WrappedFun, *args,
backend: Optional[str],
axis_name: core.AxisName,
axis_size: int,
global_axis_size: Optional[int],
devices: Optional[Sequence[Any]],
name: str,
in_axes: Sequence[Optional[int]],
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
if any(g is not None for g in global_arg_shapes):
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
if global_axis_size is not None:
raise NotImplementedError("Non-default global_axis_size not supported in "
"eager pmap.")
emap_info = EmapInfo(backend, devices)
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
t = main.with_cur_sublevel()
tracers = [
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
ans = fun.call_wrapped(*tracers)
out_tracers = map(t.full_raise, ans)
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
del main
out_axes = out_axes_thunk()
new_outvals = []
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
out_axes=out_axis, devices=devices, backend=backend)(
np.arange(axis_size), outval)
new_outvals.append(out)
return new_outvals
def _map_schedule(idx: Tuple[Optional[int], ...]) -> List[Optional[int]]:
# In order to do a multi-map (a simultaneous map over several axes), we will
# nest several maps. Each time we do a map, we "remove" an input axis so we
# need to update the remaining map axes. For example, if we are to map over
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
return [None if i is None else
i - sum(j is not None and j < i for j in idx[:l])
for l, i in enumerate(idx)]
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
all_axes: List[Tuple[Optional[int], ...]]
) -> Tuple[Callable, Dict[core.AxisName, int]]:
used_names = []
for i, name in reversed(list(enumerate(names))):
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
if any(in_axis is not None for in_axis in in_axes):
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
backend=info.backend, devices=info.devices)
used_names.append(name)
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
return f, out_shard_axes
class MapTrace(core.Trace):
def __init__(self, *args, emap_info):
super().__init__(*args)
self.emap_info = emap_info
def pure(self, val):
return MapTracer(self, val, {})
def sublift(self, tracer):
return MapTracer(self, tracer.val, tracer.shard_axes)
def process_primitive(self, primitive, tracers, params):
info = self.main.payload["emap_info"]
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
names = [f.name for f in core.thread_local_state.trace_state.axis_env
if f.main_trace is self.main]
all_axes = [_map_schedule(map(s.get, names)) for s in shard_axes]
f_mapped, out_shard_axes = _multi_pmap(partial(primitive.bind, **params),
info, names, all_axes)
with core.eval_context(), jax.disable_jit(False):
outvals = f_mapped(*vals)
if primitive.multiple_results:
return [MapTracer(self, val, out_shard_axes) for val in outvals]
return MapTracer(self, outvals, out_shard_axes)
def process_call(self, call_primitive, fun, tracers, params):
if call_primitive is not xla.xla_call_p: raise NotImplementedError
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(call_primitive.bind, fun))
return self.process_primitive(fake_primitive, tracers, params)
def process_map(self, call_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
fake_primitive = types.SimpleNamespace(
multiple_results=True, bind=partial(call_primitive.bind, fun))
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
if ax is not None else s
for v, ax, s in zip(vals, in_axes, shard_axes)]
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
t = self.main.with_cur_sublevel()
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
ans = fun.call_wrapped(*in_tracers)
out_tracers = map(t.full_raise, ans)
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
del t, in_tracers, ans, out_tracers
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes)
def process_axis_index(self, frame):
fake_primitive = types.SimpleNamespace(
multiple_results=False, bind=lambda _: jax.lax.axis_index(frame.name))
with core.eval_context():
range = jax.lax.iota(np.int32, frame.size)
dummy_tracer = MapTracer(self, range, {frame.name: 0})
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
annotation: Optional[int]) -> Optional[int]:
if annotation is None: return None
mapped_axes_ = set(mapped_axes)
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
shard_axis_src: Dict[core.AxisName, int],
dst_annotation: Optional[int]
) -> Tuple[Any, Dict[core.AxisName, int]]:
shard_axis_out = dict(shard_axis_src)
src = shard_axis_out.pop(axis_name, None)
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
dst_annotation)
with core.eval_context():
if src == dst:
outval = val
elif type(src) == type(dst) == int:
outval = batching.moveaxis(val, src, dst)
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
elif src is None and dst is not None:
outval = batching.broadcast(val, axis_size, dst)
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
else:
raise NotImplementedError
return outval, shard_axis_out
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
src: int, dst: int) -> Dict[core.AxisName, int]:
lst: List[Optional[core.AxisName]] = [None] * ndim
for k, v in shard_axes.items():
lst[v] = k
name = lst.pop(src)
lst.insert(dst - (src < dst), name)
return {name: i for i, name in enumerate(lst) if name is not None}
class MapTracer(core.Tracer):
__slots__ = ["val", "shard_axes"]
def __init__(self, trace: MapTrace, val, shard_axes: Dict[core.AxisName, int]):
self._trace = trace
self.val = val
self.shard_axes = shard_axes
assert all(val < self.val.ndim for val in self.shard_axes.values())
@property
def aval(self):
aval = xla.abstractify(self.val)
shard_axes = dict(self.shard_axes)
for axis_idx in sorted(shard_axes.values())[::-1]:
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
return aval
def full_lower(self):
return self
def __str__(self):
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
return f"{self.val}{{{','.join(named_axes)}}}"
@lu.cache
def parallel_callable(fun: lu.WrappedFun,

View File

@ -1156,15 +1156,15 @@ class PythonPmapTest(jtu.JaxTestCase):
def testDeviceCountError(self):
device_count = jax.device_count()
f = self.pmap(lambda x: x)
f = self.pmap(lambda x: 2 * x)
x = jnp.arange(device_count + 1)
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
f = self.pmap(lambda x: x)
f = self.pmap(lambda x: 2 * x)
x = np.ones((device_count + 1, 10))
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
f = self.pmap(lambda x: self.pmap(lambda x: x)(x))
f = self.pmap(lambda x: self.pmap(lambda x: 2 * x)(x))
x = np.ones((device_count, 2, 10))
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
@ -1178,11 +1178,12 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = np.repeat(3, device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
f = self.pmap(lambda x: (x, 3))
x = np.arange(device_count)
with jtu.assert_num_jit_and_pmap_compilations(1):
_, ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
if not config.jax_disable_jit:
f = self.pmap(lambda x: (x, 3))
x = np.arange(device_count)
with jtu.assert_num_jit_and_pmap_compilations(1):
_, ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPmapConstantDevices(self):
if jax.device_count() == 1:
@ -1203,7 +1204,9 @@ class PythonPmapTest(jtu.JaxTestCase):
bufs = ans._arrays
else:
bufs = ans.device_buffers
self.assertEqual([b.device() for b in bufs], devices)
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
if not config.jax_disable_jit:
self.assertEqual([b.device() for b in bufs], devices)
def testPmapConstantError(self):
device_count = jax.device_count()
@ -1286,6 +1289,8 @@ class PythonPmapTest(jtu.JaxTestCase):
[b.device() for b in expected_sharded_bufs])
def testNestedPmapConstantError(self):
if config.jax_disable_jit:
raise SkipTest("error test doesn't apply with disable_jit")
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2 + 1, 3)
x = jnp.arange(prod(shape)).reshape(shape)
@ -1702,7 +1707,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testCompositionWithJitTwice(self):
@jit
def f(x):
y = 2 * x
y = jnp.float32(2) * x
@jit
def g(z):
@ -1710,7 +1715,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return g(x)
f(np.arange(1.).reshape((1, 1))) # doesn't crash
f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash
@ignore_jit_of_pmap_warning()
def testIssue1065(self):
@ -1904,7 +1909,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testJitOfPmapWarningMessage(self):
device_count = jax.device_count()
if device_count == 1:
if device_count == 1 or config.jax_disable_jit:
raise SkipTest("test requires at least two devices")
def foo(x): return x
@ -2068,6 +2073,8 @@ class PythonPmapTest(jtu.JaxTestCase):
def test_grad_of_pmap_compilation_caching(self, axis_size):
if len(jax.local_devices()) < axis_size:
raise SkipTest("too few devices for test")
if config.jax_disable_jit:
raise SkipTest("caching doesn't apply with jit disabled")
@jax.pmap
def f(x):
@ -2474,6 +2481,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
f(jnp.ones(jax.device_count() + 1))
def testBadAxisSizeErrorNested(self):
if config.jax_disable_jit:
raise SkipTest("error doesn't apply when jit is disabled")
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
axis_name='j'),
axis_name='i',
@ -2487,6 +2496,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNestedPmaps(self):
if jax.device_count() % 2 != 0:
raise SkipTest
if config.jax_disable_jit:
raise SkipTest("disable_jit requires num devices to equal axis size")
# Devices specified in outer pmap are OK
@partial(pmap, axis_name='i', devices=jax.devices())
@ -2504,6 +2515,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNestedPmapsBools(self):
if jax.device_count() % 2 != 0:
raise SkipTest
if config.jax_disable_jit:
raise SkipTest("disable_jit requires num devices to equal axis size")
# Devices specified in outer pmap are OK
@partial(pmap, axis_name='i', devices=jax.devices())
@ -3132,6 +3145,35 @@ class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
pass
class EagerPmapMixin:
def setUp(self):
super().setUp()
self.eager_pmap_enabled = config.jax_eager_pmap
self.jit_disabled = config.jax_disable_jit
config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True)
def tearDown(self):
config.update('jax_eager_pmap', self.eager_pmap_enabled)
config.update('jax_disable_jit', self.jit_disabled)
super().tearDown()
class EagerPythonPmapTest(EagerPmapMixin, PythonPmapTest):
pass
class EagerCppPmapTest(EagerPmapMixin, CppPmapTest):
pass
class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest):
pass
class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest):
pass
class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest):
pass
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())