del serial_pmap, simpler papply, add parallelize

The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.

The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.

This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.

Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
  doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
  broadcasting_papply rule and plumbing the `size` argument to papply
  rules
* remove psplit primitive and psplit_like primitives and replace it with
  calls to all_to_all where needed
This commit is contained in:
Matthew Johnson 2019-06-23 20:01:53 -07:00
parent e36613da90
commit d64188bcb6
6 changed files with 156 additions and 377 deletions

View File

@ -599,7 +599,7 @@ def pmap(fun, axis_name=None):
class _TempAxisName(object):
def __repr__(self):
return '<temp axis {}>'.format(hex(id(self)))
return '<axis {}>'.format(hex(id(self)))
def _pmap_axis_size(args):
leaves, _ = tree_flatten(args)
@ -639,8 +639,8 @@ def soft_pmap(fun, axis_name=None):
return pmap(fun, axis_name)(*args) # can map directly onto hardware
elif leftover:
raise ValueError
num_chunks = axis_size // chunk_size
f = lu.wrap_init(fun)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
@ -680,37 +680,46 @@ def _reshape_merge(ans):
return merge(batching.get_aval(ans), ans)
def _serial_pmap(fun, axis_name=None, in_axes=0, out_axes=0):
"""Vectorizing pseudo-map for single-program multiple-data (SPMD) functions."""
axis_name = _TempAxisName() if axis_name is None else axis_name
def map_fun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
in_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = parallel.serial_pmap(jaxtree_fun, axis_name, in_flat, in_axes_, out_axes)
return build_tree(out_tree(), out_flat)
return map_fun
def _papply(fun, axis_size, in_axes=0, out_axes=0):
"""Apply a function using parallel computation by sharding inputs."""
axis_name = parallel.newvar()
def _papply(fun):
# This function is for testing purposes.
axis_name = _TempAxisName()
def papply_fun(*args, **kwargs):
axis_size = _pmap_axis_size(args)
f = lu.wrap_init(fun, kwargs)
in_axes_ = in_axes if isinstance(in_axes, (list, tuple)) else (in_axes,) * len(args)
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, axis_size,
in_axes_, out_axes)
out_flat = parallel.papply(jaxtree_fun, axis_name, args_flat, axis_size)
return build_tree(out_tree(), out_flat)
return papply_fun, axis_name
def _parallelize(fun):
axis_name = _TempAxisName()
def pfun(*args):
axis_size = _pmap_axis_size(args)
chunk_size, leftover = divmod(axis_size, pxla.unmapped_device_count())
if chunk_size == 0 and leftover:
return pmap(fun, axis_name)(*args) # can map directly onto hardware
elif leftover:
raise ValueError
num_chunks = axis_size // chunk_size
f = lu.wrap_init(fun)
args_flat, in_trees = unzip2(map(pytree_to_jaxtupletree, args))
args_flat = map(partial(_reshape_split, num_chunks), args_flat)
f, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
f, out_axis = parallel.papply_transform(f, axis_name, axis_size)
f = pxla.split_axis(f, axis_name, chunk_size)
out = pxla.xla_pmap(f, *args_flat, axis_name=axis_name, axis_size=num_chunks)
out = parallel.match_axis(0, out_axis(), _reshape_merge(out))
return build_tree(out_tree(), out)
return pfun
def jvp(fun, primals, tangents):
"""Computes a (forward-mode) Jacobian-vector product of `fun`.

View File

@ -458,7 +458,7 @@ def moveaxis2(src, dst, x):
return _moveaxis2(src, dst, x, get_aval(x))
def _moveaxis2(src, dst, x, aval):
if type(aval) is JaxTuple:
if type(aval) is AbstractTuple:
return core.pack(map(partial(_moveaxis2, src, dst), x, aval))
else:
perm = [i for i in range(onp.ndim(x)) if i != src]
@ -480,7 +480,7 @@ def _broadcast2(size, axis, x, aval):
def _promote_aval_rank(n, batched, aval):
assert isinstance(aval, core.AbstractValue)
if isinstance(aval, AbstractTuple):
if type(aval) is AbstractTuple:
t = type(batched)
if t is tuple:
return AbstractTuple(map(partial(_promote_aval_rank, n), batched, aval))

View File

@ -39,180 +39,48 @@ zip = safe_zip
def identity(x): return x
### serial_pmap is like pmap but executes in a single-machine vectorized way
### papply
def serial_pmap(fun, name, in_vals, in_axes, out_axis_target):
sizes = reduce(set.union, map(batching.dimsize, in_axes, in_vals))
if not sizes:
return fun.call_wrapped(*in_vals)
elif len(sizes) == 1:
fun, out_axis = serial_pmap_transform(fun, name, in_axes)
out_val = fun.call_wrapped(*in_vals)
return batching.moveaxis(sizes.pop(), out_axis_target, out_axis(), out_val)
else:
raise TypeError("got inconsistent map dimension sizes: {}".format(sizes))
def papply(fun, name, in_vals, axis_size):
# this function is for testing purposes, so we drop the out_axis
fun, _ = papply_transform(fun, name, axis_size)
return fun.call_wrapped(*in_vals)
@lu.transformation_with_aux
def serial_pmap_transform(name, axes, *vals):
with new_master(SerialPmapTrace) as master:
trace = SerialPmapTrace(master, core.cur_sublevel())
in_tracers = map(partial(SerialPmapTracer, trace, name), vals, axes)
def papply_transform(name, axis_size, *args):
with new_master(PapplyTrace) as master:
trace = PapplyTrace(master, core.cur_sublevel())
in_tracers = map(partial(PapplyTracer, trace, name, axis_size, axis=0), args)
ans = yield in_tracers, {}
out_tracer = trace.full_raise(ans)
out_val, out_axis = out_tracer.val, out_tracer.axis
del master, out_tracer
yield out_val, out_axis
@lu.transformation_with_aux
def serial_pmap_subtrace(master, name, axes, *vals):
trace = SerialPmapTrace(master, core.cur_sublevel())
ans = yield map(partial(SerialPmapTracer, trace, name), vals, axes), {}
out_tracer = trace.full_raise(ans)
out_val, out_axis = out_tracer.val, out_tracer.axis
yield out_val, out_axis
class SerialPmapTracer(Tracer):
def __init__(self, trace, name, val, axis):
self.trace = trace
self.name = name
self.val = val
self.axis = axis
@property
def aval(self):
batched_aval = batching.get_aval(self.val)
return batching.remove_batch_dim_from_aval(self.axis, batched_aval)
def unpack(self):
t = type(self.axis)
if t is tuple:
axes = self.axis
elif t is int:
axes = [self.axis] * len(self.val)
elif t is type(None):
return tuple(self.val)
else:
raise TypeError(t)
return map(partial(SerialPmapTracer, self.trace, self.name), self.val, axes)
def full_lower(self):
if self.axis is None:
return core.full_lower(self.val)
else:
return self
class SerialPmapTrace(Trace):
def pure(self, val):
return SerialPmapTracer(self, None, val, None)
def lift(self, val):
return SerialPmapTracer(self, None, val, None)
def sublift(self, val):
return SerialPmapTracer(self, val.name, val.val, val.axis)
def process_primitive(self, primitive, tracers, params):
names_in, vals_in, axes_in = unzip3((t.name, t.val, t.axis) for t in tracers)
if all(axis is None for axis in axes_in):
return primitive.bind(*vals_in, **params)
else:
name = next(name for name in names_in if name is not None) # all same
if primitive in serial_pmap_primitive_rules:
# if it's a pmap collective primitive, do something special
if name == params['axis_name']:
# if the name matches this tracer's name, apply the pmap rule
rule = serial_pmap_primitive_rules[primitive]
params = {k: params[k] for k in params if k != 'axis_name'}
val_out, axis_out = rule(vals_in, axes_in, **params)
return SerialPmapTracer(self, name, val_out, axis_out)
else:
# if not, bind the primitive so that any other pmap tracers can see it,
# assuming an axis equal to that of the first operand
val_out = primitive.bind(*vals_in, **params)
return SerialPmapTracer(self, name, val_out, axes_in[0])
else:
# if it's not a pmap collective primitive, act just like vmap
rule = batching.get_primitive_batcher(primitive)
val_out, axis_out = rule(vals_in, axes_in, **params)
return SerialPmapTracer(self, name, val_out, axis_out)
def process_call(self, call_primitive, f, tracers, params):
names, vals, axes = unzip3((t.name, t.val, t.axis) for t in tracers)
if all(axis is None for axis in axes):
return call_primitive.bind(f, *vals, **params)
else:
name = next(name for name in names if name is not None) # all same
f, axis_out = serial_pmap_subtrace(f, self.master, name, axes)
val_out = call_primitive.bind(f, *vals, **params)
return SerialPmapTracer(self, name, val_out, axis_out())
def post_process_call(self, call_primitive, out_tracer, params):
name, val, axis = out_tracer.name, out_tracer.val, out_tracer.axis
master = self.master
def todo(x):
trace = SerialPmapTrace(master, core.cur_sublevel())
return SerialPmapTracer(trace, name, x, axis)
return val, todo
def pack(self, tracers):
vals = core.pack([t.val for t in tracers])
axis = tuple(t.axis for t in tracers)
name = next(t.name for t in tracers if t.name)
return SerialPmapTracer(self, name, vals, axis)
serial_pmap_primitive_rules = {}
### papply
newvar = pe.gensym('_axis')
def papply(fun, name, in_vals, axis_size, in_axes, out_axis):
out_val = papply_transform(fun).call_wrapped(
name, in_vals, axis_size, in_axes, out_axis)
return out_val
def ensure_axis(dst, src, x):
aval = batching.get_aval(x)
if type(aval) is core.AbstractTuple:
if type(src) is tuple and type(dst) is tuple:
return core.pack(map(ensure_axis, dst, src, x))
elif type(src) is tuple:
return core.pack(map(partial(ensure_axis, dst), src, x))
elif type(dst) is tuple:
srcs = (src,) * len(dst)
return core.pack(map(ensure_axis, dst, srcs, x))
else:
return core.pack(map(partial(ensure_axis, dst, src), x))
elif isinstance(aval, ShapedArray):
if src == dst:
return x
elif src is None:
warnings.warn('split output axis requested for an array with no split')
return x
else:
perm = list(range(x.ndim))
perm[src] = dst
perm[dst] = src
return x.transpose(perm)
def match_axis(src, dst, x):
assert type(src) is int
if src == dst:
return x
else:
raise TypeError(type(aval))
return _match_axis(src, dst, x, core.get_aval(x))
@lu.transformation
def papply_transform(name, args, axis_size, in_axes, out_axis):
with new_master(PapplyTrace) as master:
trace = PapplyTrace(master, core.cur_sublevel())
in_tracers = map(partial(PapplyTracer, trace, name, axis_size), args, in_axes)
out_tracer = yield in_tracers, {}
out_tracer = trace.full_raise(out_tracer)
out_tracer = ensure_axis(out_axis, out_tracer.axis, out_tracer)
out_val = out_tracer.val
del master, out_tracer
yield out_val
def _match_axis(src, dst, x, aval):
if type(aval) is core.AbstractTuple:
if type(dst) is tuple:
return core.pack(map(partial(_match_axis, src), dst, x, aval))
else:
return core.pack(map(partial(_match_axis, src, dst), x, aval))
elif isinstance(aval, ShapedArray):
if type(dst) is int:
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
return x.transpose(perm)
elif dst is None:
return x[src]
else:
raise TypeError(dst)
else:
raise TypeError(aval)
class PapplyTracer(Tracer):
def __init__(self, trace, name, axis_size, val, axis):
@ -315,7 +183,8 @@ def broadcasting_papply(prim, name, size, vals, axes, **params):
elif xdim == ydim:
return prim.bind(x, y, **params), xdim
else:
x = psplit(x, axis_name, ydim, xdim)
from jax.lax.lax_parallel import all_to_all # TODO circular deps
x = all_to_all(x, name, ydim - int(xdim <= ydim), xdim)
return prim.bind(x, y, **params), ydim

View File

@ -2966,6 +2966,7 @@ ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
parallel.papply_primitive_rules[gather_p] = _gather_papply_rule
class ScatterDimensionNumbers(collections.namedtuple(
"ScatterDimensionNumbers",
["update_window_dims", "inserted_window_dims",

View File

@ -159,14 +159,11 @@ def all_to_all(x, axis_name, split_axis, concat_axis):
"""
if psum(1, axis_name) != x.shape[split_axis]:
msg = ("all_to_all requires the size of the mapped axis axis_name to equal "
"x.shape[split_axis], but they are {} and {} respectively.")
"x.shape[split_axis], but they are {} and {} respectively.")
raise ValueError(msg.format(psum(1, axis_name), x.shape[split_axis]))
return all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis,
axis_name=axis_name)
def psplit_like(x, y, axis_name):
"""Ensure the named mapped axis of ``x`` aligns with that of ``y``."""
return psplit_like_p.bind(x, y, axis_name=axis_name)
def pcollect(x, axis_name):
return pcollect_p.bind(x, axis_name=axis_name)
@ -181,11 +178,6 @@ def standard_pmap_primitive(name):
return prim
def _allreduce_serial_pmap_rule(reducer, vals, axes):
val, = vals
axis, = axes
return reducer(val, [axis]), None
def _allreduce_split_axis_rule(prim, reducer, vals, which_mapped, axis_name):
assert tuple(which_mapped) == (True,)
x, = vals
@ -199,8 +191,6 @@ def _allreduce_translation_rule(prim, c, val, replica_groups):
psum_p = standard_pmap_primitive('psum')
parallel.defreducer(lax.reduce_sum_p, psum_p)
parallel.serial_pmap_primitive_rules[psum_p] = \
partial(_allreduce_serial_pmap_rule, lax._reduce_sum)
pxla.split_axis_rules[psum_p] = \
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
pxla.parallel_translation_rules[psum_p] = \
@ -211,18 +201,18 @@ ad.deflinear(psum_p, lambda t, axis_name: [t])
pmax_p = standard_pmap_primitive('pmax')
parallel.defreducer(lax.reduce_max_p, pmax_p)
parallel.serial_pmap_primitive_rules[pmax_p] = \
partial(_allreduce_serial_pmap_rule, lax._reduce_max)
pxla.parallel_translation_rules[pmax_p] = \
partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)
pmin_p = standard_pmap_primitive('pmin')
parallel.defreducer(lax.reduce_min_p, pmin_p)
parallel.serial_pmap_primitive_rules[pmin_p] = \
partial(_allreduce_serial_pmap_rule, lax._reduce_min)
pxla.parallel_translation_rules[pmin_p] = \
partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
def _ppermute_translation_rule(c, x, replica_groups, perm):
@ -268,35 +258,11 @@ def _moveaxis(src, dst, x):
perm.insert(dst, src)
return lax.transpose(x, perm)
all_to_all_p = standard_pmap_primitive('all_to_all')
pxla.parallel_translation_rules[all_to_all_p] = _all_to_all_translation_rule
pxla.split_axis_rules[all_to_all_p] = _all_to_all_split_axis_rule
def _psplit_like_serial_pmap_rule(vals, axes):
x, y = vals
xaxis, yaxis = axes
if xaxis is not None and x.shape[xaxis] != x.shape[yaxis]:
raise ValueError(
"psplit_like is a non-square re-split along {} and {} of {}".format(
xaxis, yaxis, x.shape))
return x, yaxis
psplit_like_p = standard_pmap_primitive('psplit_like')
psplit_like_p.def_abstract_eval(
lambda x, y, *args, **kwargs: ShapedArray(y.shape, x.dtype))
parallel.serial_pmap_primitive_rules[psplit_like_p] = _psplit_like_serial_pmap_rule
def _pcollect_serial_pmap_rule(vals, axes):
x, = vals
return x, None
pcollect_p = standard_pmap_primitive('pcollect')
parallel.serial_pmap_primitive_rules[pcollect_p] = _pcollect_serial_pmap_rule
### papply rules
# TODO(skye): it would be nice if we could put these with their corresponding
# primitives, but that currently causes circular dependencies. More refactoring
@ -407,35 +373,24 @@ def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions,
def _transpose_papply_rule(name, size, vals, dims, permutation):
x, = vals
xdim, = dims
perm = list(permutation)
if perm[xdim] == xdim:
perm = [i - 1 if i > xdim else i for i in perm if i != xdim]
x = lax.transpose(x, perm)
out_dim = xdim
else:
in_dim, = [i for i in range(len(perm)) if perm[i] == xdim]
out_dim = perm[xdim]
perm[in_dim] = out_dim
perm[out_dim] = in_dim
perm = perm[:xdim] + perm[xdim + 1:]
perm = [i - 1 if i > xdim else i for i in perm]
x = lax.transpose(x, perm)
x = pswapaxes(x, name, in_dim)
return x, xdim
local_perm = [i if i < xdim else i - 1 for i in permutation if i != xdim]
return lax.transpose(x, local_perm), permutation.index(xdim)
def _select_papply_rule(name, size, vals, dims):
dimset = set([d for d in dims if d is not None])
dimset = {d for d in dims if d is not None}
if len(dimset) != 1:
raise NotImplementedError(
'papply of select with operands split along different dimensions')
like_val, like_dim = [(v, d) for v, d in zip(vals, dims) if d is not None][0]
dim, = dimset
def normalize_split(val, dim):
return psplit_like(val, like_val, name) if dim is None else val
def drop(x, d):
if d is None:
return lax.dynamic_index_in_dim(x, axis_index(name), dim, False)
else:
return x
vals = [normalize_split(v, d) for v, d in zip(vals, dims)]
return lax.select_p.bind(*vals), like_dim
return lax.select_p.bind(*map(drop, vals, dims)), dim
def _add_jaxvals_papply_rule(name, size, vals, dims):

View File

@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from unittest import SkipTest
import numpy as onp
@ -25,80 +26,30 @@ from absl.testing import parameterized
import jax.numpy as np
from jax import test_util as jtu
from jax import lax
from jax.api import _serial_pmap, _papply, jit, make_jaxpr
from jax.api import _papply, _parallelize, soft_pmap, jit, make_jaxpr
from jax.linear_util import wrap_init
from jax.util import prod
from jax.config import config
config.parse_flags_with_absl()
class SerialPmapTest(jtu.JaxTestCase):
def testConstantFunction(self):
f = lambda x: 3
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
expected = 3 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceSum(self):
f = lambda x: lax.psum(x, 'i')
ans = _serial_pmap(f, axis_name='i')(onp.ones(4))
expected = 4 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceMax(self):
f = lambda x: lax.pmax(x, 'i')
ans = _serial_pmap(f, axis_name='i')(onp.arange(4))
expected = 3 * onp.ones(4)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplit(self):
f = lambda x: lax.psplit(x, 'i', 2, 0)
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplitLike(self):
f = lambda x, y: lax.psplit_like(x, y, 'i')
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = _serial_pmap(f, axis_name='i', in_axes=(None, 2), out_axes=2)(arg, arg)
expected = arg
self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self):
f = lambda x: x - np.log(lax.psum(np.exp(x), 'i'))
x = onp.log(onp.arange(1., 10., dtype=onp.float32))
ans = _serial_pmap(f, axis_name='i')(x)
expected = x - onp.log(onp.sum(onp.exp(x)))
self.assertAllClose(ans, expected, check_dtypes=False)
def testNested(self):
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
x = onp.ones((2, 2))
ans1 = _serial_pmap(_serial_pmap(f, 'i'), 'j')(x)
ans2 = _serial_pmap(_serial_pmap(f, 'j'), 'i')(x)
expected = 4 * onp.ones((2, 2))
self.assertAllClose(ans1, expected, check_dtypes=False)
self.assertAllClose(ans2, expected, check_dtypes=False)
class PapplyTest(jtu.JaxTestCase):
def testIdentity(self):
pfun, axis_name = _papply(lambda x: x, 3)
pfun, axis_name = _papply(lambda x: x)
ans = pfun(onp.arange(3))
expected = onp.arange(3)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMap(self):
pfun, axis_name = _papply(np.sin, 3)
pfun, axis_name = _papply(np.sin)
ans = pfun(onp.arange(3.))
expected = onp.sin(onp.arange(3.))
self.assertAllClose(ans, expected, check_dtypes=False)
def testSum(self):
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5)
pfun, axis_name = _papply(lambda x: np.sum(x, axis=0))
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
@ -106,12 +57,12 @@ class PapplyTest(jtu.JaxTestCase):
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
ans = _serial_pmap(pfun, axis_name)(arg)[0]
ans = soft_pmap(pfun, axis_name)(arg)[0]
expected = onp.sum(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMax(self):
pfun, axis_name = _papply(lambda x: np.max(x, axis=0), 5)
pfun, axis_name = _papply(lambda x: np.max(x, axis=0))
jaxpr = make_jaxpr(pfun)(onp.ones(3))
expected_jaxpr = make_jaxpr(
@ -119,30 +70,20 @@ class PapplyTest(jtu.JaxTestCase):
assert repr(jaxpr) == repr(expected_jaxpr)
arg = onp.arange(15.).reshape((5, 3))
ans = _serial_pmap(pfun, axis_name)(arg)[0]
ans = soft_pmap(pfun, axis_name)(arg)[0]
expected = onp.max(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self):
pfun, axis_name = _papply(lax.select, 5,
in_axes=(None, 0, None))
p = onp.arange(15).reshape((5, 3)) % 4 == 1
t = onp.ones((5, 3))
f = onp.zeros((5, 3))
jaxpr = make_jaxpr(pfun)(p, t[0], f)
def expected_spmd(p, t, f):
return lax.select(
lax.psplit_like(p, t, axis_name),
t,
lax.psplit_like(f, t, axis_name))
def fun(t):
return lax.select(p, t, f)
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
assert repr(jaxpr) == repr(expected_jaxpr)
ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
expected = lax.select(p, t, f)
t = onp.ones((5, 3))
ans = soft_pmap(*_papply(fun))(t)
expected = fun(t)
self.assertAllClose(ans, expected, check_dtypes=True)
def testLogSoftmax(self):
@ -151,14 +92,14 @@ class PapplyTest(jtu.JaxTestCase):
def fun(x):
return x - np.log(np.sum(np.exp(x)))
pfun, axis_name = _papply(fun, 5)
pfun, axis_name = _papply(fun)
jaxpr = make_jaxpr(pfun)(onp.zeros(5))
expected_jaxpr = make_jaxpr(
lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
assert repr(jaxpr) == repr(expected_jaxpr)
ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.))
ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.))
expected = fun(onp.arange(1., 5.))
self.assertAllClose(ans, expected, check_dtypes=False)
@ -166,8 +107,8 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.array([[1, 2, 3], [4, 5, 6]])
expected = x + x
pfun, axis_name = _papply(np.add, 2)
ans = _serial_pmap(pfun, axis_name)(x, x)
pfun, axis_name = _papply(np.add)
ans = soft_pmap(pfun, axis_name)(x, x)
self.assertAllClose(ans, expected, check_dtypes=True)
def testAddBroadcasting(self):
@ -179,39 +120,47 @@ class PapplyTest(jtu.JaxTestCase):
x = onp.array([[1, 2], [3, 4]])
expected = x + 3
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
pfun, axis_name = _papply(fun)
ans = soft_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=True)
def testTranspose(self):
class ParallelizeTest(jtu.JaxTestCase):
def testNormalize(self):
def f(x):
return x / x.sum(0)
x = onp.arange(4.)
expected = f(x)
ans = _parallelize(f)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = make_jaxpr(_parallelize(f))(x)
self.assertIn('psum', repr(jaxpr))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "testTranspose_shape={}_perm={}"
.format(shape, perm),
"shape": shape, "perm": perm}
for shape in [
(2, 2),
(3, 3),
(2, 2, 2),
(2, 3, 4),
(2, 3, 2)
]
for perm in itertools.permutations(list(range(len(shape))))
))
def testTranspose(self, shape, perm):
def fun(x):
return x.T
return lax.transpose(x, perm)
xs = [
onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
]
for x in xs:
expected = x.T
pfun, axis_name = _papply(fun, x.shape[0])
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeWithOddPermutation(self):
def fun(x):
return np.transpose(x, (2, 0, 1))
xs = [
onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)),
onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)),
]
for x in xs:
expected = np.transpose(x, (2, 0, 1))
pfun, axis_name = _papply(fun, x.shape[0])
ans = _serial_pmap(pfun, axis_name)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
x = onp.arange(prod(shape)).reshape(shape)
expected = fun(x)
ans = _parallelize(fun)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeAndAddRank2(self):
@ -219,10 +168,8 @@ class PapplyTest(jtu.JaxTestCase):
return x + x.T
x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))
expected = x + x.T
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
expected = fun(x)
ans = _parallelize(fun)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeAndAddRank3(self):
@ -231,28 +178,26 @@ class PapplyTest(jtu.JaxTestCase):
return x + x.T
x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2))
expected = x + x.T
pfun, axis_name = _papply(fun, 2)
ans = _serial_pmap(pfun, axis_name)(x)
expected = fun(x)
ans = _parallelize(fun)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self):
return SkipTest("test doesn't pass yet") # TODO(frostig)
# def testDot(self):
# return SkipTest("test doesn't pass yet") # TODO(frostig)
def fun(x, y):
return lax.dot(x, y)
xs = [
onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
]
in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
for in_axes in in_axes_combos:
for x in xs:
expected = fun(x, x)
pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes)
ans = _serial_pmap(pfun, axis_name)(x, x)
self.assertAllClose(ans, expected, check_dtypes=False)
# def fun(x, y):
# return lax.dot(x, y)
# xs = [
# onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)),
# onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)),
# ]
# in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)]
# for in_axes in in_axes_combos:
# for x in xs:
# expected = fun(x, x)
# pfun, axis_name = _papply(fun)
# ans = soft_pmap(pfun, axis_name)(x, x)
# self.assertAllClose(ans, expected, check_dtypes=False)
if __name__ == '__main__':