mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
del serial_pmap, simpler papply, add parallelize
The serial_pmap transformation was a placeholder and is now replaced by soft_pmap. The papply tests that used serial_pmap now use soft_pmap, which means they can run on parallel hardware when available. The papply transform had some unused features (e.g. in_axes, out_axes) that won't be needed by parallelize, so those are removed. It is also now only needed for testing now, since parallelize (which essentially composes a soft_pmap with a papply) is likely to be the primary user-facing API. This commit adds the parallelize transformation and some tests for it, including exhaustive transpose tests. Misc changes: * simplified the transpose papply rule and made it lazy (so that it doesn't need to perform communication) * misc bugs encountered * a few lines cherry-picked from frostig@ branch, namely the fixed broadcasting_papply rule and plumbing the `size` argument to papply rules * remove psplit primitive and psplit_like primitives and replace it with calls to all_to_all where needed
This commit is contained in:
parent
e36613da90
commit
d64188bcb6
55
jax/api.py
55
jax/api.py
@ -599,7 +599,7 @@ def pmap(fun, axis_name=None):
|
||||
|
||||
class _TempAxisName(object):
|
||||
def __repr__(self):
|
||||
return '<temp axis {}>'.format(hex(id(self)))
|
||||
return '<axis {}>'.format(hex(id(self)))
|
||||
|
||||
def _pmap_axis_size(args):
|
||||
leaves, _ = tree_flatten(args)
|
||||
@ -639,8 +639,8 @@ def soft_pmap(fun, axis_name=None):
|
||||
return pmap(fun, axis_name)(*args) # can map directly onto hardware
|
||||
elif leftover:
|
||||
raise ValueError
|
||||
|
||||
num_chunks = axis_size // chunk_size
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
@ -680,37 +680,46 @@ def _reshape_merge(ans):
|
||||
return merge(batching.get_aval(ans), ans)
|
||||
|
||||
|
||||
def _serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
|
||||
"""Vectorizing pseudo-map for single-program multiple-data (SPMD) functions."""
|
||||
axis_name = _TempAxisName() if axis_name is None else axis_name
|
||||
|
||||
def map_fun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
|
||||
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
out_flat = parallel.serial_pmap(jaxtree_fun, axis_name, in_flat, in_axes_, out_axes)
|
||||
return build_tree(out_tree(), out_flat)
|
||||
|
||||
return map_fun
|
||||
|
||||
|
||||
def _papply(fun, axis_size, in_axes=0, out_axes=0):
|
||||
"""Apply a function using parallel computation by sharding inputs."""
|
||||
axis_name = parallel.newvar()
|
||||
def _papply(fun):
|
||||
# This function is for testing purposes.
|
||||
axis_name = _TempAxisName()
|
||||
|
||||
def papply_fun(*args, **kwargs):
|
||||
axis_size = _pmap_axis_size(args)
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
|
||||
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, axis_size,
|
||||
in_axes_, out_axes)
|
||||
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, axis_size)
|
||||
return build_tree(out_tree(), out_flat)
|
||||
|
||||
return papply_fun, axis_name
|
||||
|
||||
|
||||
def _parallelize(fun):
|
||||
axis_name = _TempAxisName()
|
||||
|
||||
def pfun(*args):
|
||||
axis_size = _pmap_axis_size(args)
|
||||
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
|
||||
if chunk_size == 0 and leftover:
|
||||
return pmap(fun, axis_name)(*args) # can map directly onto hardware
|
||||
elif leftover:
|
||||
raise ValueError
|
||||
num_chunks = axis_size // chunk_size
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
|
||||
args_flat = map(partial(_reshape_split, num_chunks), args_flat)
|
||||
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
|
||||
f, out_axis = parallel.papply_transform(f, axis_name, axis_size)
|
||||
f = pxla.split_axis(f, axis_name, chunk_size)
|
||||
out = pxla.xla_pmap(f, *args_flat, axis_name=axis_name, axis_size=num_chunks)
|
||||
out = parallel.match_axis(0, out_axis(), _reshape_merge(out))
|
||||
return build_tree(out_tree(), out)
|
||||
|
||||
return pfun
|
||||
|
||||
|
||||
def jvp(fun, primals, tangents):
|
||||
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
|
||||
|
||||
|
@ -458,7 +458,7 @@ def moveaxis2(src, dst, x):
|
||||
return _moveaxis2(src, dst, x, get_aval(x))
|
||||
|
||||
def _moveaxis2(src, dst, x, aval):
|
||||
if type(aval) is JaxTuple:
|
||||
if type(aval) is AbstractTuple:
|
||||
return core.pack(map(partial(_moveaxis2, src, dst), x, aval))
|
||||
else:
|
||||
perm = [i for i in range(onp.ndim(x)) if i != src]
|
||||
@ -480,7 +480,7 @@ def _broadcast2(size, axis, x, aval):
|
||||
|
||||
def _promote_aval_rank(n, batched, aval):
|
||||
assert isinstance(aval, core.AbstractValue)
|
||||
if isinstance(aval, AbstractTuple):
|
||||
if type(aval) is AbstractTuple:
|
||||
t = type(batched)
|
||||
if t is tuple:
|
||||
return AbstractTuple(map(partial(_promote_aval_rank, n), batched, aval))
|
||||
|
@ -39,180 +39,48 @@ zip = safe_zip
|
||||
def identity(x): return x
|
||||
|
||||
|
||||
### serial_pmap is like pmap but executes in a single-machine vectorized way
|
||||
### papply
|
||||
|
||||
|
||||
def serial_pmap(fun, name, in_vals, in_axes, out_axis_target):
|
||||
sizes = reduce(set.union, map(batching.dimsize, in_axes, in_vals))
|
||||
if not sizes:
|
||||
return fun.call_wrapped(*in_vals)
|
||||
elif len(sizes) == 1:
|
||||
fun, out_axis = serial_pmap_transform(fun, name, in_axes)
|
||||
out_val = fun.call_wrapped(*in_vals)
|
||||
return batching.moveaxis(sizes.pop(), out_axis_target, out_axis(), out_val)
|
||||
else:
|
||||
raise TypeError("got inconsistent map dimension sizes: {}".format(sizes))
|
||||
def papply(fun, name, in_vals, axis_size):
|
||||
# this function is for testing purposes, so we drop the out_axis
|
||||
fun, _ = papply_transform(fun, name, axis_size)
|
||||
return fun.call_wrapped(*in_vals)
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def serial_pmap_transform(name, axes, *vals):
|
||||
with new_master(SerialPmapTrace) as master:
|
||||
trace = SerialPmapTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(SerialPmapTracer, trace, name), vals, axes)
|
||||
def papply_transform(name, axis_size, *args):
|
||||
with new_master(PapplyTrace) as master:
|
||||
trace = PapplyTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(PapplyTracer, trace, name, axis_size, axis=0), args)
|
||||
ans = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_val, out_axis = out_tracer.val, out_tracer.axis
|
||||
del master, out_tracer
|
||||
yield out_val, out_axis
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def serial_pmap_subtrace(master, name, axes, *vals):
|
||||
trace = SerialPmapTrace(master, core.cur_sublevel())
|
||||
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes), {}
|
||||
out_tracer = trace.full_raise(ans)
|
||||
out_val, out_axis = out_tracer.val, out_tracer.axis
|
||||
yield out_val, out_axis
|
||||
|
||||
class SerialPmapTracer(Tracer):
|
||||
def __init__(self, trace, name, val, axis):
|
||||
self.trace = trace
|
||||
self.name = name
|
||||
self.val = val
|
||||
self.axis = axis
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
batched_aval = batching.get_aval(self.val)
|
||||
return batching.remove_batch_dim_from_aval(self.axis, batched_aval)
|
||||
|
||||
def unpack(self):
|
||||
t = type(self.axis)
|
||||
if t is tuple:
|
||||
axes = self.axis
|
||||
elif t is int:
|
||||
axes = [self.axis] * len(self.val)
|
||||
elif t is type(None):
|
||||
return tuple(self.val)
|
||||
else:
|
||||
raise TypeError(t)
|
||||
return map(partial(SerialPmapTracer, self.trace, self.name), self.val, axes)
|
||||
|
||||
def full_lower(self):
|
||||
if self.axis is None:
|
||||
return core.full_lower(self.val)
|
||||
else:
|
||||
return self
|
||||
|
||||
class SerialPmapTrace(Trace):
|
||||
def pure(self, val):
|
||||
return SerialPmapTracer(self, None, val, None)
|
||||
|
||||
def lift(self, val):
|
||||
return SerialPmapTracer(self, None, val, None)
|
||||
|
||||
def sublift(self, val):
|
||||
return SerialPmapTracer(self, val.name, val.val, val.axis)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
names_in, vals_in, axes_in = unzip3((t.name, t.val, t.axis) for t in tracers)
|
||||
if all(axis is None for axis in axes_in):
|
||||
return primitive.bind(*vals_in, **params)
|
||||
else:
|
||||
name = next(name for name in names_in if name is not None) # all same
|
||||
if primitive in serial_pmap_primitive_rules:
|
||||
# if it's a pmap collective primitive, do something special
|
||||
if name == params['axis_name']:
|
||||
# if the name matches this tracer's name, apply the pmap rule
|
||||
rule = serial_pmap_primitive_rules[primitive]
|
||||
params = {k: params[k] for k in params if k != 'axis_name'}
|
||||
val_out, axis_out = rule(vals_in, axes_in, **params)
|
||||
return SerialPmapTracer(self, name, val_out, axis_out)
|
||||
else:
|
||||
# if not, bind the primitive so that any other pmap tracers can see it,
|
||||
# assuming an axis equal to that of the first operand
|
||||
val_out = primitive.bind(*vals_in, **params)
|
||||
return SerialPmapTracer(self, name, val_out, axes_in[0])
|
||||
else:
|
||||
# if it's not a pmap collective primitive, act just like vmap
|
||||
rule = batching.get_primitive_batcher(primitive)
|
||||
val_out, axis_out = rule(vals_in, axes_in, **params)
|
||||
return SerialPmapTracer(self, name, val_out, axis_out)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
names, vals, axes = unzip3((t.name, t.val, t.axis) for t in tracers)
|
||||
if all(axis is None for axis in axes):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
name = next(name for name in names if name is not None) # all same
|
||||
f, axis_out = serial_pmap_subtrace(f, self.master, name, axes)
|
||||
val_out = call_primitive.bind(f, *vals, **params)
|
||||
return SerialPmapTracer(self, name, val_out, axis_out())
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracer, params):
|
||||
name, val, axis = out_tracer.name, out_tracer.val, out_tracer.axis
|
||||
master = self.master
|
||||
def todo(x):
|
||||
trace = SerialPmapTrace(master, core.cur_sublevel())
|
||||
return SerialPmapTracer(trace, name, x, axis)
|
||||
|
||||
return val, todo
|
||||
|
||||
def pack(self, tracers):
|
||||
vals = core.pack([t.val for t in tracers])
|
||||
axis = tuple(t.axis for t in tracers)
|
||||
name = next(t.name for t in tracers if t.name)
|
||||
return SerialPmapTracer(self, name, vals, axis)
|
||||
|
||||
|
||||
serial_pmap_primitive_rules = {}
|
||||
|
||||
|
||||
### papply
|
||||
|
||||
|
||||
newvar = pe.gensym('_axis')
|
||||
|
||||
def papply(fun, name, in_vals, axis_size, in_axes, out_axis):
|
||||
out_val = papply_transform(fun).call_wrapped(
|
||||
name, in_vals, axis_size, in_axes, out_axis)
|
||||
return out_val
|
||||
|
||||
def ensure_axis(dst, src, x):
|
||||
aval = batching.get_aval(x)
|
||||
if type(aval) is core.AbstractTuple:
|
||||
if type(src) is tuple and type(dst) is tuple:
|
||||
return core.pack(map(ensure_axis, dst, src, x))
|
||||
elif type(src) is tuple:
|
||||
return core.pack(map(partial(ensure_axis, dst), src, x))
|
||||
elif type(dst) is tuple:
|
||||
srcs = (src,) * len(dst)
|
||||
return core.pack(map(ensure_axis, dst, srcs, x))
|
||||
else:
|
||||
return core.pack(map(partial(ensure_axis, dst, src), x))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
if src == dst:
|
||||
return x
|
||||
elif src is None:
|
||||
warnings.warn('split output axis requested for an array with no split')
|
||||
return x
|
||||
else:
|
||||
perm = list(range(x.ndim))
|
||||
perm[src] = dst
|
||||
perm[dst] = src
|
||||
return x.transpose(perm)
|
||||
def match_axis(src, dst, x):
|
||||
assert type(src) is int
|
||||
if src == dst:
|
||||
return x
|
||||
else:
|
||||
raise TypeError(type(aval))
|
||||
return _match_axis(src, dst, x, core.get_aval(x))
|
||||
|
||||
@lu.transformation
|
||||
def papply_transform(name, args, axis_size, in_axes, out_axis):
|
||||
with new_master(PapplyTrace) as master:
|
||||
trace = PapplyTrace(master, core.cur_sublevel())
|
||||
in_tracers = map(partial(PapplyTracer, trace, name, axis_size), args, in_axes)
|
||||
out_tracer = yield in_tracers, {}
|
||||
out_tracer = trace.full_raise(out_tracer)
|
||||
out_tracer = ensure_axis(out_axis, out_tracer.axis, out_tracer)
|
||||
out_val = out_tracer.val
|
||||
del master, out_tracer
|
||||
yield out_val
|
||||
def _match_axis(src, dst, x, aval):
|
||||
if type(aval) is core.AbstractTuple:
|
||||
if type(dst) is tuple:
|
||||
return core.pack(map(partial(_match_axis, src), dst, x, aval))
|
||||
else:
|
||||
return core.pack(map(partial(_match_axis, src, dst), x, aval))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
if type(dst) is int:
|
||||
perm = [i for i in range(x.ndim) if i != src]
|
||||
perm.insert(dst, src)
|
||||
return x.transpose(perm)
|
||||
elif dst is None:
|
||||
return x[src]
|
||||
else:
|
||||
raise TypeError(dst)
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
class PapplyTracer(Tracer):
|
||||
def __init__(self, trace, name, axis_size, val, axis):
|
||||
@ -315,7 +183,8 @@ def broadcasting_papply(prim, name, size, vals, axes, **params):
|
||||
elif xdim == ydim:
|
||||
return prim.bind(x, y, **params), xdim
|
||||
else:
|
||||
x = psplit(x, axis_name, ydim, xdim)
|
||||
from jax.lax.lax_parallel import all_to_all # TODO circular deps
|
||||
x = all_to_all(x, name, ydim - int(xdim <= ydim), xdim)
|
||||
return prim.bind(x, y, **params), ydim
|
||||
|
||||
|
||||
|
@ -2966,6 +2966,7 @@ ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
parallel.papply_primitive_rules[gather_p] = _gather_papply_rule
|
||||
|
||||
|
||||
class ScatterDimensionNumbers(collections.namedtuple(
|
||||
"ScatterDimensionNumbers",
|
||||
["update_window_dims", "inserted_window_dims",
|
||||
|
@ -159,14 +159,11 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
|
||||
"""
|
||||
if psum(1, axis_name) != x.shape[split_axis]:
|
||||
msg = ("all_to_all requires the size of the mapped axis axis_name to equal "
|
||||
"x.shape[split_axis], but they are {} and {} respectively.")
|
||||
"x.shape[split_axis], but they are {} and {} respectively.")
|
||||
raise ValueError(msg.format(psum(1, axis_name), x.shape[split_axis]))
|
||||
return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
|
||||
axis_name=axis_name)
|
||||
|
||||
def psplit_like(x, y, axis_name):
|
||||
"""Ensure the named mapped axis of ``x`` aligns with that of ``y``."""
|
||||
return psplit_like_p.bind(x, y, axis_name=axis_name)
|
||||
|
||||
def pcollect(x, axis_name):
|
||||
return pcollect_p.bind(x, axis_name=axis_name)
|
||||
@ -181,11 +178,6 @@ def standard_pmap_primitive(name):
|
||||
return prim
|
||||
|
||||
|
||||
def _allreduce_serial_pmap_rule(reducer, vals, axes):
|
||||
val, = vals
|
||||
axis, = axes
|
||||
return reducer(val, [axis]), None
|
||||
|
||||
def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name):
|
||||
assert tuple(which_mapped) == (True,)
|
||||
x, = vals
|
||||
@ -199,8 +191,6 @@ def _allreduce_translation_rule(prim, c, val, replica_groups):
|
||||
|
||||
psum_p = standard_pmap_primitive('psum')
|
||||
parallel.defreducer(lax.reduce_sum_p, psum_p)
|
||||
parallel.serial_pmap_primitive_rules[psum_p] = \
|
||||
partial(_allreduce_serial_pmap_rule, lax._reduce_sum)
|
||||
pxla.split_axis_rules[psum_p] = \
|
||||
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
|
||||
pxla.parallel_translation_rules[psum_p] = \
|
||||
@ -211,18 +201,18 @@ ad.deflinear(psum_p, lambda t, axis_name: [t])
|
||||
|
||||
pmax_p = standard_pmap_primitive('pmax')
|
||||
parallel.defreducer(lax.reduce_max_p, pmax_p)
|
||||
parallel.serial_pmap_primitive_rules[pmax_p] = \
|
||||
partial(_allreduce_serial_pmap_rule, lax._reduce_max)
|
||||
pxla.parallel_translation_rules[pmax_p] = \
|
||||
partial(_allreduce_translation_rule, lax.max_p)
|
||||
pxla.split_axis_rules[pmax_p] = \
|
||||
partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)
|
||||
|
||||
|
||||
pmin_p = standard_pmap_primitive('pmin')
|
||||
parallel.defreducer(lax.reduce_min_p, pmin_p)
|
||||
parallel.serial_pmap_primitive_rules[pmin_p] = \
|
||||
partial(_allreduce_serial_pmap_rule, lax._reduce_min)
|
||||
pxla.parallel_translation_rules[pmin_p] = \
|
||||
partial(_allreduce_translation_rule, lax.min_p)
|
||||
pxla.split_axis_rules[pmin_p] = \
|
||||
partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
|
||||
|
||||
|
||||
def _ppermute_translation_rule(c, x, replica_groups, perm):
|
||||
@ -268,35 +258,11 @@ def _moveaxis(src, dst, x):
|
||||
perm.insert(dst, src)
|
||||
return lax.transpose(x, perm)
|
||||
|
||||
|
||||
all_to_all_p = standard_pmap_primitive('all_to_all')
|
||||
pxla.parallel_translation_rules[all_to_all_p] = _all_to_all_translation_rule
|
||||
pxla.split_axis_rules[all_to_all_p] = _all_to_all_split_axis_rule
|
||||
|
||||
|
||||
def _psplit_like_serial_pmap_rule(vals, axes):
|
||||
x, y = vals
|
||||
xaxis, yaxis = axes
|
||||
if xaxis is not None and x.shape[xaxis] != x.shape[yaxis]:
|
||||
raise ValueError(
|
||||
"psplit_like is a non-square re-split along {} and {} of {}".format(
|
||||
xaxis, yaxis, x.shape))
|
||||
return x, yaxis
|
||||
|
||||
psplit_like_p = standard_pmap_primitive('psplit_like')
|
||||
psplit_like_p.def_abstract_eval(
|
||||
lambda x, y, *args, **kwargs: ShapedArray(y.shape, x.dtype))
|
||||
parallel.serial_pmap_primitive_rules[psplit_like_p] = _psplit_like_serial_pmap_rule
|
||||
|
||||
|
||||
def _pcollect_serial_pmap_rule(vals, axes):
|
||||
x, = vals
|
||||
return x, None
|
||||
|
||||
pcollect_p = standard_pmap_primitive('pcollect')
|
||||
parallel.serial_pmap_primitive_rules[pcollect_p] = _pcollect_serial_pmap_rule
|
||||
|
||||
|
||||
### papply rules
|
||||
# TODO(skye): it would be nice if we could put these with their corresponding
|
||||
# primitives, but that currently causes circular dependencies. More refactoring
|
||||
@ -407,35 +373,24 @@ def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions,
|
||||
def _transpose_papply_rule(name, size, vals, dims, permutation):
|
||||
x, = vals
|
||||
xdim, = dims
|
||||
perm = list(permutation)
|
||||
if perm[xdim] == xdim:
|
||||
perm = [i - 1 if i > xdim else i for i in perm if i != xdim]
|
||||
x = lax.transpose(x, perm)
|
||||
out_dim = xdim
|
||||
else:
|
||||
in_dim, = [i for i in range(len(perm)) if perm[i] == xdim]
|
||||
out_dim = perm[xdim]
|
||||
perm[in_dim] = out_dim
|
||||
perm[out_dim] = in_dim
|
||||
perm = perm[:xdim] + perm[xdim + 1:]
|
||||
perm = [i - 1 if i > xdim else i for i in perm]
|
||||
x = lax.transpose(x, perm)
|
||||
x = pswapaxes(x, name, in_dim)
|
||||
return x, xdim
|
||||
local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim]
|
||||
return lax.transpose(x, local_perm), permutation.index(xdim)
|
||||
|
||||
|
||||
def _select_papply_rule(name, size, vals, dims):
|
||||
dimset = set([d for d in dims if d is not None])
|
||||
dimset = {d for d in dims if d is not None}
|
||||
if len(dimset) != 1:
|
||||
raise NotImplementedError(
|
||||
'papply of select with operands split along different dimensions')
|
||||
like_val, like_dim = [(v, d) for v, d in zip(vals, dims) if d is not None][0]
|
||||
dim, = dimset
|
||||
|
||||
def normalize_split(val, dim):
|
||||
return psplit_like(val, like_val, name) if dim is None else val
|
||||
def drop(x, d):
|
||||
if d is None:
|
||||
return lax.dynamic_index_in_dim(x, axis_index(name), dim, False)
|
||||
else:
|
||||
return x
|
||||
|
||||
vals = [normalize_split(v, d) for v, d in zip(vals, dims)]
|
||||
return lax.select_p.bind(*vals), like_dim
|
||||
return lax.select_p.bind(*map(drop, vals, dims)), dim
|
||||
|
||||
|
||||
def _add_jaxvals_papply_rule(name, size, vals, dims):
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
from unittest import SkipTest
|
||||
|
||||
import numpy as onp
|
||||
@ -25,80 +26,30 @@ from absl.testing import parameterized
|
||||
import jax.numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax import lax
|
||||
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
|
||||
from jax.api import _papply, _parallelize, soft_pmap, jit, make_jaxpr
|
||||
from jax.linear_util import wrap_init
|
||||
from jax.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class SerialPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testConstantFunction(self):
|
||||
f = lambda x: 3
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 3 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceSum(self):
|
||||
f = lambda x: lax.psum(x, 'i')
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
|
||||
expected = 4 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testReduceMax(self):
|
||||
f = lambda x: lax.pmax(x, 'i')
|
||||
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
|
||||
expected = 3 * onp.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplit(self):
|
||||
f = lambda x: lax.psplit(x, 'i', 2, 0)
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplitLike(self):
|
||||
f = lambda x, y: lax.psplit_like(x, y, 'i')
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
|
||||
expected = arg
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
|
||||
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
|
||||
ans = _serial_pmap(f, axis_name='i')(x)
|
||||
expected = x - onp.log(onp.sum(onp.exp(x)))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testNested(self):
|
||||
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
||||
x = onp.ones((2, 2))
|
||||
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
|
||||
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
|
||||
expected = 4 * onp.ones((2, 2))
|
||||
self.assertAllClose(ans1, expected, check_dtypes=False)
|
||||
self.assertAllClose(ans2, expected, check_dtypes=False)
|
||||
|
||||
|
||||
class PapplyTest(jtu.JaxTestCase):
|
||||
|
||||
def testIdentity(self):
|
||||
pfun, axis_name = _papply(lambda x: x, 3)
|
||||
pfun, axis_name = _papply(lambda x: x)
|
||||
ans = pfun(onp.arange(3))
|
||||
expected = onp.arange(3)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMap(self):
|
||||
pfun, axis_name = _papply(np.sin, 3)
|
||||
pfun, axis_name = _papply(np.sin)
|
||||
ans = pfun(onp.arange(3.))
|
||||
expected = onp.sin(onp.arange(3.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSum(self):
|
||||
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5)
|
||||
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0))
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
@ -106,12 +57,12 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
ans = _serial_pmap(pfun, axis_name)(arg)[0]
|
||||
ans = soft_pmap(pfun, axis_name)(arg)[0]
|
||||
expected = onp.sum(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMax(self):
|
||||
pfun, axis_name = _papply(lambda x: np.max(x, axis=0), 5)
|
||||
pfun, axis_name = _papply(lambda x: np.max(x, axis=0))
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.ones(3))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
@ -119,30 +70,20 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
arg = onp.arange(15.).reshape((5, 3))
|
||||
ans = _serial_pmap(pfun, axis_name)(arg)[0]
|
||||
ans = soft_pmap(pfun, axis_name)(arg)[0]
|
||||
expected = onp.max(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSelect(self):
|
||||
pfun, axis_name = _papply(lax.select, 5,
|
||||
in_axes=(None, 0, None))
|
||||
|
||||
p = onp.arange(15).reshape((5, 3)) % 4 == 1
|
||||
t = onp.ones((5, 3))
|
||||
f = onp.zeros((5, 3))
|
||||
jaxpr = make_jaxpr(pfun)(p, t[0], f)
|
||||
|
||||
def expected_spmd(p, t, f):
|
||||
return lax.select(
|
||||
lax.psplit_like(p, t, axis_name),
|
||||
t,
|
||||
lax.psplit_like(f, t, axis_name))
|
||||
def fun(t):
|
||||
return lax.select(p, t, f)
|
||||
|
||||
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
|
||||
expected = lax.select(p, t, f)
|
||||
t = onp.ones((5, 3))
|
||||
ans = soft_pmap(*_papply(fun))(t)
|
||||
expected = fun(t)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def testLogSoftmax(self):
|
||||
@ -151,14 +92,14 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
def fun(x):
|
||||
return x - np.log(np.sum(np.exp(x)))
|
||||
|
||||
pfun, axis_name = _papply(fun, 5)
|
||||
pfun, axis_name = _papply(fun)
|
||||
|
||||
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
|
||||
expected_jaxpr = make_jaxpr(
|
||||
lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
|
||||
ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.))
|
||||
expected = fun(onp.arange(1., 5.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -166,8 +107,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.array([[1, 2, 3], [4, 5, 6]])
|
||||
expected = x + x
|
||||
|
||||
pfun, axis_name = _papply(np.add, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x, x)
|
||||
pfun, axis_name = _papply(np.add)
|
||||
ans = soft_pmap(pfun, axis_name)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def testAddBroadcasting(self):
|
||||
@ -179,39 +120,47 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
x = onp.array([[1, 2], [3, 4]])
|
||||
expected = x + 3
|
||||
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
pfun, axis_name = _papply(fun)
|
||||
ans = soft_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def testTranspose(self):
|
||||
|
||||
class ParallelizeTest(jtu.JaxTestCase):
|
||||
|
||||
def testNormalize(self):
|
||||
def f(x):
|
||||
return x / x.sum(0)
|
||||
|
||||
x = onp.arange(4.)
|
||||
expected = f(x)
|
||||
ans = _parallelize(f)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = make_jaxpr(_parallelize(f))(x)
|
||||
self.assertIn('psum', repr(jaxpr))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "testTranspose_shape={}_perm={}"
|
||||
.format(shape, perm),
|
||||
"shape": shape, "perm": perm}
|
||||
for shape in [
|
||||
(2, 2),
|
||||
(3, 3),
|
||||
(2, 2, 2),
|
||||
(2, 3, 4),
|
||||
(2, 3, 2)
|
||||
]
|
||||
for perm in itertools.permutations(list(range(len(shape))))
|
||||
))
|
||||
def testTranspose(self, shape, perm):
|
||||
|
||||
def fun(x):
|
||||
return x.T
|
||||
return lax.transpose(x, perm)
|
||||
|
||||
xs = [
|
||||
onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
|
||||
onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
|
||||
]
|
||||
for x in xs:
|
||||
expected = x.T
|
||||
pfun, axis_name = _papply(fun, x.shape[0])
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeWithOddPermutation(self):
|
||||
|
||||
def fun(x):
|
||||
return np.transpose(x, (2, 0, 1))
|
||||
|
||||
xs = [
|
||||
onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)),
|
||||
onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)),
|
||||
]
|
||||
for x in xs:
|
||||
expected = np.transpose(x, (2, 0, 1))
|
||||
pfun, axis_name = _papply(fun, x.shape[0])
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
x = onp.arange(prod(shape)).reshape(shape)
|
||||
expected = fun(x)
|
||||
ans = _parallelize(fun)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeAndAddRank2(self):
|
||||
|
||||
@ -219,10 +168,8 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
return x + x.T
|
||||
|
||||
x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))
|
||||
expected = x + x.T
|
||||
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
expected = fun(x)
|
||||
ans = _parallelize(fun)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testTransposeAndAddRank3(self):
|
||||
@ -231,28 +178,26 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
return x + x.T
|
||||
|
||||
x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
|
||||
expected = x + x.T
|
||||
|
||||
pfun, axis_name = _papply(fun, 2)
|
||||
ans = _serial_pmap(pfun, axis_name)(x)
|
||||
expected = fun(x)
|
||||
ans = _parallelize(fun)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDot(self):
|
||||
return SkipTest("test doesn't pass yet") # TODO(frostig)
|
||||
# def testDot(self):
|
||||
# return SkipTest("test doesn't pass yet") # TODO(frostig)
|
||||
|
||||
def fun(x, y):
|
||||
return lax.dot(x, y)
|
||||
xs = [
|
||||
onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
|
||||
onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
|
||||
]
|
||||
in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
|
||||
for in_axes in in_axes_combos:
|
||||
for x in xs:
|
||||
expected = fun(x, x)
|
||||
pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
|
||||
ans = _serial_pmap(pfun, axis_name)(x, x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
# def fun(x, y):
|
||||
# return lax.dot(x, y)
|
||||
# xs = [
|
||||
# onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
|
||||
# onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
|
||||
# ]
|
||||
# in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
|
||||
# for in_axes in in_axes_combos:
|
||||
# for x in xs:
|
||||
# expected = fun(x, x)
|
||||
# pfun, axis_name = _papply(fun)
|
||||
# ans = soft_pmap(pfun, axis_name)(x, x)
|
||||
# self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user