mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Cleaning up eager pmap implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
a7f760d9ed
commit
fe040cc01e
@ -113,12 +113,12 @@ zip, unsafe_zip = safe_zip, zip
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", False),
|
||||
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
|
||||
"A flag enabling the C++ jax.jit fast path."
|
||||
"Set this to `False` only if it crashes otherwise and report "
|
||||
"the error to the jax-team.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", False),
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
|
@ -897,6 +897,13 @@ config.define_bool_state(
|
||||
default=False,
|
||||
help='Enable using a cache for lowering subjaxprs.')
|
||||
|
||||
# TODO(sharadmv,mattjj): set default to True, then remove
|
||||
config.define_bool_state(
|
||||
name='jax_eager_pmap',
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help='Enable eager-mode pmap when jax_disable_jit is activated.')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def explicit_device_put_scope() -> Iterator[None]:
|
||||
"""Indicates that the current context is an explicit device_put*() call."""
|
||||
|
@ -907,7 +907,9 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
if config.jax_disable_jit:
|
||||
if (config.jax_disable_jit and config.jax_eager_pmap and
|
||||
global_axis_size is None and not any(d for d in donated_invars) and
|
||||
not all(g is not None for g in global_arg_shapes)):
|
||||
return _emap_impl(fun, *args, backend=backend, axis_name=axis_name,
|
||||
axis_size=axis_size, global_axis_size=global_axis_size,
|
||||
devices=devices, name=name, in_axes=in_axes,
|
||||
@ -929,6 +931,10 @@ def xla_pmap_impl(fun: lu.WrappedFun, *args,
|
||||
("fingerprint", fingerprint))
|
||||
return compiled_fun(*args)
|
||||
|
||||
class EmapInfo(NamedTuple):
|
||||
backend: Optional[str]
|
||||
devices: Optional[Sequence[Any]]
|
||||
|
||||
def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
backend: Optional[str],
|
||||
axis_name: core.AxisName,
|
||||
@ -940,76 +946,64 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
if global_axis_size is not None: raise NotImplementedError
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size,
|
||||
global_axis_size, devices, None, None, ())
|
||||
# if devices is not None:
|
||||
# if len(devices) == 0:
|
||||
# raise ValueError("'devices' argument to pmap must be non-empty, or None.")
|
||||
# if len(devices) != axis_size:
|
||||
# raise ValueError(
|
||||
# f"Leading axis size of input to pmapped function must equal the "
|
||||
# f"number of local devices passed to pmap. Got axis_size="
|
||||
# f"{axis_size}, num_local_devices={len(devices)}.")
|
||||
# else:
|
||||
# devices = xb.devices(backend=backend)[:axis_size]
|
||||
# if len(devices) != axis_size:
|
||||
# msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
||||
# "devices are available (num_replicas={}, num_partitions={})")
|
||||
# raise ValueError(msg.format(axis_size,
|
||||
# xb.device_count(backend),
|
||||
# None,
|
||||
# None))
|
||||
sharded_args = []
|
||||
shard_axes = []
|
||||
for arg, in_axis in zip(args, in_axes):
|
||||
if in_axis == 0:
|
||||
sharded_args.append(arg)
|
||||
shard_axes.append({axis_name: 0})
|
||||
elif in_axis is None:
|
||||
sharded_args.append(arg)
|
||||
shard_axes.append({})
|
||||
else:
|
||||
perm = list(range(arg.ndim))
|
||||
a = perm.pop(in_axis)
|
||||
perm.insert(0, a)
|
||||
sharded_args.append(arg.transpose(perm))
|
||||
shard_axes.append({axis_name: 0})
|
||||
with core.new_base_main(MapTrace, pci=pci) as main:
|
||||
# TODO(sharadmv,mattjj): implement these cases
|
||||
if any(d for d in donated_invars):
|
||||
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
||||
if any(g is not None for g in global_arg_shapes):
|
||||
raise NotImplementedError("Global arg shapes not supported in eager pmap.")
|
||||
if global_axis_size is not None:
|
||||
raise NotImplementedError("Non-default global_axis_size not supported in "
|
||||
"eager pmap.")
|
||||
|
||||
emap_info = EmapInfo(backend, devices)
|
||||
shard_axes = [{} if in_axis is None else {axis_name: in_axis} for in_axis in in_axes]
|
||||
with core.new_base_main(MapTrace, emap_info=emap_info) as main:
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
|
||||
t = main.with_cur_sublevel()
|
||||
tracers = [
|
||||
MapTracer(t, arg, s) for arg, s in zip(sharded_args, shard_axes)]
|
||||
MapTracer(t, arg, s) for arg, s in zip(args, shard_axes)]
|
||||
ans = fun.call_wrapped(*tracers)
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
outvals, out_axes_src = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
||||
del main
|
||||
out_axes = out_axes_thunk()
|
||||
|
||||
# This next bit is like matchaxis in batching.py (for the end of a vmap)
|
||||
new_outvals = []
|
||||
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
||||
with jax._src.config.disable_jit(False):
|
||||
with jax.disable_jit(False):
|
||||
out = jax.pmap(lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)),
|
||||
out_axes=out_axis)(np.arange(axis_size), outval)
|
||||
out_axes=out_axis, devices=devices, backend=backend)(
|
||||
np.arange(axis_size), outval)
|
||||
new_outvals.append(out)
|
||||
return new_outvals
|
||||
|
||||
def _map_schedule(idx: Tuple[Optional[int], ...]) -> List[Optional[int]]:
|
||||
# In order to do a multi-map (a simultaneous map over several axes), we will
|
||||
# nest several maps. Each time we do a map, we "remove" an input axis so we
|
||||
# need to update the remaining map axes. For example, if we are to map over
|
||||
# the axes 0, 3, and 4, we make three calls to pmap with in_axes as 0, 2, 2.
|
||||
return [None if i is None else
|
||||
i - sum(j is not None and j < i for j in idx[:l])
|
||||
for l, i in enumerate(idx)]
|
||||
|
||||
def _map_indices_to_map_schedule(idx: Tuple[Optional[int], ...]):
|
||||
return tuple(None if i is None else i - sum(j is not None and j < i for j in idx[:l]) for l, i in enumerate(idx))
|
||||
def _multi_pmap(f: Callable, info: EmapInfo, names: List[core.AxisName],
|
||||
all_axes: List[Tuple[Optional[int], ...]]
|
||||
) -> Tuple[Callable, Dict[core.AxisName, int]]:
|
||||
used_names = []
|
||||
for i, name in reversed(list(enumerate(names))):
|
||||
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
||||
if any(in_axis is not None for in_axis in in_axes):
|
||||
f = jax.pmap(f, in_axes=in_axes, axis_name=name, out_axes=0,
|
||||
backend=info.backend, devices=info.devices)
|
||||
used_names.append(name)
|
||||
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
||||
return f, out_shard_axes
|
||||
|
||||
class MapTrace(core.Trace):
|
||||
|
||||
def __init__(self, *args, pci):
|
||||
def __init__(self, *args, emap_info):
|
||||
super().__init__(*args)
|
||||
self.pci = pci
|
||||
|
||||
def _get_frames(self):
|
||||
frames = [f for f in core.thread_local_state.trace_state.axis_env
|
||||
if f.main_trace is self.main]
|
||||
return frames
|
||||
self.emap_info = emap_info
|
||||
|
||||
def pure(self, val):
|
||||
return MapTracer(self, val, {})
|
||||
@ -1018,24 +1012,15 @@ class MapTrace(core.Trace):
|
||||
return MapTracer(self, tracer.val, tracer.shard_axes)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
vals = [t.val for t in tracers]
|
||||
names = [f.name for f in self._get_frames()]
|
||||
f = lambda *args: primitive.bind(*args, **params)
|
||||
used_names = []
|
||||
all_axes = []
|
||||
for t in tracers:
|
||||
arg_axes = tuple(t.shard_axes.get(name, None) for name in names)
|
||||
arg_axes = _map_indices_to_map_schedule(arg_axes)
|
||||
all_axes.append(arg_axes)
|
||||
for i, name in reversed(list(enumerate(names))):
|
||||
in_axes = tuple(arg_axis[i] for arg_axis in all_axes)
|
||||
if any(in_axis is not None for in_axis in in_axes):
|
||||
f = jax.pmap(f, in_axes=in_axes, axis_name=name,
|
||||
devices=self.main.payload['pci'].devices)
|
||||
used_names.append(name)
|
||||
with core.eval_context(), jax._src.config.disable_jit(False):
|
||||
outvals = f(*vals)
|
||||
out_shard_axes = {name: i for i, name in enumerate(reversed(used_names))}
|
||||
info = self.main.payload["emap_info"]
|
||||
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
||||
names = [f.name for f in core.thread_local_state.trace_state.axis_env
|
||||
if f.main_trace is self.main]
|
||||
all_axes = [_map_schedule(map(s.get, names)) for s in shard_axes]
|
||||
f_mapped, out_shard_axes = _multi_pmap(partial(primitive.bind, **params),
|
||||
info, names, all_axes)
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
outvals = f_mapped(*vals)
|
||||
if primitive.multiple_results:
|
||||
return [MapTracer(self, val, out_shard_axes) for val in outvals]
|
||||
return MapTracer(self, outvals, out_shard_axes)
|
||||
@ -1049,55 +1034,26 @@ class MapTrace(core.Trace):
|
||||
def process_map(self, call_primitive, fun, tracers, params):
|
||||
if params['devices'] is not None:
|
||||
raise ValueError("Nested pmap with explicit devices argument.")
|
||||
if config.jax_disable_jit:
|
||||
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
||||
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
||||
invals = [t.val for t in tracers]
|
||||
shard_axes_src = [t.shard_axes for t in tracers]
|
||||
shard_axes = []
|
||||
for inval, in_axis, shard_axis_src in zip(invals, in_axes, shard_axes_src):
|
||||
new_shard_axis_src = dict(shard_axis_src)
|
||||
if in_axis is not None:
|
||||
idx = [i for i in range(inval.ndim) if i not in shard_axis_src.values()]
|
||||
new_idx = idx[in_axis]
|
||||
new_shard_axis_src = {axis_name: new_idx, **shard_axis_src}
|
||||
shard_axes.append(new_shard_axis_src)
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
||||
t = self.main.with_cur_sublevel()
|
||||
in_tracers = [MapTracer(t, val, shard_axis) for val, shard_axis in
|
||||
zip(invals, shard_axes)]
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_axes_dest = out_axes_thunk()
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
outvals, shard_axes_src = util.unzip2([(t.val, t.shard_axes) for t in
|
||||
out_tracers])
|
||||
new_out_tracers = map(
|
||||
partial(self._make_output_tracer, axis_name, axis_size), outvals,
|
||||
shard_axes_src, out_axes_dest)
|
||||
return new_out_tracers
|
||||
else:
|
||||
if not config.jax_disable_jit:
|
||||
fake_primitive = types.SimpleNamespace(
|
||||
multiple_results=True, bind=partial(call_primitive.bind, fun))
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
|
||||
def _make_output_tracer(self, axis_name, axis_size, val, shard_axis_src,
|
||||
dst_annotation):
|
||||
shard_axis_out = dict(shard_axis_src)
|
||||
src = shard_axis_out.pop(axis_name, None)
|
||||
dst = annotation_to_flat(np.ndim(val), shard_axis_out.values(),
|
||||
src, dst_annotation)
|
||||
with core.eval_context():
|
||||
if src == dst:
|
||||
outval = val
|
||||
elif type(src) == type(dst) == int:
|
||||
outval = batching.moveaxis(val, src, dst)
|
||||
shard_axis_out = moveaxis(np.ndim(val), shard_axis_src, src, dst)
|
||||
elif src is None and dst is not None:
|
||||
outval = batching.broadcast(val, axis_size, dst)
|
||||
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
|
||||
else:
|
||||
assert False
|
||||
return MapTracer(self, outval, shard_axis_out)
|
||||
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
||||
params["in_axes"], params["out_axes_thunk"], params["axis_size"])
|
||||
vals, shard_axes = unzip2([(t.val, t.shard_axes) for t in tracers])
|
||||
shard_axes = [{axis_name: _annot_to_flat(np.ndim(v), s.values(), ax), **s}
|
||||
if ax is not None else s
|
||||
for v, ax, s in zip(vals, in_axes, shard_axes)]
|
||||
with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, self.main):
|
||||
t = self.main.with_cur_sublevel()
|
||||
in_tracers = map(partial(MapTracer, t), vals, shard_axes)
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
out, outaxes = unzip2((t.val, t.shard_axes) for t in out_tracers)
|
||||
del t, in_tracers, ans, out_tracers
|
||||
out, outaxes = unzip2(_match_annot(axis_name, axis_size, v, s, dst)
|
||||
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
|
||||
return map(partial(MapTracer, self), out, outaxes)
|
||||
|
||||
def process_axis_index(self, frame):
|
||||
fake_primitive = types.SimpleNamespace(
|
||||
@ -1107,19 +1063,35 @@ class MapTrace(core.Trace):
|
||||
dummy_tracer = MapTracer(self, range, {frame.name: 0})
|
||||
return self.process_primitive(fake_primitive, (dummy_tracer,), {})
|
||||
|
||||
def annotation_to_flat(ndim: int, mapped_axes: Sequence[int],
|
||||
src_flat: Optional[int], dst_annotation: Optional[int]
|
||||
) -> Optional[int]:
|
||||
if dst_annotation is None:
|
||||
return None
|
||||
ndim_ = ndim - len(mapped_axes) + (src_flat is None)
|
||||
dst_annotation = batching.canonicalize_axis(dst_annotation, ndim_)
|
||||
idx = [i for i in range(ndim + (src_flat is None)) if i not in mapped_axes]
|
||||
out = idx[dst_annotation]
|
||||
return out
|
||||
def _annot_to_flat(ndim: int, mapped_axes: Iterable[int],
|
||||
annotation: Optional[int]) -> Optional[int]:
|
||||
if annotation is None: return None
|
||||
mapped_axes_ = set(mapped_axes)
|
||||
return [i for i in range(ndim) if i not in mapped_axes_][annotation]
|
||||
|
||||
def moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
|
||||
src: int, dst: int):
|
||||
def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any,
|
||||
shard_axis_src: Dict[core.AxisName, int],
|
||||
dst_annotation: Optional[int]
|
||||
) -> Tuple[Any, Dict[core.AxisName, int]]:
|
||||
shard_axis_out = dict(shard_axis_src)
|
||||
src = shard_axis_out.pop(axis_name, None)
|
||||
dst = _annot_to_flat(np.ndim(val) + (src is None), shard_axis_out.values(),
|
||||
dst_annotation)
|
||||
with core.eval_context():
|
||||
if src == dst:
|
||||
outval = val
|
||||
elif type(src) == type(dst) == int:
|
||||
outval = batching.moveaxis(val, src, dst)
|
||||
shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst)
|
||||
elif src is None and dst is not None:
|
||||
outval = batching.broadcast(val, axis_size, dst)
|
||||
shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return outval, shard_axis_out
|
||||
|
||||
def _moveaxis(ndim: int, shard_axes: Dict[core.AxisName, int],
|
||||
src: int, dst: int) -> Dict[core.AxisName, int]:
|
||||
lst: List[Optional[core.AxisName]] = [None] * ndim
|
||||
for k, v in shard_axes.items():
|
||||
lst[v] = k
|
||||
|
188
matt.py
188
matt.py
@ -1,188 +0,0 @@
|
||||
import pdb, sys, traceback
|
||||
def info(type, value, tb):
|
||||
traceback.print_exception(type, value, tb)
|
||||
pdb.pm()
|
||||
sys.excepthook = info
|
||||
|
||||
import os
|
||||
import functools
|
||||
import re
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
config.update("jax_new_checkpoint", True)
|
||||
config.update("jax_traceback_filtering", "off")
|
||||
config.update("jax_platform_name", "cpu")
|
||||
|
||||
def set_host_device_count(n):
|
||||
xla_flags = os.getenv("XLA_FLAGS", "")
|
||||
xla_flags = re.sub(
|
||||
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
|
||||
).split()
|
||||
os.environ["XLA_FLAGS"] = " ".join(
|
||||
["--xla_force_host_platform_device_count={}".format(n)] + xla_flags
|
||||
)
|
||||
|
||||
set_host_device_count(8)
|
||||
|
||||
# @functools.partial(jax.pmap, axis_name="foo")
|
||||
# def f(x):
|
||||
# def cond(i):
|
||||
# return i < 10
|
||||
# def body(i):
|
||||
# return i + 1
|
||||
# return jax.lax.while_loop(cond, body, x)
|
||||
|
||||
# @functools.partial(jax.pmap, axis_name="foo")
|
||||
# def f(x):
|
||||
# def cond(i):
|
||||
# return i < 10
|
||||
# def body(i):
|
||||
# return i + 1
|
||||
# return jax.lax.while_loop(cond, body, x)
|
||||
|
||||
# print(f(jnp.arange(4.)))
|
||||
|
||||
# with jax.disable_jit():
|
||||
# print(f(jnp.arange(4.)))
|
||||
|
||||
|
||||
# @jax.pmap
|
||||
# def g(x):
|
||||
# with jax._src.config.disable_jit(False):
|
||||
# return jax.jit(jnp.sin)(x)
|
||||
|
||||
# print(g(jnp.arange(4.)))
|
||||
|
||||
# with jax.disable_jit():
|
||||
# print(g(jnp.arange(4.)))
|
||||
|
||||
# @functools.partial(jax.pmap, in_axes=(0, None, 0, None), axis_name='i')
|
||||
# @functools.partial(jax.pmap, in_axes=(None, 0, 0, None), axis_name='j')
|
||||
# def f(x, y, z, w):
|
||||
# return jax.lax.axis_index(['i', 'j']) + x * y + z + w
|
||||
|
||||
# print(f(jnp.arange(4.), jnp.arange(2.), jnp.arange(8.).reshape((4, 2)), 100.))
|
||||
|
||||
# with jax.disable_jit():
|
||||
# print(f(jnp.arange(4.), jnp.arange(2.), jnp.arange(8.).reshape((4, 2)), 100.))
|
||||
|
||||
device_count = jax.device_count()
|
||||
|
||||
# @functools.partial(jax.pmap, axis_name='i')
|
||||
# def f(x):
|
||||
# @functools.partial(jax.pmap, axis_name='j')
|
||||
# def g(y):
|
||||
# a = jax.lax.psum(1, 'i')
|
||||
# b = jax.lax.psum(1, 'j')
|
||||
# c = jax.lax.psum(1, ('i', 'j'))
|
||||
# return a, b, c
|
||||
# return g(x)
|
||||
|
||||
# import numpy as np
|
||||
# shape = (device_count, 1, 4)
|
||||
# x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
# a, b, c = f(x)
|
||||
# print(a)
|
||||
# print(b)
|
||||
# print(c)
|
||||
|
||||
# with jax.disable_jit():
|
||||
# a, b, c = f(x)
|
||||
# print(a)
|
||||
# print(b)
|
||||
# print(c)
|
||||
|
||||
# f = lambda axis: jax.pmap(jax.pmap(lambda x: x + jax.lax.axis_index(axis), 'j'), 'i')
|
||||
# x = jnp.ones((2, 2), dtype='int32')
|
||||
# print(f('i')(x))
|
||||
# print(f('j')(x))
|
||||
|
||||
# with jax.disable_jit():
|
||||
# print(f('i')(x))
|
||||
# print(f('j')(x))
|
||||
|
||||
|
||||
# def f(key):
|
||||
# key = jax.random.fold_in(key, jax.lax.axis_index('i'))
|
||||
# return jax.random.bernoulli(key, p=0.5)
|
||||
|
||||
# keys = jax.random.split(jax.random.PRNGKey(0), len(jax.devices()))
|
||||
|
||||
# print(jax.pmap(jax.remat(f), axis_name='i')(keys))
|
||||
|
||||
# with jax.disable_jit():
|
||||
# print(jax.pmap(jax.remat(f), axis_name='i')(keys))
|
||||
|
||||
|
||||
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
|
||||
# with jax.disable_jit():
|
||||
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
|
||||
|
||||
# jax.pmap(lambda x: x)(jnp.zeros(jax.device_count() + 1))
|
||||
# with jax.disable_jit():
|
||||
# jax.pmap(jax.pmap(jnp.square))(jnp.arange(16).reshape((4, 4)))
|
||||
|
||||
# f = jax.pmap(lambda x: jax.pmap(lambda x: x)(x))
|
||||
# x = jnp.ones((jax.device_count(), 2, 10))
|
||||
# f(x)
|
||||
# with jax.disable_jit():
|
||||
# print(f(x))
|
||||
|
||||
f = jax.pmap(jax.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2, 3)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
print(f(x))
|
||||
with jax.disable_jit():
|
||||
print(f(x))
|
||||
# TODO:
|
||||
# * [x] process_call
|
||||
# * jit-of-emap = pmap (already, b/c we're only changing the pmap impl)
|
||||
# * emap-of-jit = our processs_call rule should act same as initial style HOPs
|
||||
# * emap-of-core.call = do a subtrace like thing where we turn around and stay
|
||||
# in python
|
||||
# * [ ] collectives
|
||||
# * [ ] testing
|
||||
# * [ ] nesting (process_map, sublift, etc)
|
||||
# * [ ] shadowing of names
|
||||
|
||||
# * delete process_call and core.call, just have an xla_call rule
|
||||
# * no call updaters!
|
||||
# * blocked on delete old remat
|
||||
# * first just make xla_call have its own rule
|
||||
# * then make it INITIAL STYLE
|
||||
# * make it take closed jaxprs, so we can delete core.closed_call
|
||||
# * delete all updaters
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# brainstorming process_map
|
||||
# assert map_primitive is xla_pmap_p
|
||||
# backend, axis_size, axis_name = (
|
||||
# params['backend'], params['axis_size'], params['axis_name'])
|
||||
# if config.jax_disable_jit:
|
||||
# shape = [f.size for f in self._get_frames()]
|
||||
# devices = xb.devices(backend=backend)[:prod(shape) * axis_size]
|
||||
# breakpoint()
|
||||
# # def reshard(x: jnp.ndarray, devices: Array[devices]):
|
||||
# # assert x.ndim == devices.ndim
|
||||
# # e.g . x.shape = (4, 3, 2)
|
||||
# # devices.shape = (4, 1, 2)
|
||||
# # reshard(x, devices.reshape(4, 1, 2))
|
||||
# sharded_args = [jax.device_put_sharded(list(x), devices) for x in args]
|
||||
# with core.new_sublevel(), core.extend_axis_env(axis_name, axis_size, main):
|
||||
# t = main.with_cur_sublevel()
|
||||
# shard_axes = {axis_name: 0}
|
||||
# tracers = [MapTracer(t, arg, shard_axes) for arg in sharded_args]
|
||||
# ans = fun.call_wrapped(*tracers)
|
||||
# out_tracers = map(t.full_raise, ans)
|
||||
# outvals = [t.val for t in out_tracers]
|
||||
# return outvals
|
||||
# else:
|
||||
# breakpoint()
|
@ -727,39 +727,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
||||
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""),
|
||||
"device_mesh_shape": device_mesh_shape}
|
||||
for device_mesh_shape in [(1, 2)])
|
||||
def testNestedWithClosure2(self, device_mesh_shape):
|
||||
mesh_shape = self._getMeshShape(device_mesh_shape)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
def test_fun(x):
|
||||
y = jnp.sum(jnp.sin(x))
|
||||
|
||||
@partial(self.pmap, axis_name='j')
|
||||
def g(z):
|
||||
return jnp.exp(x.sum())
|
||||
return grad(lambda w: jnp.sum(g(w)))(x)
|
||||
|
||||
@vmap
|
||||
def baseline_fun(x):
|
||||
y = jnp.sum(jnp.sin(x))
|
||||
|
||||
@vmap
|
||||
def g(z):
|
||||
return jnp.exp(x.sum())
|
||||
return grad(lambda w: jnp.sum(g(w)))(x)
|
||||
|
||||
shape = mesh_shape
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
print(x.shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
|
||||
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
||||
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testShardedDeviceArrays(self):
|
||||
f = lambda x: 2 * x
|
||||
f = self.pmap(f, axis_name='i')
|
||||
@ -3087,8 +3054,6 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out2, input_data)
|
||||
|
||||
def test_pmap_array_in_axes_out_axes(self):
|
||||
if config.jax_disable_jit:
|
||||
raise SkipTest("test hangs with disable_jit") # TODO(mattjj,sharadmv,yashkatariya)
|
||||
dc = jax.device_count()
|
||||
input_shape = (dc, 2)
|
||||
a1, input_data = create_input_array_for_pmap(input_shape, in_axes=0)
|
||||
@ -3180,14 +3145,17 @@ class ArrayVmapPmapCollectivesTest(ArrayPmapMixin, VmapPmapCollectivesTest):
|
||||
class ArrayPmapWithDevicesTest(ArrayPmapMixin, PmapWithDevicesTest):
|
||||
pass
|
||||
|
||||
class EagerPmapMixin: # i hate mixins
|
||||
class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.eager_pmap_enabled = config.jax_eager_pmap
|
||||
self.jit_disabled = config.jax_disable_jit
|
||||
config.update('jax_disable_jit', True)
|
||||
config.update('jax_eager_pmap', True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update('jax_eager_pmap', self.eager_pmap_enabled)
|
||||
config.update('jax_disable_jit', self.jit_disabled)
|
||||
super().tearDown()
|
||||
|
||||
@ -3203,6 +3171,9 @@ class EagerPmapWithDevicesTest(EagerPmapMixin, PmapWithDevicesTest):
|
||||
class EagerVmapOfPmapTest(EagerPmapMixin, VmapOfPmapTest):
|
||||
pass
|
||||
|
||||
class EagerArrayPmapTest(EagerPmapMixin, ArrayPmapTest):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user