Merge pull request #916 from google/parallelize

parallelization work-in-progress
This commit is contained in:
Roy Frostig 2019-06-24 16:08:14 -07:00 committed by GitHub
commit 33b01733a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 204 additions and 79 deletions

View File

@ -253,7 +253,7 @@ class PapplyTrace(Trace):
name = next(n for n in names if n is not None)
size = next(t.axis_size for t in tracers if t.axis_size is not None)
rule = papply_primitive_rules[primitive]
val_out, axis_out = rule(name, vals, axes, **params)
val_out, axis_out = rule(name, size, vals, axes, **params)
return PapplyTracer(self, name, size, val_out, axis_out)
def process_call(self, call_primitive, f, tracers, params):
@ -270,11 +270,11 @@ class PapplyTrace(Trace):
return PapplyTracer(self, name, size, vals, axis)
def vectorized_papply(prim, name, vals, axes, **params):
def vectorized_papply(prim, name, size, vals, axes, **params):
assert all(axes[0] == a for a in axes[1:])
return prim.bind(*vals, **params), axes[0]
def reducer_papply(prim, cprim, name, vals, papply_axes, axes, **kwargs):
def reducer_papply(prim, cprim, name, size, vals, papply_axes, axes, **kwargs):
operand, = vals
papply_axis, = papply_axes
@ -296,26 +296,28 @@ def reducer_papply(prim, cprim, name, vals, papply_axes, axes, **kwargs):
return result, new_papply_axis
def broadcasting_papply(prim, name, vals, axes, **params):
def broadcasting_papply(prim, name, size, vals, axes, **params):
x, y = vals
xdim, ydim = axes
if xdim is None:
assert x.shape[ydim] == 1
x = x.reshape(onp.delete(x.shape, ydim))
if x.shape:
assert x.shape[ydim] == 1
x = x.reshape(onp.delete(x.shape, ydim))
return prim.bind(x, y, **params), ydim
elif ydim is None:
assert y.shape[xdim] == 1
y = y.reshape(onp.delete(y.shape, xdim))
if y.shape:
assert y.shape[xdim] == 1
y = y.reshape(onp.delete(y.shape, xdim))
return prim.bind(x, y, **params), xdim
elif xdim == ydim:
return prim.bind(x, y, **params), xdim
else:
x = psplit(x, axis_name, ydim)
x = psplit(x, axis_name, ydim, xdim)
return prim.bind(x, y, **params), ydim
def identity_papply(prim, argnum, name, vals, axes, **params):
def identity_papply(prim, argnum, name, size, vals, axes, **params):
return prim.bind(*vals, **params), axes[argnum]

View File

@ -1767,6 +1767,12 @@ def _convert_element_type_translation_rule(c, operand, new_dtype, old_dtype):
new_etype = xla_bridge.dtype_to_etype_exact(new_dtype)
return c.ConvertElementType(operand, new_element_type=new_etype)
def _convert_element_type_papply_rule(name, size, vals, dims, new_dtype,
**kwargs):
operand, = vals
dim, = dims
return convert_element_type(operand, new_dtype), dim
convert_element_type_p = standard_primitive(
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
'convert_element_type', _convert_element_type_translation_rule)
@ -1774,6 +1780,7 @@ ad.deflinear(
convert_element_type_p,
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
batching.defvectorized(convert_element_type_p)
parallel.papply_primitive_rules[convert_element_type_p] = _convert_element_type_papply_rule
def _bitcast_convert_type_shape_rule(operand, new_dtype):
@ -1943,6 +1950,22 @@ def _conv_general_dilated_batch_rule(
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
return out, out_spec[1]
def _conv_general_dilated_papply_rule(
name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, **unused_kwargs):
lhs, rhs = vals
lhs_dim, rhs_dim = dims
lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
lhs = reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1)))
out = conv_general_dilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers)
return out, lhs_dim
else:
raise NotImplementedError(
"splitting a convolution along anything but input batch dimension")
conv_general_dilated_p = standard_primitive(
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
'conv_general_dilated', _conv_general_dilated_translation_rule)
@ -1951,6 +1974,9 @@ ad.defbilinear(conv_general_dilated_p,
_conv_general_dilated_transpose_rhs)
batching.primitive_batchers[
conv_general_dilated_p] = _conv_general_dilated_batch_rule
parallel.papply_primitive_rules[
conv_general_dilated_p] = _conv_general_dilated_papply_rule
def _reshape_axis_into(src, dst, x):
perm = [i for i in range(x.ndim) if i != src]
@ -2218,11 +2244,25 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
new_broadcast_dimensions.insert(bdim, bdim)
return broadcast_in_dim(operand, new_shape, new_broadcast_dimensions), bdim
def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
broadcast_dimensions):
operand, = vals
dim, = dims
out_dim = broadcast_dimensions[dim]
if shape[out_dim] != shape[dim]:
raise ValueError(
"broadcast_in_dim changes hidden dimension size: {} to {}".format(
shape[dim], shape[out_dim]))
sub_bdims = tuple(onp.delete(broadcast_dimensions, dim))
sub_shape = tuple(onp.delete(shape, out_dim))
return broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
parallel.papply_primitive_rules[broadcast_in_dim_p] = _broadcast_in_dim_papply_rule
def _clamp_shape_rule(min, operand, max):
@ -2350,10 +2390,27 @@ def _pad_batch_rule(batched_args, batch_dims, padding_config):
else:
raise NotImplementedError # loop and stack
def _pad_papply_rule(name, size, vals, dims, padding_config):
operand, padding_value = vals
operand_dim, padding_value_dim = dims
assert padding_value_dim is None
padding_config = list(padding_config)
if padding_config[operand_dim] == (0, 0, 0):
padded = pad(
operand,
padding_value,
padding_config[:operand_dim] + padding_config[operand_dim + 1:])
return padded, operand_dim
else:
raise NotImplementedError(
'pad changes size of hidden dimension {} with config {}'.format(
operand_dim, padding_config))
pad_p = standard_primitive(_pad_shape_rule, _input_dtype, 'pad')
ad.deflinear(pad_p, _pad_transpose)
ad.primitive_transposes[pad_p] = _pad_transpose
batching.primitive_batchers[pad_p] = _pad_batch_rule
parallel.papply_primitive_rules[pad_p] = _pad_papply_rule
def _reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs):
@ -2593,10 +2650,30 @@ def _slice_batching_rule(batched_args, batch_dims, start_indices, limit_indices,
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
return out, bdim
def _slice_papply_rule(name, size, vals, dims, start_indices, limit_indices,
strides, **kwargs):
operand, = vals
dim, = dims
start_indices = list(start_indices)
limit_indices = list(limit_indices)
if (start_indices[dim] != 0 or
limit_indices[dim] != size or
strides is not None and strides[dim] != 1):
raise NotImplementedError('slice changes side of hidden dimension')
out = slice(
operand,
start_indices[:dim] + start_indices[dim + 1:],
limit_indices[:dim] + limit_indices[dim + 1:],
strides[:dim] + strides[dim + 1:] if strides is not None else None)
return out, dim
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
_slice_translation_rule)
ad.deflinear(slice_p, _slice_transpose_rule)
batching.primitive_batchers[slice_p] = _slice_batching_rule
parallel.papply_primitive_rules[slice_p] = _slice_papply_rule
def _dynamic_slice_shape_rule(operand, start_indices, slice_sizes,
@ -2859,9 +2936,27 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
return gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes), 0
def _gather_serial_pmap_rule(vals, axes):
val, = vals
return val, None
def _gather_papply_rule(
name, size, vals, dims, dimension_numbers, slice_sizes, operand_shape):
operand, start_indices = vals
operand_dim, start_indices_dim = dims
if (operand_dim is None and
start_indices_dim is not None and
start_indices_dim not in dimension_numbers.offset_dims and
dimension_numbers.collapsed_slice_dims == (0,)):
offset_dims = tuple(i - 1 if i > start_indices_dim else i
for i in dimension_numbers.offset_dims)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
start_index_map=dimension_numbers.start_index_map)
out = gather(operand, start_indices, dimension_numbers=dnums,
slice_sizes=slice_sizes)
out_dim = start_indices_dim + onp.sum(
onp.less_equal(offset_dims, start_indices_dim))
return out, out_dim
else:
raise NotImplementedError
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
@ -2869,8 +2964,7 @@ gather_p = standard_primitive(
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
parallel.serial_pmap_primitive_rules[gather_p] = _gather_serial_pmap_rule
parallel.papply_primitive_rules[gather_p] = _gather_papply_rule
class ScatterDimensionNumbers(collections.namedtuple(
"ScatterDimensionNumbers",

View File

@ -15,6 +15,9 @@
Parallelization primitives.
"""
import numpy as onp
from jax import ad_util
from jax.lax import lax
from jax.abstract_arrays import ShapedArray
from jax.interpreters import ad
@ -131,7 +134,7 @@ def pswapaxes(x, axis_name, axis):
# raise ValueError(msg.format(axis_size(axis_name), x.shape[axis]))
return pswapaxes_p.bind(x, axis_name=axis_name, axis=axis)
def psplit(x, axis_name, axis):
def psplit(x, axis_name, split_axis, concat_axis):
"""Unmap the pmapped axis ``axis_name`` and map ``axis`` with the same name.
This function is similar to ``pswapaxes`` except the pmapped axis of the input
@ -141,15 +144,18 @@ def psplit(x, axis_name, axis):
x: array with a mapped axis named ``axis_name``.
axis_name: hashable Python object used to name a pmapped axis (see the
``pmap`` docstring for more details).
axis: int indicating the unmapped axis of ``x`` to map with the name
split_axis: int indicating the unmapped axis of ``x`` to map with the name
``axis_name``.
concat_axis: int indicating the dimension at which to materialize the axis
of ``x`` mapped with ``axis_name``.
Returns:
An array with shape ``(axis_size,) + tuple(np.delete(x.shape, axis))`` where
``axis_size`` is the size of the mapped axis named ``axis_name`` in the
input ``x``.
An array with shape ``np.insert(np.delete(x.shape, split_axis), concat_axis,
axis_size)`` where ``axis_size`` is the size of the mapped axis named
``axis_name`` in the input ``x``.
"""
return psplit_p.bind(x, axis_name=axis_name, axis=axis)
return psplit_p.bind(x, axis_name=axis_name, concat_axis=concat_axis,
split_axis=split_axis)
def psplit_like(x, y, axis_name):
"""Ensure the named mapped axis of ``x`` aligns with that of ``y``."""
@ -249,14 +255,17 @@ pxla.parallel_translation_rules[pswapaxes_p] = _pswapaxes_translation_rule
parallel.serial_pmap_primitive_rules[pswapaxes_p] = _pswapaxes_serial_pmap_rule
def _psplit_serial_pmap_rule(vals, axes, axis):
def _psplit_serial_pmap_rule(vals, axes, split_axis, concat_axis):
x, = vals
axis_in, = axes
if x.shape[axis_in] != x.shape[axis]:
if x.shape[axis_in] != x.shape[split_axis]:
raise ValueError(
"psplit between non-square dimensions {} and {} of {}".format(
axis_in, axis, x.shape))
return x, axis
axis_in, split_axis, x.shape))
perm = list(range(x.ndim))
perm[axis_in] = concat_axis
perm[concat_axis] = axis_in
return lax.transpose(x, perm), split_axis
psplit_p = standard_pmap_primitive('psplit')
parallel.serial_pmap_primitive_rules[psplit_p] = _psplit_serial_pmap_rule
@ -290,7 +299,7 @@ parallel.serial_pmap_primitive_rules[pcollect_p] = _pcollect_serial_pmap_rule
# primitives, but that currently causes circular dependencies. More refactoring
# might fix this.
def _dot_papply_rule(name, vals, dims):
def _dot_papply_rule(name, size, vals, dims):
x, y = vals
xdim, ydim = dims
if xdim is None:
@ -299,7 +308,7 @@ def _dot_papply_rule(name, vals, dims):
return lax.dot(x, y), xdim
elif ydim == 0:
if xdim != x.ndim:
x = psplit(x, name, x.ndim)
x = psplit(x, name, x.ndim, xdim)
x = x[..., None]
y = y[..., None, :]
return psum(x * y, name), None
@ -308,49 +317,62 @@ def _dot_papply_rule(name, vals, dims):
return lax.dot(x, y), xdim
def _dot_general_papply_rule(name, vals, dims, dimension_numbers):
def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers):
x, y = vals
xdim, ydim = dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
if len(lhs_batch) > 0 or len(rhs_batch) > 0:
raise NotImplementedError
def adjust_dims(dims, thresh):
return tuple(i - 1 if i >= thresh else i for i in dims if i != thresh)
return tuple(i - 1 if i > thresh else i for i in dims if i != thresh)
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
if xdim is not None:
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
if ydim is not None:
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
sub_lhs_batch, sub_rhs_batch = lhs_batch, rhs_batch
sub_dimension_numbers = (
(sub_lhs_contract, sub_rhs_contract), (lhs_batch, rhs_batch))
def sub_dims(xdim, ydim):
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
sub_lhs_batch, sub_rhs_batch = lhs_batch, rhs_batch
if xdim is not None:
sub_lhs_batch = adjust_dims(lhs_batch, xdim)
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
if ydim is not None:
sub_rhs_batch = adjust_dims(rhs_batch, ydim)
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
return (
(sub_lhs_contract, sub_rhs_contract), (sub_lhs_batch, sub_rhs_batch))
if xdim in lhs_contract and ydim in rhs_contract:
z = lax.dot_general(x, y, sub_dimension_numbers)
return psum(z, name), None
elif xdim in lhs_contract:
if ydim is not None: # Cannot hide two dimensions, so collect one
y = pcollect(y, name)
return lax.dot_general(x, y, sub_dimension_numbers), xdim
elif ydim in rhs_contract:
if xdim is not None: # Cannot hide two dimensions, so collect one
x = pcollect(x, name)
return lax.dot_general(x, y, sub_dimension_numbers), ydim
elif xdim is not None:
if ydim is not None: # Cannot hide two dimensions, so collect one
y = pcollect(y, name)
return lax.dot_general(x, y, sub_dimension_numbers), xdim
elif ydim is not None:
return lax.dot_general(x, y, sub_dimension_numbers), ydim
def cases(x, y, xdim, ydim, xcontract, ycontract):
if xdim in xcontract:
if ydim in ycontract:
# case: both operands are split and contracting
z = lax.dot_general(x, y, sub_dims(xdim, ydim))
return True, (psum(z, name), None)
elif ydim is not None:
# case: x split and contracting, y split but not contracting
new_ydim = ycontract[xcontract.index(xdim)]
y = psplit(y, name, new_ydim, ydim)
z = lax.dot_general(x, y, sub_dims(xdim, new_ydim))
return True, (psum(z, name), None)
else:
# case: x split and contracting, y not split
return False, 'one operand split and contracting, other is not split'
else:
return False, 'unhandled case'
ok, out = cases(x, y, xdim, ydim, lhs_contract, rhs_contract)
if not ok:
ok, out = cases(y, x, ydim, xdim, rhs_contract, lhs_contract)
if not ok:
raise NotImplementedError(
('papply of dot_general, {}: '
'xdim={}, ydim={}, dimension_numbers={}').format(
out, xdim, ydim, dimension_numbers))
else:
return lax.dot_general(x, y, sub_dimension_numbers), None
return out
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions,
old_sizes):
operand, = vals
axis, = axes
@ -358,42 +380,33 @@ def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
return filter(lambda x: x != 1, xs)
def find_new_axis(old_axis, old_sizes, new_sizes):
if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)):
return None
num_before = len(filter_ones(old_sizes[:old_axis]))
sz = old_sizes[old_axis]
for i, new_sz in enumerate(new_sizes):
if num_before == 0:
if new_sz == sz:
return i
elif new_sz != 1:
return None
elif new_sz != 1:
num_before -= 1
left = onp.prod(old_sizes[:old_axis])
size = old_sizes[old_axis]
prod = 1
for i, cur_sz in enumerate(new_sizes):
if prod == left and cur_sz == size:
return i
prod = prod * sz
return None
err = NotImplementedError(
'papply of reshape that would change hidden dimension size')
if dimensions is None:
new_axis = find_new_axis(axis, old_sizes, new_sizes)
if new_axis is not None:
if (lax.prod(old_sizes[:axis]) != lax.prod(new_sizes[:new_axis]) or
lax.prod(old_sizes[axis + 1:]) != lax.prod(new_sizes[new_axis + 1:])):
raise err
new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis
else:
raise err
raise NotImplementedError(
'papply of reshape that would change hidden dimension size')
else:
raise NotImplementedError('papply of reshape with `dimensions`')
def _transpose_papply_rule(name, vals, dims, permutation):
def _transpose_papply_rule(name, size, vals, dims, permutation):
x, = vals
xdim, = dims
perm = list(permutation)
if perm[xdim] == xdim:
perm = [i - 1 if i > xdim else i for i in perm if i != xdim]
x = lax.transpose(x, perm)
out_dim = xdim
else:
@ -408,7 +421,7 @@ def _transpose_papply_rule(name, vals, dims, permutation):
return x, xdim
def _select_papply_rule(name, vals, dims):
def _select_papply_rule(name, size, vals, dims):
dimset = set([d for d in dims if d is not None])
if len(dimset) != 1:
raise NotImplementedError(
@ -422,8 +435,24 @@ def _select_papply_rule(name, vals, dims):
return lax.select_p.bind(*vals), like_dim
def _add_jaxvals_papply_rule(name, size, vals, dims):
x, y = vals
xdim, ydim = dims
if xdim == ydim:
out_dim = xdim
elif ydim is None:
y = lax.psplit_like(y, x, name)
out_dim = xdim
else:
x = lax.psplit_like(x, y, name)
out_dim = ydim
return ad_util.add_jaxvals_p.bind(x, y), out_dim
parallel.papply_primitive_rules[lax.dot_p] = _dot_papply_rule
parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule
parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule
parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule
parallel.papply_primitive_rules[lax.select_p] = _select_papply_rule
parallel.papply_primitive_rules[ad_util.add_jaxvals_p] = (
_add_jaxvals_papply_rule)

View File

@ -53,7 +53,7 @@ class SerialPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsplit(self):
f = lambda x: lax.psplit(x, 'i', 2)
f = lambda x: lax.psplit(x, 'i', 2, 0)
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
expected = arg