mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11854 from mattjj:sharads-cool-stuff
PiperOrigin-RevId: 467766379
This commit is contained in:
commit
b75969c5a1
@ -660,7 +660,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_jit():
|
||||
def disable_jit(disable: bool = True):
|
||||
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
|
||||
|
||||
For debugging it is useful to have a mechanism that disables :py:func:`jit`
|
||||
@ -704,7 +704,7 @@ def disable_jit():
|
||||
Value of y is [2 4 6]
|
||||
[5 7 9]
|
||||
"""
|
||||
with _disable_jit(True):
|
||||
with _disable_jit(disable):
|
||||
yield
|
||||
|
||||
|
||||
|
@ -897,6 +897,13 @@ config.define_bool_state(
|
||||
default=False,
|
||||
help='Enable using a cache for lowering subjaxprs.')
|
||||
|
||||
# TODO(sharadmv,mattjj): set default to True, then remove
|
||||
config.define_bool_state(
|
||||
name='jax_eager_pmap',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
|
@ -914,10 +914,11 @@ def new_main(trace_type: Type[Trace],
|
||||
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
|
||||
def new_base_main(trace_type: Type[Trace],
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
main = MainTrace(0, trace_type)
|
||||
main = MainTrace(0, trace_type, **payload)
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, main
|
||||
prev_base, stack.stack[0] = stack.stack[0], main
|
||||
_update_thread_local_jit_state(stack.dynamic)
|
||||
|
@ -37,13 +37,15 @@ import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
import sys
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
@ -61,6 +63,7 @@ from jax.tree_util import tree_flatten, tree_map
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
@ -77,7 +80,8 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
new_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
split_dict, unzip2)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jax.experimental.sharding import MeshPspecSharding, XLACompatibleSharding
|
||||
@ -903,6 +907,15 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
if (config.jax_disable_jit and config.jax_eager_pmap and
|
||||
global_axis_size is None and not any(d for d in donated_invars) and
|
||||
not all(g is not None for g in global_arg_shapes)):
|
||||
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
||||
axis_size=axis_size, global_axis_size=global_axis_size,
|
||||
devices=devices, name=name, in_axes=in_axes,
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
donated_invars=donated_invars,
|
||||
global_arg_shapes=global_arg_shapes)
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
compiled_fun, fingerprint = parallel_callable(
|
||||
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
|
||||
@ -918,6 +931,197 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
|
||||
("fingerprint", fingerprint))
|
||||
return compiled_fun(*args)
|
||||
|
||||
class EmapInfo(NamedTuple):
|
||||
backend: Optional[str]
|
||||
devices: Optional[Sequence[Any]]
|
||||
|
||||
def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
backend: Optional[str],
|
||||
axis_name: core.AxisName,
|
||||
axis_size: int,
|
||||
global_axis_size: Optional[int],
|
||||
devices: Optional[Sequence[Any]],
|
||||
name: str,
|
||||
in_axes: Sequence[Optional[int]],
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
# TODO(sharadmv,mattjj): implement these cases
|
||||
if any(d for d in donated_invars):
|
||||
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
||||
if any(g is not None for g in global_arg_shapes):
|
||||
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
|
||||
if global_axis_size is not None:
|
||||
raise NotImplementedError("Non-default global_axis_size not supported in "
|
||||
"eager pmap.")
|
||||
|
||||
emap_info = EmapInfo(backend, devices)
|
||||
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
|
||||
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
|
||||
t = main.with_cur_sublevel()
|
||||
tracers = [
|
||||
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
|
||||
ans = fun.call_wrapped(*tracers)
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
||||
del main
|
||||
out_axes = out_axes_thunk()
|
||||
|
||||
new_outvals = []
|
||||
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
||||
with jax.disable_jit(False):
|
||||
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
|
||||
out_axes=out_axis, devices=devices, backend=backend)(
|
||||
np.arange(axis_size), outval)
|
||||
new_outvals.append(out)
|
||||
return new_outvals
|
||||
|
||||
def _map_schedule(idx: Tuple[Optional[int], ...]) -> List[Optional[int]]:
|
||||
# In order to do a multi-map (a simultaneous map over several axes), we will
|
||||
# nest several maps. Each time we do a map, we "remove" an input axis so we
|
||||
# need to update the remaining map axes. For example, if we are to map over
|
||||
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
|
||||
return [None if i is None else
|
||||
i - sum(j is not None and j < i for j in idx[:l])
|
||||
for l, i in enumerate(idx)]
|
||||
|
||||
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
|
||||
all_axes: List[Tuple[Optional[int], ...]]
|
||||
) -> Tuple[Callable, Dict[core.AxisName, int]]:
|
||||
used_names = []
|
||||
for i, name in reversed(list(enumerate(names))):
|
||||
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
||||
if any(in_axis is not None for in_axis in in_axes):
|
||||
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
|
||||
backend=info.backend, devices=info.devices)
|
||||
used_names.append(name)
|
||||
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
||||
return f, out_shard_axes
|
||||
|
||||
class MapTrace(core.Trace):
|
||||
|
||||
def __init__(self, *args, emap_info):
|
||||
super().__init__(*args)
|
||||
self.emap_info = emap_info
|
||||
|
||||
def pure(self, val):
|
||||
return MapTracer(self, val, {})
|
||||
|
||||
def sublift(self, tracer):
|
||||
return MapTracer(self, tracer.val, tracer.shard_axes)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
info = self.main.payload["emap_info"]
|
||||
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
||||
names = [f.name for f in core.thread_local_state.trace_state.axis_env
|
||||
if f.main_trace is self.main]
|
||||
all_axes = [_map_schedule(map(s.get, names)) for s in shard_axes]
|
||||
f_mapped, out_shard_axes = _multi_pmap(partial(primitive.bind, **params),
|
||||
info, names, all_axes)
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
outvals = f_mapped(*vals)
|
||||
if primitive.multiple_results:
|
||||
return [MapTracer(self, val, out_shard_axes) for val in outvals]
|
||||
return MapTracer(self, outvals, out_shard_axes)
|
||||
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
||||
fake_primitive = types.SimpleNamespace(
|
||||
multiple_results=True, bind=partial(call_primitive.bind, fun))
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
|
||||
def process_map(self, call_primitive, fun, tracers, params):
|
||||
if params['devices'] is not None:
|
||||
raise ValueError("Nested pmap with explicit devices argument.")
|
||||
if not config.jax_disable_jit:
|
||||
fake_primitive = types.SimpleNamespace(
|
||||
multiple_results=True, bind=partial(call_primitive.bind, fun))
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
||||
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
||||
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
||||
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
|
||||
if ax is not None else s
|
||||
for v, ax, s in zip(vals, in_axes, shard_axes)]
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
||||
t = self.main.with_cur_sublevel()
|
||||
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
||||
del t, in_tracers, ans, out_tracers
|
||||
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
|
||||
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
|
||||
return map(partial(MapTracer, self), out, outaxes)
|
||||
|
||||
def process_axis_index(self, frame):
|
||||
fake_primitive = types.SimpleNamespace(
|
||||
multiple_results=False, bind=lambda _: jax.lax.axis_index(frame.name))
|
||||
with core.eval_context():
|
||||
range = jax.lax.iota(np.int32, frame.size)
|
||||
dummy_tracer = MapTracer(self, range, {frame.name: 0})
|
||||
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
|
||||
|
||||
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
||||
annotation: Optional[int]) -> Optional[int]:
|
||||
if annotation is None: return None
|
||||
mapped_axes_ = set(mapped_axes)
|
||||
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
|
||||
|
||||
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
||||
shard_axis_src: Dict[core.AxisName, int],
|
||||
dst_annotation: Optional[int]
|
||||
) -> Tuple[Any, Dict[core.AxisName, int]]:
|
||||
shard_axis_out = dict(shard_axis_src)
|
||||
src = shard_axis_out.pop(axis_name, None)
|
||||
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
|
||||
dst_annotation)
|
||||
with core.eval_context():
|
||||
if src == dst:
|
||||
outval = val
|
||||
elif type(src) == type(dst) == int:
|
||||
outval = batching.moveaxis(val, src, dst)
|
||||
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
|
||||
elif src is None and dst is not None:
|
||||
outval = batching.broadcast(val, axis_size, dst)
|
||||
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return outval, shard_axis_out
|
||||
|
||||
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
|
||||
src: int, dst: int) -> Dict[core.AxisName, int]:
|
||||
lst: List[Optional[core.AxisName]] = [None] * ndim
|
||||
for k, v in shard_axes.items():
|
||||
lst[v] = k
|
||||
name = lst.pop(src)
|
||||
lst.insert(dst - (src < dst), name)
|
||||
return {name: i for i, name in enumerate(lst) if name is not None}
|
||||
|
||||
class MapTracer(core.Tracer):
|
||||
__slots__ = ["val", "shard_axes"]
|
||||
|
||||
def __init__(self, trace: MapTrace, val, shard_axes: Dict[core.AxisName, int]):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
self.shard_axes = shard_axes
|
||||
assert all(val < self.val.ndim for val in self.shard_axes.values())
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
aval = xla.abstractify(self.val)
|
||||
shard_axes = dict(self.shard_axes)
|
||||
for axis_idx in sorted(shard_axes.values())[::-1]:
|
||||
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
|
||||
return aval
|
||||
|
||||
def full_lower(self):
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
named_axes = [f"{k}={v}" for k, v in self.shard_axes.items()]
|
||||
return f"{self.val}{{{','.join(named_axes)}}}"
|
||||
|
||||
@lu.cache
|
||||
def parallel_callable(fun: lu.WrappedFun,
|
||||
|
@ -1156,15 +1156,15 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testDeviceCountError(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
f = self.pmap(lambda x: x)
|
||||
f = self.pmap(lambda x: 2 * x)
|
||||
x = jnp.arange(device_count + 1)
|
||||
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
||||
|
||||
f = self.pmap(lambda x: x)
|
||||
f = self.pmap(lambda x: 2 * x)
|
||||
x = np.ones((device_count + 1, 10))
|
||||
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
||||
|
||||
f = self.pmap(lambda x: self.pmap(lambda x: x)(x))
|
||||
f = self.pmap(lambda x: self.pmap(lambda x: 2 * x)(x))
|
||||
x = np.ones((device_count, 2, 10))
|
||||
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
||||
|
||||
@ -1178,11 +1178,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = np.repeat(3, device_count)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
f = self.pmap(lambda x: (x, 3))
|
||||
x = np.arange(device_count)
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
_, ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
if not config.jax_disable_jit:
|
||||
f = self.pmap(lambda x: (x, 3))
|
||||
x = np.arange(device_count)
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
_, ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPmapConstantDevices(self):
|
||||
if jax.device_count() == 1:
|
||||
@ -1203,7 +1204,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
bufs = ans._arrays
|
||||
else:
|
||||
bufs = ans.device_buffers
|
||||
self.assertEqual([b.device() for b in bufs], devices)
|
||||
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
|
||||
if not config.jax_disable_jit:
|
||||
self.assertEqual([b.device() for b in bufs], devices)
|
||||
|
||||
def testPmapConstantError(self):
|
||||
device_count = jax.device_count()
|
||||
@ -1286,6 +1289,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
[b.device() for b in expected_sharded_bufs])
|
||||
|
||||
def testNestedPmapConstantError(self):
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("error test doesn't apply with disable_jit")
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2 + 1, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
@ -1702,7 +1707,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testCompositionWithJitTwice(self):
|
||||
@jit
|
||||
def f(x):
|
||||
y = 2 * x
|
||||
y = jnp.float32(2) * x
|
||||
|
||||
@jit
|
||||
def g(z):
|
||||
@ -1710,7 +1715,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
return g(x)
|
||||
|
||||
f(np.arange(1.).reshape((1, 1))) # doesn't crash
|
||||
f(np.arange(1., dtype='float32').reshape((1, 1))) # doesn't crash
|
||||
|
||||
@ignore_jit_of_pmap_warning()
|
||||
def testIssue1065(self):
|
||||
@ -1904,7 +1909,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testJitOfPmapWarningMessage(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1:
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
def foo(x): return x
|
||||
@ -2068,6 +2073,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def test_grad_of_pmap_compilation_caching(self, axis_size):
|
||||
if len(jax.local_devices()) < axis_size:
|
||||
raise SkipTest("too few devices for test")
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("caching doesn't apply with jit disabled")
|
||||
|
||||
@jax.pmap
|
||||
def f(x):
|
||||
@ -2474,6 +2481,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
f(jnp.ones(jax.device_count() + 1))
|
||||
|
||||
def testBadAxisSizeErrorNested(self):
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("error doesn't apply when jit is disabled")
|
||||
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
|
||||
axis_name='j'),
|
||||
axis_name='i',
|
||||
@ -2487,6 +2496,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testNestedPmaps(self):
|
||||
if jax.device_count() % 2 != 0:
|
||||
raise SkipTest
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("disable_jit requires num devices to equal axis size")
|
||||
|
||||
# Devices specified in outer pmap are OK
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
@ -2504,6 +2515,8 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testNestedPmapsBools(self):
|
||||
if jax.device_count() % 2 != 0:
|
||||
raise SkipTest
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("disable_jit requires num devices to equal axis size")
|
||||
|
||||
# Devices specified in outer pmap are OK
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
@ -3132,6 +3145,35 @@ class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
|
||||
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
|
||||
pass
|
||||
|
||||
class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.eager_pmap_enabled = config.jax_eager_pmap
|
||||
self.jit_disabled = config.jax_disable_jit
|
||||
config.update('jax_disable_jit', True)
|
||||
config.update('jax_eager_pmap', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
||||
config.update('jax_disable_jit', self.jit_disabled)
|
||||
super().tearDown()
|
||||
|
||||
class EagerPythonPmapTest(EagerPmapMixin, PythonPmapTest):
|
||||
pass
|
||||
|
||||
class EagerCppPmapTest(EagerPmapMixin, CppPmapTest):
|
||||
pass
|
||||
|
||||
class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest):
|
||||
pass
|
||||
|
||||
class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest):
|
||||
pass
|
||||
|
||||
class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user