mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #916 from google/parallelize
parallelization work-in-progress
This commit is contained in:
commit
33b01733a9
@ -253,7 +253,7 @@ class PapplyTrace(Trace):
|
||||
name = next(n for n in names if n is not None)
|
||||
size = next(t.axis_size for t in tracers if t.axis_size is not None)
|
||||
rule = papply_primitive_rules[primitive]
|
||||
val_out, axis_out = rule(name, vals, axes, **params)
|
||||
val_out, axis_out = rule(name, size, vals, axes, **params)
|
||||
return PapplyTracer(self, name, size, val_out, axis_out)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
@ -270,11 +270,11 @@ class PapplyTrace(Trace):
|
||||
return PapplyTracer(self, name, size, vals, axis)
|
||||
|
||||
|
||||
def vectorized_papply(prim, name, vals, axes, **params):
|
||||
def vectorized_papply(prim, name, size, vals, axes, **params):
|
||||
assert all(axes[0] == a for a in axes[1:])
|
||||
return prim.bind(*vals, **params), axes[0]
|
||||
|
||||
def reducer_papply(prim, cprim, name, vals, papply_axes, axes, **kwargs):
|
||||
def reducer_papply(prim, cprim, name, size, vals, papply_axes, axes, **kwargs):
|
||||
operand, = vals
|
||||
papply_axis, = papply_axes
|
||||
|
||||
@ -296,26 +296,28 @@ def reducer_papply(prim, cprim, name, vals, papply_axes, axes, **kwargs):
|
||||
return result, new_papply_axis
|
||||
|
||||
|
||||
def broadcasting_papply(prim, name, vals, axes, **params):
|
||||
def broadcasting_papply(prim, name, size, vals, axes, **params):
|
||||
x, y = vals
|
||||
xdim, ydim = axes
|
||||
|
||||
if xdim is None:
|
||||
assert x.shape[ydim] == 1
|
||||
x = x.reshape(onp.delete(x.shape, ydim))
|
||||
if x.shape:
|
||||
assert x.shape[ydim] == 1
|
||||
x = x.reshape(onp.delete(x.shape, ydim))
|
||||
return prim.bind(x, y, **params), ydim
|
||||
elif ydim is None:
|
||||
assert y.shape[xdim] == 1
|
||||
y = y.reshape(onp.delete(y.shape, xdim))
|
||||
if y.shape:
|
||||
assert y.shape[xdim] == 1
|
||||
y = y.reshape(onp.delete(y.shape, xdim))
|
||||
return prim.bind(x, y, **params), xdim
|
||||
elif xdim == ydim:
|
||||
return prim.bind(x, y, **params), xdim
|
||||
else:
|
||||
x = psplit(x, axis_name, ydim)
|
||||
x = psplit(x, axis_name, ydim, xdim)
|
||||
return prim.bind(x, y, **params), ydim
|
||||
|
||||
|
||||
def identity_papply(prim, argnum, name, vals, axes, **params):
|
||||
def identity_papply(prim, argnum, name, size, vals, axes, **params):
|
||||
return prim.bind(*vals, **params), axes[argnum]
|
||||
|
||||
|
||||
|
104
jax/lax/lax.py
104
jax/lax/lax.py
@ -1767,6 +1767,12 @@ def _convert_element_type_translation_rule(c, operand, new_dtype, old_dtype):
|
||||
new_etype = xla_bridge.dtype_to_etype_exact(new_dtype)
|
||||
return c.ConvertElementType(operand, new_element_type=new_etype)
|
||||
|
||||
def _convert_element_type_papply_rule(name, size, vals, dims, new_dtype,
|
||||
**kwargs):
|
||||
operand, = vals
|
||||
dim, = dims
|
||||
return convert_element_type(operand, new_dtype), dim
|
||||
|
||||
convert_element_type_p = standard_primitive(
|
||||
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
||||
'convert_element_type', _convert_element_type_translation_rule)
|
||||
@ -1774,6 +1780,7 @@ ad.deflinear(
|
||||
convert_element_type_p,
|
||||
lambda t, new_dtype, old_dtype: [convert_element_type(t, old_dtype)])
|
||||
batching.defvectorized(convert_element_type_p)
|
||||
parallel.papply_primitive_rules[convert_element_type_p] = _convert_element_type_papply_rule
|
||||
|
||||
|
||||
def _bitcast_convert_type_shape_rule(operand, new_dtype):
|
||||
@ -1943,6 +1950,22 @@ def _conv_general_dilated_batch_rule(
|
||||
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
|
||||
return out, out_spec[1]
|
||||
|
||||
def _conv_general_dilated_papply_rule(
|
||||
name, size, vals, dims, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, **unused_kwargs):
|
||||
lhs, rhs = vals
|
||||
lhs_dim, rhs_dim = dims
|
||||
lhs_spec_batch_dim = dimension_numbers.lhs_spec[0]
|
||||
if rhs_dim is None and lhs_dim == lhs_spec_batch_dim:
|
||||
lhs = reshape(lhs, tuple(onp.insert(lhs.shape, lhs_dim, 1)))
|
||||
out = conv_general_dilated(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers)
|
||||
return out, lhs_dim
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"splitting a convolution along anything but input batch dimension")
|
||||
|
||||
conv_general_dilated_p = standard_primitive(
|
||||
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
|
||||
'conv_general_dilated', _conv_general_dilated_translation_rule)
|
||||
@ -1951,6 +1974,9 @@ ad.defbilinear(conv_general_dilated_p,
|
||||
_conv_general_dilated_transpose_rhs)
|
||||
batching.primitive_batchers[
|
||||
conv_general_dilated_p] = _conv_general_dilated_batch_rule
|
||||
parallel.papply_primitive_rules[
|
||||
conv_general_dilated_p] = _conv_general_dilated_papply_rule
|
||||
|
||||
|
||||
def _reshape_axis_into(src, dst, x):
|
||||
perm = [i for i in range(x.ndim) if i != src]
|
||||
@ -2218,11 +2244,25 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
new_broadcast_dimensions.insert(bdim, bdim)
|
||||
return broadcast_in_dim(operand, new_shape, new_broadcast_dimensions), bdim
|
||||
|
||||
def _broadcast_in_dim_papply_rule(name, size, vals, dims, shape,
|
||||
broadcast_dimensions):
|
||||
operand, = vals
|
||||
dim, = dims
|
||||
out_dim = broadcast_dimensions[dim]
|
||||
if shape[out_dim] != shape[dim]:
|
||||
raise ValueError(
|
||||
"broadcast_in_dim changes hidden dimension size: {} to {}".format(
|
||||
shape[dim], shape[out_dim]))
|
||||
sub_bdims = tuple(onp.delete(broadcast_dimensions, dim))
|
||||
sub_shape = tuple(onp.delete(shape, out_dim))
|
||||
return broadcast_in_dim(operand, sub_shape, sub_bdims), out_dim
|
||||
|
||||
|
||||
broadcast_in_dim_p = standard_primitive(
|
||||
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
||||
ad.deflinear(broadcast_in_dim_p, _broadcast_in_dim_transpose_rule)
|
||||
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
parallel.papply_primitive_rules[broadcast_in_dim_p] = _broadcast_in_dim_papply_rule
|
||||
|
||||
|
||||
def _clamp_shape_rule(min, operand, max):
|
||||
@ -2350,10 +2390,27 @@ def _pad_batch_rule(batched_args, batch_dims, padding_config):
|
||||
else:
|
||||
raise NotImplementedError # loop and stack
|
||||
|
||||
def _pad_papply_rule(name, size, vals, dims, padding_config):
|
||||
operand, padding_value = vals
|
||||
operand_dim, padding_value_dim = dims
|
||||
assert padding_value_dim is None
|
||||
padding_config = list(padding_config)
|
||||
if padding_config[operand_dim] == (0, 0, 0):
|
||||
padded = pad(
|
||||
operand,
|
||||
padding_value,
|
||||
padding_config[:operand_dim] + padding_config[operand_dim + 1:])
|
||||
return padded, operand_dim
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'pad changes size of hidden dimension {} with config {}'.format(
|
||||
operand_dim, padding_config))
|
||||
|
||||
pad_p = standard_primitive(_pad_shape_rule, _input_dtype, 'pad')
|
||||
ad.deflinear(pad_p, _pad_transpose)
|
||||
ad.primitive_transposes[pad_p] = _pad_transpose
|
||||
batching.primitive_batchers[pad_p] = _pad_batch_rule
|
||||
parallel.papply_primitive_rules[pad_p] = _pad_papply_rule
|
||||
|
||||
|
||||
def _reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs):
|
||||
@ -2593,10 +2650,30 @@ def _slice_batching_rule(batched_args, batch_dims, start_indices, limit_indices,
|
||||
out = slice(operand, new_start_indices, new_limit_indices, new_strides)
|
||||
return out, bdim
|
||||
|
||||
def _slice_papply_rule(name, size, vals, dims, start_indices, limit_indices,
|
||||
strides, **kwargs):
|
||||
operand, = vals
|
||||
dim, = dims
|
||||
start_indices = list(start_indices)
|
||||
limit_indices = list(limit_indices)
|
||||
|
||||
if (start_indices[dim] != 0 or
|
||||
limit_indices[dim] != size or
|
||||
strides is not None and strides[dim] != 1):
|
||||
raise NotImplementedError('slice changes side of hidden dimension')
|
||||
|
||||
out = slice(
|
||||
operand,
|
||||
start_indices[:dim] + start_indices[dim + 1:],
|
||||
limit_indices[:dim] + limit_indices[dim + 1:],
|
||||
strides[:dim] + strides[dim + 1:] if strides is not None else None)
|
||||
return out, dim
|
||||
|
||||
slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice',
|
||||
_slice_translation_rule)
|
||||
ad.deflinear(slice_p, _slice_transpose_rule)
|
||||
batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
parallel.papply_primitive_rules[slice_p] = _slice_papply_rule
|
||||
|
||||
|
||||
def _dynamic_slice_shape_rule(operand, start_indices, slice_sizes,
|
||||
@ -2859,9 +2936,27 @@ def _gather_batching_rule(batched_args, batch_dims, dimension_numbers,
|
||||
return gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes), 0
|
||||
|
||||
def _gather_serial_pmap_rule(vals, axes):
|
||||
val, = vals
|
||||
return val, None
|
||||
def _gather_papply_rule(
|
||||
name, size, vals, dims, dimension_numbers, slice_sizes, operand_shape):
|
||||
operand, start_indices = vals
|
||||
operand_dim, start_indices_dim = dims
|
||||
if (operand_dim is None and
|
||||
start_indices_dim is not None and
|
||||
start_indices_dim not in dimension_numbers.offset_dims and
|
||||
dimension_numbers.collapsed_slice_dims == (0,)):
|
||||
offset_dims = tuple(i - 1 if i > start_indices_dim else i
|
||||
for i in dimension_numbers.offset_dims)
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
||||
start_index_map=dimension_numbers.start_index_map)
|
||||
out = gather(operand, start_indices, dimension_numbers=dnums,
|
||||
slice_sizes=slice_sizes)
|
||||
out_dim = start_indices_dim + onp.sum(
|
||||
onp.less_equal(offset_dims, start_indices_dim))
|
||||
return out, out_dim
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
@ -2869,8 +2964,7 @@ gather_p = standard_primitive(
|
||||
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
parallel.serial_pmap_primitive_rules[gather_p] = _gather_serial_pmap_rule
|
||||
|
||||
parallel.papply_primitive_rules[gather_p] = _gather_papply_rule
|
||||
|
||||
class ScatterDimensionNumbers(collections.namedtuple(
|
||||
"ScatterDimensionNumbers",
|
||||
|
@ -15,6 +15,9 @@
|
||||
Parallelization primitives.
|
||||
"""
|
||||
|
||||
import numpy as onp
|
||||
|
||||
from jax import ad_util
|
||||
from jax.lax import lax
|
||||
from jax.abstract_arrays import ShapedArray
|
||||
from jax.interpreters import ad
|
||||
@ -131,7 +134,7 @@ def pswapaxes(x, axis_name, axis):
|
||||
# raise ValueError(msg.format(axis_size(axis_name), x.shape[axis]))
|
||||
return pswapaxes_p.bind(x, axis_name=axis_name, axis=axis)
|
||||
|
||||
def psplit(x, axis_name, axis):
|
||||
def psplit(x, axis_name, split_axis, concat_axis):
|
||||
"""Unmap the pmapped axis ``axis_name`` and map ``axis`` with the same name.
|
||||
|
||||
This function is similar to ``pswapaxes`` except the pmapped axis of the input
|
||||
@ -141,15 +144,18 @@ def psplit(x, axis_name, axis):
|
||||
x: array with a mapped axis named ``axis_name``.
|
||||
axis_name: hashable Python object used to name a pmapped axis (see the
|
||||
``pmap`` docstring for more details).
|
||||
axis: int indicating the unmapped axis of ``x`` to map with the name
|
||||
split_axis: int indicating the unmapped axis of ``x`` to map with the name
|
||||
``axis_name``.
|
||||
concat_axis: int indicating the dimension at which to materialize the axis
|
||||
of ``x`` mapped with ``axis_name``.
|
||||
|
||||
Returns:
|
||||
An array with shape ``(axis_size,) + tuple(np.delete(x.shape, axis))`` where
|
||||
``axis_size`` is the size of the mapped axis named ``axis_name`` in the
|
||||
input ``x``.
|
||||
An array with shape ``np.insert(np.delete(x.shape, split_axis), concat_axis,
|
||||
axis_size)`` where ``axis_size`` is the size of the mapped axis named
|
||||
``axis_name`` in the input ``x``.
|
||||
"""
|
||||
return psplit_p.bind(x, axis_name=axis_name, axis=axis)
|
||||
return psplit_p.bind(x, axis_name=axis_name, concat_axis=concat_axis,
|
||||
split_axis=split_axis)
|
||||
|
||||
def psplit_like(x, y, axis_name):
|
||||
"""Ensure the named mapped axis of ``x`` aligns with that of ``y``."""
|
||||
@ -249,14 +255,17 @@ pxla.parallel_translation_rules[pswapaxes_p] = _pswapaxes_translation_rule
|
||||
parallel.serial_pmap_primitive_rules[pswapaxes_p] = _pswapaxes_serial_pmap_rule
|
||||
|
||||
|
||||
def _psplit_serial_pmap_rule(vals, axes, axis):
|
||||
def _psplit_serial_pmap_rule(vals, axes, split_axis, concat_axis):
|
||||
x, = vals
|
||||
axis_in, = axes
|
||||
if x.shape[axis_in] != x.shape[axis]:
|
||||
if x.shape[axis_in] != x.shape[split_axis]:
|
||||
raise ValueError(
|
||||
"psplit between non-square dimensions {} and {} of {}".format(
|
||||
axis_in, axis, x.shape))
|
||||
return x, axis
|
||||
axis_in, split_axis, x.shape))
|
||||
perm = list(range(x.ndim))
|
||||
perm[axis_in] = concat_axis
|
||||
perm[concat_axis] = axis_in
|
||||
return lax.transpose(x, perm), split_axis
|
||||
|
||||
psplit_p = standard_pmap_primitive('psplit')
|
||||
parallel.serial_pmap_primitive_rules[psplit_p] = _psplit_serial_pmap_rule
|
||||
@ -290,7 +299,7 @@ parallel.serial_pmap_primitive_rules[pcollect_p] = _pcollect_serial_pmap_rule
|
||||
# primitives, but that currently causes circular dependencies. More refactoring
|
||||
# might fix this.
|
||||
|
||||
def _dot_papply_rule(name, vals, dims):
|
||||
def _dot_papply_rule(name, size, vals, dims):
|
||||
x, y = vals
|
||||
xdim, ydim = dims
|
||||
if xdim is None:
|
||||
@ -299,7 +308,7 @@ def _dot_papply_rule(name, vals, dims):
|
||||
return lax.dot(x, y), xdim
|
||||
elif ydim == 0:
|
||||
if xdim != x.ndim:
|
||||
x = psplit(x, name, x.ndim)
|
||||
x = psplit(x, name, x.ndim, xdim)
|
||||
x = x[..., None]
|
||||
y = y[..., None, :]
|
||||
return psum(x * y, name), None
|
||||
@ -308,49 +317,62 @@ def _dot_papply_rule(name, vals, dims):
|
||||
return lax.dot(x, y), xdim
|
||||
|
||||
|
||||
def _dot_general_papply_rule(name, vals, dims, dimension_numbers):
|
||||
def _dot_general_papply_rule(name, size, vals, dims, dimension_numbers):
|
||||
x, y = vals
|
||||
xdim, ydim = dims
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
|
||||
if len(lhs_batch) > 0 or len(rhs_batch) > 0:
|
||||
raise NotImplementedError
|
||||
|
||||
def adjust_dims(dims, thresh):
|
||||
return tuple(i - 1 if i >= thresh else i for i in dims if i != thresh)
|
||||
return tuple(i - 1 if i > thresh else i for i in dims if i != thresh)
|
||||
|
||||
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
|
||||
if xdim is not None:
|
||||
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
|
||||
if ydim is not None:
|
||||
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
|
||||
sub_lhs_batch, sub_rhs_batch = lhs_batch, rhs_batch
|
||||
|
||||
sub_dimension_numbers = (
|
||||
(sub_lhs_contract, sub_rhs_contract), (lhs_batch, rhs_batch))
|
||||
def sub_dims(xdim, ydim):
|
||||
sub_lhs_contract, sub_rhs_contract = lhs_contract, rhs_contract
|
||||
sub_lhs_batch, sub_rhs_batch = lhs_batch, rhs_batch
|
||||
if xdim is not None:
|
||||
sub_lhs_batch = adjust_dims(lhs_batch, xdim)
|
||||
sub_lhs_contract = adjust_dims(lhs_contract, xdim)
|
||||
if ydim is not None:
|
||||
sub_rhs_batch = adjust_dims(rhs_batch, ydim)
|
||||
sub_rhs_contract = adjust_dims(rhs_contract, ydim)
|
||||
return (
|
||||
(sub_lhs_contract, sub_rhs_contract), (sub_lhs_batch, sub_rhs_batch))
|
||||
|
||||
if xdim in lhs_contract and ydim in rhs_contract:
|
||||
z = lax.dot_general(x, y, sub_dimension_numbers)
|
||||
return psum(z, name), None
|
||||
elif xdim in lhs_contract:
|
||||
if ydim is not None: # Cannot hide two dimensions, so collect one
|
||||
y = pcollect(y, name)
|
||||
return lax.dot_general(x, y, sub_dimension_numbers), xdim
|
||||
elif ydim in rhs_contract:
|
||||
if xdim is not None: # Cannot hide two dimensions, so collect one
|
||||
x = pcollect(x, name)
|
||||
return lax.dot_general(x, y, sub_dimension_numbers), ydim
|
||||
elif xdim is not None:
|
||||
if ydim is not None: # Cannot hide two dimensions, so collect one
|
||||
y = pcollect(y, name)
|
||||
return lax.dot_general(x, y, sub_dimension_numbers), xdim
|
||||
elif ydim is not None:
|
||||
return lax.dot_general(x, y, sub_dimension_numbers), ydim
|
||||
def cases(x, y, xdim, ydim, xcontract, ycontract):
|
||||
if xdim in xcontract:
|
||||
if ydim in ycontract:
|
||||
# case: both operands are split and contracting
|
||||
z = lax.dot_general(x, y, sub_dims(xdim, ydim))
|
||||
return True, (psum(z, name), None)
|
||||
elif ydim is not None:
|
||||
# case: x split and contracting, y split but not contracting
|
||||
new_ydim = ycontract[xcontract.index(xdim)]
|
||||
y = psplit(y, name, new_ydim, ydim)
|
||||
z = lax.dot_general(x, y, sub_dims(xdim, new_ydim))
|
||||
return True, (psum(z, name), None)
|
||||
else:
|
||||
# case: x split and contracting, y not split
|
||||
return False, 'one operand split and contracting, other is not split'
|
||||
else:
|
||||
return False, 'unhandled case'
|
||||
|
||||
ok, out = cases(x, y, xdim, ydim, lhs_contract, rhs_contract)
|
||||
if not ok:
|
||||
ok, out = cases(y, x, ydim, xdim, rhs_contract, lhs_contract)
|
||||
if not ok:
|
||||
raise NotImplementedError(
|
||||
('papply of dot_general, {}: '
|
||||
'xdim={}, ydim={}, dimension_numbers={}').format(
|
||||
out, xdim, ydim, dimension_numbers))
|
||||
else:
|
||||
return lax.dot_general(x, y, sub_dimension_numbers), None
|
||||
return out
|
||||
|
||||
|
||||
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
|
||||
def _reshape_papply_rule(name, size, vals, axes, new_sizes, dimensions,
|
||||
old_sizes):
|
||||
operand, = vals
|
||||
axis, = axes
|
||||
|
||||
@ -358,42 +380,33 @@ def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
|
||||
return filter(lambda x: x != 1, xs)
|
||||
|
||||
def find_new_axis(old_axis, old_sizes, new_sizes):
|
||||
if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)):
|
||||
return None
|
||||
num_before = len(filter_ones(old_sizes[:old_axis]))
|
||||
sz = old_sizes[old_axis]
|
||||
for i, new_sz in enumerate(new_sizes):
|
||||
if num_before == 0:
|
||||
if new_sz == sz:
|
||||
return i
|
||||
elif new_sz != 1:
|
||||
return None
|
||||
elif new_sz != 1:
|
||||
num_before -= 1
|
||||
left = onp.prod(old_sizes[:old_axis])
|
||||
size = old_sizes[old_axis]
|
||||
prod = 1
|
||||
for i, cur_sz in enumerate(new_sizes):
|
||||
if prod == left and cur_sz == size:
|
||||
return i
|
||||
prod = prod * sz
|
||||
return None
|
||||
|
||||
err = NotImplementedError(
|
||||
'papply of reshape that would change hidden dimension size')
|
||||
|
||||
if dimensions is None:
|
||||
new_axis = find_new_axis(axis, old_sizes, new_sizes)
|
||||
if new_axis is not None:
|
||||
if (lax.prod(old_sizes[:axis]) != lax.prod(new_sizes[:new_axis]) or
|
||||
lax.prod(old_sizes[axis + 1:]) != lax.prod(new_sizes[new_axis + 1:])):
|
||||
raise err
|
||||
new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
|
||||
return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis
|
||||
else:
|
||||
raise err
|
||||
raise NotImplementedError(
|
||||
'papply of reshape that would change hidden dimension size')
|
||||
else:
|
||||
raise NotImplementedError('papply of reshape with `dimensions`')
|
||||
|
||||
|
||||
def _transpose_papply_rule(name, vals, dims, permutation):
|
||||
def _transpose_papply_rule(name, size, vals, dims, permutation):
|
||||
x, = vals
|
||||
xdim, = dims
|
||||
perm = list(permutation)
|
||||
if perm[xdim] == xdim:
|
||||
perm = [i - 1 if i > xdim else i for i in perm if i != xdim]
|
||||
x = lax.transpose(x, perm)
|
||||
out_dim = xdim
|
||||
else:
|
||||
@ -408,7 +421,7 @@ def _transpose_papply_rule(name, vals, dims, permutation):
|
||||
return x, xdim
|
||||
|
||||
|
||||
def _select_papply_rule(name, vals, dims):
|
||||
def _select_papply_rule(name, size, vals, dims):
|
||||
dimset = set([d for d in dims if d is not None])
|
||||
if len(dimset) != 1:
|
||||
raise NotImplementedError(
|
||||
@ -422,8 +435,24 @@ def _select_papply_rule(name, vals, dims):
|
||||
return lax.select_p.bind(*vals), like_dim
|
||||
|
||||
|
||||
def _add_jaxvals_papply_rule(name, size, vals, dims):
|
||||
x, y = vals
|
||||
xdim, ydim = dims
|
||||
if xdim == ydim:
|
||||
out_dim = xdim
|
||||
elif ydim is None:
|
||||
y = lax.psplit_like(y, x, name)
|
||||
out_dim = xdim
|
||||
else:
|
||||
x = lax.psplit_like(x, y, name)
|
||||
out_dim = ydim
|
||||
return ad_util.add_jaxvals_p.bind(x, y), out_dim
|
||||
|
||||
|
||||
parallel.papply_primitive_rules[lax.dot_p] = _dot_papply_rule
|
||||
parallel.papply_primitive_rules[lax.dot_general_p] = _dot_general_papply_rule
|
||||
parallel.papply_primitive_rules[lax.reshape_p] = _reshape_papply_rule
|
||||
parallel.papply_primitive_rules[lax.transpose_p] = _transpose_papply_rule
|
||||
parallel.papply_primitive_rules[lax.select_p] = _select_papply_rule
|
||||
parallel.papply_primitive_rules[ad_util.add_jaxvals_p] = (
|
||||
_add_jaxvals_papply_rule)
|
||||
|
@ -53,7 +53,7 @@ class SerialPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsplit(self):
|
||||
f = lambda x: lax.psplit(x, 'i', 2)
|
||||
f = lambda x: lax.psplit(x, 'i', 2, 0)
|
||||
arg = onp.arange(3 * 2 * 3 * 5).reshape(3, 2, 3, 5)
|
||||
ans = _serial_pmap(f, axis_name='i', out_axes=2)(arg)
|
||||
expected = arg
|
||||
|
Loading…
x
Reference in New Issue
Block a user