Added batching rules for convolutions + pooling.

Added batching rules:
conv_general_dilated_batch_rule
select_and_scatter_add_batch_rule
reduce_window_max_batch_rule
reduce_window_sum_batch_rule
This commit is contained in:
sschoenholz 2019-01-28 14:33:57 -08:00 committed by GitHub
parent d968e1e572
commit a15bad401f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1128,13 +1128,61 @@ def conv_general_dilated_translation_rule(
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers)
def conv_general_dilated_batch_rule(
batched_args, batch_dims, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, **unused_kwargs):
lhs, rhs = batched_args
lhs_bdim, rhs_bdim = batch_dims
lhs_dim, rhs_dim, out_dim = dimension_numbers
if lhs_bdim is not None and rhs_bdim is not None:
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
outputs = [
conv_general_dilated(l, r, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
for l, r in zip(lhs, rhs)]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
elif lhs_bdim is not None:
# Currently we don't handle cases where the batch dimension of the
# convolution isn't the first dimension.
if lhs_dim[0] != 0 or out_dim[0] != 0:
raise NotImplementedError
lhs = batching.move_dim_to_front(lhs, lhs_dim[0])
lhs = batching.move_dim_to_front(lhs, lhs_bdim)
batched_size = lhs.shape[0]
n_size = lhs.shape[1]
lhs = reshape(lhs, (batched_size * n_size,) + lhs.shape[2:])
outputs = conv_general_dilated(
lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
outputs = reshape(outputs, (batched_size, n_size,) + outputs.shape[1:])
return outputs, 0
elif rhs_bdim is not None:
# TODO(schsam): Consider a loop instead of unrolling.
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
outputs = [
conv_general_dilated(lhs, x, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
for x in rhs]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
conv_general_dilated_p = standard_primitive(
conv_general_dilated_shape_rule, conv_general_dilated_dtype_rule,
'conv_general_dilated', conv_general_dilated_translation_rule)
ad.defbilinear(conv_general_dilated_p,
conv_general_dilated_transpose_lhs,
conv_general_dilated_transpose_rhs)
batching.primitive_batchers[
conv_general_dilated_p] = conv_general_dilated_batch_rule
def dot_shape_rule(lhs, rhs):
if lhs.ndim == 0 or rhs.ndim == 0:
@ -2291,12 +2339,26 @@ def reduce_window_sum_transpose_rule(cotangent, window_dimensions,
xla_bridge.get_xla_client().PaddingType.VALID)
assert result.shape == input_shape
return [result]
def reduce_window_sum_batch_rule(
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
operand, = batched_args
bdim, = bdims
if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
oprand = _reduce_window_sum(
operand, window_dimensions, window_strides, padding)
return oprand, 0
reduce_window_sum_p = standard_primitive(
reduce_window_sum_shape_rule, _input_dtype, 'reduce_window_sum',
reduce_window_sum_translation_rule)
ad.deflinear(reduce_window_sum_p, reduce_window_sum_transpose_rule)
batching.primitive_batchers[reduce_window_sum_p] = reduce_window_sum_batch_rule
def reduce_window_chooser_translation_rule(
prim, identity, c, operand, window_dimensions, window_strides, padding):
@ -2338,6 +2400,20 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
onp.subtract(operand_padded, window_dimensions), window_strides) + 1
return tuple(t)
def reduce_window_max_batch_rule(
batched_args, bdims, window_dimensions, window_strides, padding, **kwargs):
operand, = batched_args
bdim, = bdims
if bdim is not None:
window_dimensions = \
window_dimensions[:bdim] + (1,) + window_dimensions[bdim:]
window_strides = window_strides[:bdim] + (1,) + window_strides[bdim:]
operand = _reduce_window_max(
operand, window_dimensions, window_strides, padding)
return operand, 0
reduce_window_max_translation_rule = partial(
reduce_window_chooser_translation_rule, max_p, _get_max_identity)
@ -2345,7 +2421,7 @@ reduce_window_max_p = standard_primitive(
common_reduce_window_shape_rule, _input_dtype, 'reduce_window_max',
reduce_window_max_translation_rule)
ad.defjvp(reduce_window_max_p, partial(reduce_window_chooser_jvp_rule, max_p))
batching.primitive_batchers[reduce_window_max_p] = reduce_window_max_batch_rule
reduce_window_min_translation_rule = partial(
reduce_window_chooser_translation_rule, min_p, _get_min_identity)
@ -2402,12 +2478,40 @@ def select_and_scatter_add_transpose(
window_strides, padding)
return [result, None]
def select_and_scatter_add_batch_rule(batched_args, batch_dims, **kwargs):
source, operand = batched_args
s_bdims, o_bdims = batch_dims
if s_bdims is not None and o_bdims is not None:
source = batching.move_dim_to_front(source, s_bdims)
operand = batching.move_dim_to_front(operand, o_bdims)
outputs = [
_select_and_scatter_add(s, o, **kwargs) for s, o in zip(source, operand)]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
elif s_bdims is not None:
source = batching.move_dim_to_front(source, s_bdims)
outputs = [
_select_and_scatter_add(s, operand, **kwargs) for s in source]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
elif o_bdims is not None:
operand = batching.move_dim_to_front(operand, o_bdims)
outputs = [
_select_and_scatter_add(source, o, **kwargs) for o in operand]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
select_and_scatter_add_p = standard_primitive(
select_and_scatter_add_shape_rule, _input_dtype, 'select_and_scatter_add',
select_and_scatter_add_translation)
ad.primitive_transposes[select_and_scatter_add_p] = \
select_and_scatter_add_transpose
batching.primitive_batchers[select_and_scatter_add_p] = \
select_and_scatter_add_batch_rule
def _select_and_gather_add_shape_rule(
tangents, operand, select_prim, pair_select_jaxpr, pair_select_consts,