Add batch_group_count to conv_general_dilated. (#2635)

* Add batch_group_count to conv_general_dilated.

* Use batch_group_count for RHS grouped convolution transpose rule.

* Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
This commit is contained in:
Peter Hawkins 2020-04-09 16:21:30 -04:00 committed by GitHub
parent 1694a56fa3
commit 1bb67637ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 181 additions and 63 deletions

View File

@ -34,6 +34,7 @@ from .. import api
from .. import linear_util as lu
from .. import dtypes
from .. import lazy
from .. import lib
from ..config import flags
from ..core import _canonicalize_dimension, Primitive
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
@ -462,7 +463,8 @@ def conv_general_dilated(
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, precision: Optional[PrecisionType] = None) -> Array:
feature_group_count: int = 1, batch_group_count: int = 1,
precision: Optional[PrecisionType] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
@ -487,6 +489,7 @@ def conv_general_dilated(
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either `None`, which means the default precision for
the backend, or a `Precision` enum value.
@ -543,6 +546,7 @@ def conv_general_dilated(
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
dimension_numbers=dnums,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
lhs_shape=lhs.shape, rhs_shape=rhs.shape,
precision=_canonicalize_precision(precision))
@ -2145,7 +2149,8 @@ masking.defvectorized(bitcast_convert_type_p)
def _conv_general_dilated_shape_rule(
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, **unused_kwargs):
dimension_numbers, feature_group_count, batch_group_count,
**unused_kwargs):
assert type(dimension_numbers) is ConvDimensionNumbers
if not feature_group_count > 0:
msg = ("conv_general_dilated feature_group_count "
@ -2168,10 +2173,32 @@ def _conv_general_dilated_shape_rule(
"multiple of feature_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
feature_group_count))
if not batch_group_count > 0:
msg = ("conv_general_dilated batch_group_count "
"must be a positive integer, got {}.")
raise ValueError(msg.format(batch_group_count))
lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
if lhs_batch_count % batch_group_count != 0:
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
"dimension size, but {} does not divide {}.")
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
msg = ("conv_general_dilated rhs output feature dimension size must be a "
"multiple of batch_group_count, but {} is not a multiple of {}.")
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
batch_ground_count))
if not batch_group_count > 0 and feature_group_count > 0:
msg = ("At most one of batch_group_count and feature_group_count may be > "
"1, got batch_group_count={} and feature_group_count={}")
raise ValueError(msg.format(batch_group_count, feature_group_count))
lhs_perm, rhs_perm, out_perm = dimension_numbers
lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation)
rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
batch_group_count)
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
def _conv_general_dilated_dtype_rule(
@ -2183,11 +2210,39 @@ def _conv_general_dilated_dtype_rule(
_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
_conv_sdims = lambda spec: spec[2:]
# Understanding the convolution transpose rules:
# Ignoring the spatial dimensions, let m = batch, j = input feature,
# k = output feature.
#
# Convolution computes the following contraction:
# Forward: [m, j] [j, k] -> [m, k]
#
# The transposes are similar to the rules for transposing a matmul:
# LHS transpose: [m, k] [k, j] -> [m, j]
# RHS transpose: [j, m] [m, k] -> [j, k]
#
# With feature grouping, we have the following signatures:
# Forward: [m, gj] [j, gk] -> [m, gk]
# LHS transpose: [m, gk] [k, gj] -> [m, gj]
# --> implemented as feature grouping after transposing the group from the
# kernel input features to the kernel output features.
# RHS transpose: [gj, m] [m, gk] -> [j, gk]
# --> which is batch grouping.
#
# With batch grouping, we have the following signatures:
# Forward: [gm,j] [j,gk]->[m,gk]
# LHS transpose: [m, gk][gk, j] -> [gm, j]
# --> implemented as feature grouping with transposing the group on the kernel
# and the output.
# RHS transpose: [j, gm][m, gk] -> [j, gk]
# --> which is feature grouping.
def _conv_general_dilated_transpose_lhs(
g, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count,
dimension_numbers, feature_group_count, batch_group_count,
lhs_shape, rhs_shape, precision):
assert type(dimension_numbers) is ConvDimensionNumbers
assert batch_group_count == 1 or feature_group_count == 1
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
t_rhs_spec = _conv_spec_transpose(rhs_spec)
@ -2196,31 +2251,52 @@ def _conv_general_dilated_transpose_lhs(
# group axis into the transposed rhs's output feature dim
rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs)
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
elif batch_group_count > 1:
rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs)
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
feature_group_count = batch_group_count
trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec)
padding = _conv_general_vjp_lhs_padding(
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
rhs_dilation)
revd_weights = rev(rhs, rhs_sdims)
return conv_general_dilated(
out = conv_general_dilated(
g, revd_weights, window_strides=lhs_dilation, padding=padding,
lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
dimension_numbers=trans_dimension_numbers,
feature_group_count=feature_group_count, precision=precision)
feature_group_count=feature_group_count,
batch_group_count=1, precision=precision)
if batch_group_count > 1:
out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out)
out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out)
return out
# TODO(phawkins): remove when the minimum jaxlib version is incremented past
# 0.1.43.
_jaxlib_has_working_batch_group_count = lib.version > (0, 1, 43)
def _conv_general_dilated_transpose_rhs(
g, lhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count,
lhs_shape, rhs_shape, precision):
dimension_numbers: ConvDimensionNumbers, feature_group_count: int,
batch_group_count: int, lhs_shape, rhs_shape, precision):
assert type(dimension_numbers) is ConvDimensionNumbers
if onp.size(g) == 0:
# Avoids forming degenerate convolutions where the RHS has spatial size 0.
return ad_util.zero
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
if feature_group_count > 1:
lhs = _reshape_axis_out_of(lhs_trans[0], feature_group_count, lhs)
lhs = _reshape_axis_into(lhs_trans[0], lhs_trans[1], lhs)
assert batch_group_count == 1 or feature_group_count == 1
if batch_group_count > 1:
feature_group_count = batch_group_count
batch_group_count = 1
elif feature_group_count > 1:
if _jaxlib_has_working_batch_group_count:
batch_group_count = feature_group_count
feature_group_count = 1
else:
lhs = _reshape_axis_out_of(lhs_trans[0], feature_group_count, lhs)
lhs = _reshape_axis_into(lhs_trans[0], lhs_trans[1], lhs)
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
padding = _conv_general_vjp_rhs_padding(
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
@ -2230,70 +2306,96 @@ def _conv_general_dilated_transpose_rhs(
lhs, g, window_strides=rhs_dilation, padding=padding,
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
dimension_numbers=trans_dimension_numbers,
feature_group_count=feature_group_count, precision=precision)
feature_group_count=feature_group_count,
batch_group_count=batch_group_count, precision=precision)
def _conv_general_dilated_translation_rule(
c, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, precision, **unused_kwargs):
dimension_numbers, feature_group_count, batch_group_count, precision,
**unused_kwargs):
assert type(dimension_numbers) is ConvDimensionNumbers
dimension_numbers = _conv_general_proto(dimension_numbers)
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count,
feature_group_count, batch_group_count,
precision_config=_precision_config(precision))
def _conv_general_dilated_batch_rule(
batched_args, batch_dims, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision, **unused_kwargs):
feature_group_count, batch_group_count, precision, **unused_kwargs):
assert batch_group_count == 1 or feature_group_count == 1
lhs, rhs = batched_args
lhs_bdim, rhs_bdim = batch_dims
lhs_spec, rhs_spec, out_spec = dimension_numbers
if lhs_bdim is not None and rhs_bdim is not None:
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
if batch_group_count > 1:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
batch_group_count *= lhs.shape[lhs_bdim]
else:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
feature_group_count *= lhs.shape[lhs_bdim]
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(
new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers,
feature_group_count=lhs.shape[lhs_bdim] * feature_group_count,
dimension_numbers, feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=precision)
out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
return out, out_spec[1]
elif lhs_bdim is not None:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision=precision)
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
return out, out_spec[0]
if batch_group_count == 1:
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision=precision)
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
return out, out_spec[0]
else:
new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]),
batch_group_count, lhs)
new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim),
lhs_spec[0] + 1,
new_lhs)
new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs)
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision)
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
return out, out_spec[0]
elif rhs_bdim is not None:
if feature_group_count == 1:
if feature_group_count == 1 and batch_group_count == 1:
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision=precision)
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision)
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
return out, out_spec[1]
else:
# feature_group needs to be outermost, so we need to factor it out of the
# groups need to be outermost, so we need to factor them out of the
# rhs output feature dim, then factor the batch dim into the remaining rhs
# output feature dim, then put feature_group back in. we do something
# similar on the output. an alternative which would require more FLOPs but
# output feature dim, then put groups back in. We do something
# similar on the output. An alternative which would require more FLOPs but
# fewer reshapes would be to broadcast lhs.
group_count = (feature_group_count if feature_group_count > 1
else batch_group_count)
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
feature_group_count, rhs)
group_count, rhs)
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
rhs_spec[0] + 1,
new_rhs)
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, precision=precision)
out = _reshape_axis_out_of(out_spec[1], feature_group_count, out)
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision=precision)
out = _reshape_axis_out_of(out_spec[1], group_count, out)
out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out)
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
return out, out_spec[1]
@ -4612,7 +4714,7 @@ def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
raise TypeError(msg.format(name, expected_length, len(window_strides)))
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
"""Compute the shape tuple of a conv given input shapes in canonical order."""
if isinstance(pads, str):
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
@ -4625,8 +4727,9 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
out_space = onp.floor_divide(
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
out_space = onp.maximum(0, out_space)
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
return tuple(out_shape)
assert lhs_shape[0] % batch_group_count == 0
out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0])
return tuple(out_shape + tuple(out_space))
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
@ -4639,7 +4742,7 @@ def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
dimension_numbers):
dimension_numbers):
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
lhs_trans = onp.take(lhs_shape, lhs_perm)
rhs_trans = onp.take(rhs_shape, rhs_perm)

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.1.43"
__version__ = "0.1.44"

View File

@ -469,9 +469,13 @@ class LaxTest(jtu.JaxTestCase):
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
"rhs_dilation": rhs_dilation, "dimension_numbers": dim_nums,
"feature_group_count": feature_group_count,
"batch_group_count": batch_group_count,
"perms": perms, "rng_factory": rng_factory}
for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)]
for lhs_shape, rhs_shape in [
((b, i, 9, w), (j, i, 4, 5))
((b * batch_group_count, i * feature_group_count, 9, w),
(j * feature_group_count * batch_group_count, i, 4, 5))
for w in [0, 10]
for b, i, j in itertools.product([2, 3], repeat=3)]
for dtype in float_dtypes for strides in [(1, 1), (2, 1)]
@ -486,6 +490,7 @@ class LaxTest(jtu.JaxTestCase):
]))
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dilation, rhs_dilation,
feature_group_count, batch_group_count,
dimension_numbers, perms, rng_factory):
rng = rng_factory()
lhs_perm, rhs_perm = perms # permute to compatible shapes
@ -497,7 +502,8 @@ class LaxTest(jtu.JaxTestCase):
def fun(lhs, rhs):
return lax.conv_general_dilated(
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers)
dimension_numbers, feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
@ -2004,24 +2010,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}_feature_group_count={}"
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
feature_group_count),
feature_group_count, batch_group_count),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
"perms": perms, "feature_group_count": feature_group_count}
"perms": perms, "feature_group_count": feature_group_count,
"batch_group_count": batch_group_count}
# TODO(phawkins): make batch_group_count tests unconditional after
# minimum jaxlib version is 0.1.44 or greater.
for batch_group_count, feature_group_count in (
[(1, 1), (2, 1), (1, 2)] if jax.lib.version > (0, 1, 43)
else [(1, 1), (1, 2)])
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
([(b, i, 6, 7), (b, i, 0, 4)], # lhs_shape
(j, i, 1, 2), # rhs_shape
([(b * batch_group_count, i * feature_group_count, 6, 7),
(b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape
(j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape
[(1, 1), (1, 2), (2, 1)], # strides
[(1, 1), (2, 1)], # lhs_dils
[(1, 1), (2, 2)]) # rhs_dils
for b, i, j in itertools.product([1, 2], repeat=3)]
for lhs_shape in lhs_shapes
for feature_group_count in [1, 2]
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
@ -2036,7 +2048,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
))
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, dimension_numbers,
perms, feature_group_count, rng_factory):
perms, feature_group_count, batch_group_count,
rng_factory):
rng = rng_factory()
tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 1e-4}
@ -2044,9 +2057,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
lhs_perm, rhs_perm = perms
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
@ -2054,6 +2064,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=lax.Precision.HIGHEST)
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
atol=tol, rtol=tol)
@ -2684,25 +2695,32 @@ class LaxVmapTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}_feature_group_count={}_lhs_bdim={}_rhs_bdim={}"
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
"_lhs_bdim={}_rhs_bdim={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
feature_group_count, lhs_bdim, rhs_bdim),
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
"feature_group_count": feature_group_count}
"feature_group_count": feature_group_count,
"batch_group_count": batch_group_count,
}
# TODO(phawkins): make batch_group_count tests unconditional after
# minimum jaxlib version is 0.1.44 or greater.
for batch_group_count, feature_group_count in (
[(1, 1), (2, 1), (1, 2)] if jax.lib.version > (0, 1, 43)
else [(1, 1), (1, 2)])
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
((b, i, 6, 7), # lhs_shape
(j, i, 1, 2), # rhs_shape
((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape
(j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape
[(1, 1), (1, 2), (2, 1)], # strides
[((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads
[(1, 1), (2, 1)], # lhs_dils
[(1, 1), (2, 2)]) # rhs_dils
for b, i, j in itertools.product([1, 2], repeat=3)]
for feature_group_count in [1, 2]
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
@ -2719,11 +2737,10 @@ class LaxVmapTest(jtu.JaxTestCase):
if (lhs_bdim, rhs_bdim) != (None, None)
for rng_factory in [jtu.rand_default]
))
# TODO(mattjj): some cases fail on TPU just due to numerical tolerances
@jtu.skip_on_devices("tpu")
def testConvGeneralDilatedBatching(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
dimension_numbers, perms, feature_group_count, lhs_bdim, rhs_bdim, rng_factory):
dimension_numbers, perms, feature_group_count, batch_group_count,
lhs_bdim, rhs_bdim, rng_factory):
rng = rng_factory()
tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3
@ -2731,14 +2748,12 @@ class LaxVmapTest(jtu.JaxTestCase):
lhs_perm, rhs_perm = perms
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
conv = partial(lax.conv_general_dilated, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=lax.Precision.HIGHEST)
self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
(dtype, dtype), rng, rtol=tol, atol=tol)