grouped convolution support

This commit is contained in:
Jonathan Heek 2019-06-17 12:18:58 +02:00 committed by Jonathan Heek
parent ff29d582e8
commit 077d56529f

View File

@ -438,6 +438,18 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
lhs_dilation = (1,) * (lhs.ndim - 2)
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
if batch_group_count != 1 and batch_group_count != lhs.shape[dimension_numbers[0][0]]:
# XLA currently doesn't support 1 < batch_group_count < batch_size
# so we use the batch rule to rewrite into a convolution using feature_group_count
lhs = _reshape_axis_out_of(dimension_numbers[0][0], batch_group_count, lhs)
rhs = _reshape_axis_out_of(dimension_numbers[1][0], batch_group_count, rhs)
out, out_batch_dim = _conv_general_dilated_batch_rule((lhs, rhs),
(dimension_numbers[0][0], dimension_numbers[1][0]),
window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count=1)
out = _reshape_axis_into(out_batch_dim, dimension_numbers[2][1], out)
return out
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
@ -1769,10 +1781,21 @@ def _conv_general_dilated_shape_rule(
"multiple of feature_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
feature_group_count))
if lhs.shape[dimension_numbers.lhs_spec[0]] % batch_group_count:
msg = ("conv_general_dilated lhs input batch dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(lhs.shape[dimension_numbers.lhs_spec[0]],
batch_group_count))
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
batch_group_count))
lhs_perm, rhs_perm, out_perm = dimension_numbers
lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation)
rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
def _conv_general_dilated_dtype_rule(
@ -1820,11 +1843,7 @@ def _conv_general_dilated_transpose_rhs(
dimension_numbers, feature_group_count, batch_group_count,
lhs_shape, rhs_shape):
assert type(dimension_numbers) is ConvDimensionNumbers
if not feature_group_count == batch_group_count == 1:
msg = ("conv_general_dilated transpose rule is only implemented for "
"feature_group_count == batch_group_count == 1, but got {} and {}. "
"Open a feature request!")
raise NotImplementedError(msg.format(feature_group_count, batch_group_count))
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
@ -1836,8 +1855,8 @@ def _conv_general_dilated_transpose_rhs(
lhs, g, window_strides=rhs_dilation, padding=padding,
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
dimension_numbers=trans_dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
feature_group_count=batch_group_count,
batch_group_count=feature_group_count)
def _conv_general_dilated_translation_rule(
c, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
@ -1876,7 +1895,6 @@ def _conv_general_dilated_batch_rule(
return out, out_spec[0]
elif rhs_bdim is not None:
num_output_features = feature_group_count
if feature_group_count == 1:
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
@ -3934,7 +3952,7 @@ def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
raise TypeError(msg.format(name, expected_length, len(window_strides)))
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
"""Compute the shape tuple of a conv given input shapes in canonical order."""
if isinstance(pads, str):
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
@ -3946,16 +3964,17 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
out_space = onp.floor_divide(
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
out_space = onp.maximum(0, out_space)
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) + tuple(out_space)
return tuple(out_shape)
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
dimension_numbers):
dimension_numbers, batch_group_count=1):
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
lhs_trans = onp.take(lhs_shape, lhs_perm)
rhs_trans = onp.take(rhs_shape, rhs_perm)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(onp.take(out_trans, onp.argsort(out_perm)))