mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
grouped convolution support
This commit is contained in:
parent
ff29d582e8
commit
077d56529f
@ -438,6 +438,18 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
|
||||
lhs_dilation = (1,) * (lhs.ndim - 2)
|
||||
if rhs_dilation is None:
|
||||
rhs_dilation = (1,) * (rhs.ndim - 2)
|
||||
if batch_group_count != 1 and batch_group_count != lhs.shape[dimension_numbers[0][0]]:
|
||||
# XLA currently doesn't support 1 < batch_group_count < batch_size
|
||||
# so we use the batch rule to rewrite into a convolution using feature_group_count
|
||||
lhs = _reshape_axis_out_of(dimension_numbers[0][0], batch_group_count, lhs)
|
||||
rhs = _reshape_axis_out_of(dimension_numbers[1][0], batch_group_count, rhs)
|
||||
out, out_batch_dim = _conv_general_dilated_batch_rule((lhs, rhs),
|
||||
(dimension_numbers[0][0], dimension_numbers[1][0]),
|
||||
window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count=1)
|
||||
out = _reshape_axis_into(out_batch_dim, dimension_numbers[2][1], out)
|
||||
return out
|
||||
return conv_general_dilated_p.bind(
|
||||
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
|
||||
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
|
||||
@ -1769,10 +1781,21 @@ def _conv_general_dilated_shape_rule(
|
||||
"multiple of feature_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
feature_group_count))
|
||||
if lhs.shape[dimension_numbers.lhs_spec[0]] % batch_group_count:
|
||||
msg = ("conv_general_dilated lhs input batch dimension size must be a "
|
||||
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(lhs.shape[dimension_numbers.lhs_spec[0]],
|
||||
batch_group_count))
|
||||
if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
|
||||
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
||||
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
batch_group_count))
|
||||
lhs_perm, rhs_perm, out_perm = dimension_numbers
|
||||
lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation)
|
||||
rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
||||
batch_group_count)
|
||||
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
|
||||
|
||||
def _conv_general_dilated_dtype_rule(
|
||||
@ -1820,11 +1843,7 @@ def _conv_general_dilated_transpose_rhs(
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
lhs_shape, rhs_shape):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
if not feature_group_count == batch_group_count == 1:
|
||||
msg = ("conv_general_dilated transpose rule is only implemented for "
|
||||
"feature_group_count == batch_group_count == 1, but got {} and {}. "
|
||||
"Open a feature request!")
|
||||
raise NotImplementedError(msg.format(feature_group_count, batch_group_count))
|
||||
|
||||
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
|
||||
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
|
||||
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
|
||||
@ -1836,8 +1855,8 @@ def _conv_general_dilated_transpose_rhs(
|
||||
lhs, g, window_strides=rhs_dilation, padding=padding,
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
|
||||
dimension_numbers=trans_dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
feature_group_count=batch_group_count,
|
||||
batch_group_count=feature_group_count)
|
||||
|
||||
def _conv_general_dilated_translation_rule(
|
||||
c, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
@ -1876,7 +1895,6 @@ def _conv_general_dilated_batch_rule(
|
||||
return out, out_spec[0]
|
||||
|
||||
elif rhs_bdim is not None:
|
||||
num_output_features = feature_group_count
|
||||
if feature_group_count == 1:
|
||||
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
||||
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
||||
@ -3934,7 +3952,7 @@ def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
|
||||
raise TypeError(msg.format(name, expected_length, len(window_strides)))
|
||||
|
||||
|
||||
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
|
||||
"""Compute the shape tuple of a conv given input shapes in canonical order."""
|
||||
if isinstance(pads, str):
|
||||
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
|
||||
@ -3946,16 +3964,17 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
out_space = onp.floor_divide(
|
||||
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
||||
out_space = onp.maximum(0, out_space)
|
||||
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
|
||||
out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0]) + tuple(out_space)
|
||||
return tuple(out_shape)
|
||||
|
||||
|
||||
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
||||
dimension_numbers):
|
||||
dimension_numbers, batch_group_count=1):
|
||||
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
||||
lhs_trans = onp.take(lhs_shape, lhs_perm)
|
||||
rhs_trans = onp.take(rhs_shape, rhs_perm)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
||||
batch_group_count)
|
||||
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user