mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add batch_group_count to conv_general_dilated. (#2635)
* Add batch_group_count to conv_general_dilated. * Use batch_group_count for RHS grouped convolution transpose rule. * Implement lhs/rhs transpose and batching rules for batch_group_count convolution.
This commit is contained in:
parent
1694a56fa3
commit
1bb67637ca
179
jax/lax/lax.py
179
jax/lax/lax.py
@ -34,6 +34,7 @@ from .. import api
|
||||
from .. import linear_util as lu
|
||||
from .. import dtypes
|
||||
from .. import lazy
|
||||
from .. import lib
|
||||
from ..config import flags
|
||||
from ..core import _canonicalize_dimension, Primitive
|
||||
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
@ -462,7 +463,8 @@ def conv_general_dilated(
|
||||
lhs_dilation: Optional[Sequence[int]] = None,
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
|
||||
feature_group_count: int = 1, precision: Optional[PrecisionType] = None) -> Array:
|
||||
feature_group_count: int = 1, batch_group_count: int = 1,
|
||||
precision: Optional[PrecisionType] = None) -> Array:
|
||||
"""General n-dimensional convolution operator, with optional dilation.
|
||||
|
||||
Wraps XLA's `Conv
|
||||
@ -487,6 +489,7 @@ def conv_general_dilated(
|
||||
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
|
||||
of length `n+2`.
|
||||
feature_group_count: integer, default 1. See XLA HLO docs.
|
||||
batch_group_count: integer, default 1. See XLA HLO docs.
|
||||
precision: Optional. Either `None`, which means the default precision for
|
||||
the backend, or a `Precision` enum value.
|
||||
|
||||
@ -543,6 +546,7 @@ def conv_general_dilated(
|
||||
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
|
||||
dimension_numbers=dnums,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
lhs_shape=lhs.shape, rhs_shape=rhs.shape,
|
||||
precision=_canonicalize_precision(precision))
|
||||
|
||||
@ -2145,7 +2149,8 @@ masking.defvectorized(bitcast_convert_type_p)
|
||||
|
||||
def _conv_general_dilated_shape_rule(
|
||||
lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, **unused_kwargs):
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
**unused_kwargs):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
if not feature_group_count > 0:
|
||||
msg = ("conv_general_dilated feature_group_count "
|
||||
@ -2168,10 +2173,32 @@ def _conv_general_dilated_shape_rule(
|
||||
"multiple of feature_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
feature_group_count))
|
||||
|
||||
if not batch_group_count > 0:
|
||||
msg = ("conv_general_dilated batch_group_count "
|
||||
"must be a positive integer, got {}.")
|
||||
raise ValueError(msg.format(batch_group_count))
|
||||
lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
|
||||
if lhs_batch_count % batch_group_count != 0:
|
||||
msg = ("conv_general_dilated batch_group_count must divide lhs batch "
|
||||
"dimension size, but {} does not divide {}.")
|
||||
raise ValueError(msg.format(batch_group_count, lhs_batch_count))
|
||||
if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
|
||||
msg = ("conv_general_dilated rhs output feature dimension size must be a "
|
||||
"multiple of batch_group_count, but {} is not a multiple of {}.")
|
||||
raise ValueError(msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
|
||||
batch_ground_count))
|
||||
|
||||
if not batch_group_count > 0 and feature_group_count > 0:
|
||||
msg = ("At most one of batch_group_count and feature_group_count may be > "
|
||||
"1, got batch_group_count={} and feature_group_count={}")
|
||||
raise ValueError(msg.format(batch_group_count, feature_group_count))
|
||||
|
||||
lhs_perm, rhs_perm, out_perm = dimension_numbers
|
||||
lhs_trans = _dilate_shape(onp.take(lhs.shape, lhs_perm), lhs_dilation)
|
||||
rhs_trans = _dilate_shape(onp.take(rhs.shape, rhs_perm), rhs_dilation)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding)
|
||||
out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
|
||||
batch_group_count)
|
||||
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
|
||||
|
||||
def _conv_general_dilated_dtype_rule(
|
||||
@ -2183,11 +2210,39 @@ def _conv_general_dilated_dtype_rule(
|
||||
_conv_spec_transpose = lambda spec: (spec[1], spec[0]) + spec[2:]
|
||||
_conv_sdims = lambda spec: spec[2:]
|
||||
|
||||
# Understanding the convolution transpose rules:
|
||||
# Ignoring the spatial dimensions, let m = batch, j = input feature,
|
||||
# k = output feature.
|
||||
#
|
||||
# Convolution computes the following contraction:
|
||||
# Forward: [m, j] [j, k] -> [m, k]
|
||||
#
|
||||
# The transposes are similar to the rules for transposing a matmul:
|
||||
# LHS transpose: [m, k] [k, j] -> [m, j]
|
||||
# RHS transpose: [j, m] [m, k] -> [j, k]
|
||||
#
|
||||
# With feature grouping, we have the following signatures:
|
||||
# Forward: [m, gj] [j, gk] -> [m, gk]
|
||||
# LHS transpose: [m, gk] [k, gj] -> [m, gj]
|
||||
# --> implemented as feature grouping after transposing the group from the
|
||||
# kernel input features to the kernel output features.
|
||||
# RHS transpose: [gj, m] [m, gk] -> [j, gk]
|
||||
# --> which is batch grouping.
|
||||
#
|
||||
# With batch grouping, we have the following signatures:
|
||||
# Forward: [gm,j] [j,gk]->[m,gk]
|
||||
# LHS transpose: [m, gk][gk, j] -> [gm, j]
|
||||
# --> implemented as feature grouping with transposing the group on the kernel
|
||||
# and the output.
|
||||
# RHS transpose: [j, gm][m, gk] -> [j, gk]
|
||||
# --> which is feature grouping.
|
||||
|
||||
def _conv_general_dilated_transpose_lhs(
|
||||
g, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
lhs_shape, rhs_shape, precision):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
assert batch_group_count == 1 or feature_group_count == 1
|
||||
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
t_rhs_spec = _conv_spec_transpose(rhs_spec)
|
||||
@ -2196,31 +2251,52 @@ def _conv_general_dilated_transpose_lhs(
|
||||
# group axis into the transposed rhs's output feature dim
|
||||
rhs = _reshape_axis_out_of(rhs_spec[0], feature_group_count, rhs)
|
||||
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
|
||||
elif batch_group_count > 1:
|
||||
rhs = _reshape_axis_out_of(rhs_spec[0], batch_group_count, rhs)
|
||||
rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[1], rhs)
|
||||
feature_group_count = batch_group_count
|
||||
trans_dimension_numbers = ConvDimensionNumbers(out_spec, t_rhs_spec, lhs_spec)
|
||||
padding = _conv_general_vjp_lhs_padding(
|
||||
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
|
||||
window_strides, onp.take(g.shape, out_sdims), padding, lhs_dilation,
|
||||
rhs_dilation)
|
||||
revd_weights = rev(rhs, rhs_sdims)
|
||||
return conv_general_dilated(
|
||||
out = conv_general_dilated(
|
||||
g, revd_weights, window_strides=lhs_dilation, padding=padding,
|
||||
lhs_dilation=window_strides, rhs_dilation=rhs_dilation,
|
||||
dimension_numbers=trans_dimension_numbers,
|
||||
feature_group_count=feature_group_count, precision=precision)
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=1, precision=precision)
|
||||
if batch_group_count > 1:
|
||||
out = _reshape_axis_out_of(lhs_spec[1], batch_group_count, out)
|
||||
out = _reshape_axis_into(lhs_spec[1], lhs_spec[0], out)
|
||||
return out
|
||||
|
||||
# TODO(phawkins): remove when the minimum jaxlib version is incremented past
|
||||
# 0.1.43.
|
||||
_jaxlib_has_working_batch_group_count = lib.version > (0, 1, 43)
|
||||
|
||||
def _conv_general_dilated_transpose_rhs(
|
||||
g, lhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count,
|
||||
lhs_shape, rhs_shape, precision):
|
||||
dimension_numbers: ConvDimensionNumbers, feature_group_count: int,
|
||||
batch_group_count: int, lhs_shape, rhs_shape, precision):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
if onp.size(g) == 0:
|
||||
# Avoids forming degenerate convolutions where the RHS has spatial size 0.
|
||||
return ad_util.zero
|
||||
lhs_sdims, rhs_sdims, out_sdims = map(_conv_sdims, dimension_numbers)
|
||||
lhs_trans, rhs_trans, out_trans = map(_conv_spec_transpose, dimension_numbers)
|
||||
if feature_group_count > 1:
|
||||
lhs = _reshape_axis_out_of(lhs_trans[0], feature_group_count, lhs)
|
||||
lhs = _reshape_axis_into(lhs_trans[0], lhs_trans[1], lhs)
|
||||
assert batch_group_count == 1 or feature_group_count == 1
|
||||
if batch_group_count > 1:
|
||||
feature_group_count = batch_group_count
|
||||
batch_group_count = 1
|
||||
elif feature_group_count > 1:
|
||||
if _jaxlib_has_working_batch_group_count:
|
||||
batch_group_count = feature_group_count
|
||||
feature_group_count = 1
|
||||
else:
|
||||
lhs = _reshape_axis_out_of(lhs_trans[0], feature_group_count, lhs)
|
||||
lhs = _reshape_axis_into(lhs_trans[0], lhs_trans[1], lhs)
|
||||
trans_dimension_numbers = ConvDimensionNumbers(lhs_trans, out_trans, rhs_trans)
|
||||
padding = _conv_general_vjp_rhs_padding(
|
||||
onp.take(lhs_shape, lhs_sdims), onp.take(rhs_shape, rhs_sdims),
|
||||
@ -2230,70 +2306,96 @@ def _conv_general_dilated_transpose_rhs(
|
||||
lhs, g, window_strides=rhs_dilation, padding=padding,
|
||||
lhs_dilation=lhs_dilation, rhs_dilation=window_strides,
|
||||
dimension_numbers=trans_dimension_numbers,
|
||||
feature_group_count=feature_group_count, precision=precision)
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count, precision=precision)
|
||||
|
||||
def _conv_general_dilated_translation_rule(
|
||||
c, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, precision, **unused_kwargs):
|
||||
dimension_numbers, feature_group_count, batch_group_count, precision,
|
||||
**unused_kwargs):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
dimension_numbers = _conv_general_proto(dimension_numbers)
|
||||
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config=_precision_config(precision))
|
||||
|
||||
def _conv_general_dilated_batch_rule(
|
||||
batched_args, batch_dims, *, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, precision, **unused_kwargs):
|
||||
feature_group_count, batch_group_count, precision, **unused_kwargs):
|
||||
assert batch_group_count == 1 or feature_group_count == 1
|
||||
lhs, rhs = batched_args
|
||||
lhs_bdim, rhs_bdim = batch_dims
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
|
||||
if lhs_bdim is not None and rhs_bdim is not None:
|
||||
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
|
||||
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
|
||||
if batch_group_count > 1:
|
||||
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
|
||||
batch_group_count *= lhs.shape[lhs_bdim]
|
||||
else:
|
||||
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[1], lhs)
|
||||
feature_group_count *= lhs.shape[lhs_bdim]
|
||||
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
||||
out = conv_general_dilated(
|
||||
new_lhs, new_rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers,
|
||||
feature_group_count=lhs.shape[lhs_bdim] * feature_group_count,
|
||||
dimension_numbers, feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[1], lhs.shape[lhs_bdim], out)
|
||||
return out, out_spec[1]
|
||||
|
||||
elif lhs_bdim is not None:
|
||||
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
|
||||
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
|
||||
return out, out_spec[0]
|
||||
if batch_group_count == 1:
|
||||
new_lhs = _reshape_axis_into(lhs_bdim, lhs_spec[0], lhs)
|
||||
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
|
||||
return out, out_spec[0]
|
||||
else:
|
||||
new_lhs = _reshape_axis_out_of(lhs_spec[0] + int(lhs_bdim <= lhs_spec[0]),
|
||||
batch_group_count, lhs)
|
||||
new_lhs = _reshape_axis_into(lhs_bdim + int(lhs_spec[0] < lhs_bdim),
|
||||
lhs_spec[0] + 1,
|
||||
new_lhs)
|
||||
new_lhs = _reshape_axis_into(lhs_spec[0], lhs_spec[0], new_lhs)
|
||||
out = conv_general_dilated(new_lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[0], lhs.shape[lhs_bdim], out)
|
||||
return out, out_spec[0]
|
||||
|
||||
elif rhs_bdim is not None:
|
||||
if feature_group_count == 1:
|
||||
if feature_group_count == 1 and batch_group_count == 1:
|
||||
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
||||
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, precision=precision)
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
|
||||
return out, out_spec[1]
|
||||
else:
|
||||
# feature_group needs to be outermost, so we need to factor it out of the
|
||||
# groups need to be outermost, so we need to factor them out of the
|
||||
# rhs output feature dim, then factor the batch dim into the remaining rhs
|
||||
# output feature dim, then put feature_group back in. we do something
|
||||
# similar on the output. an alternative which would require more FLOPs but
|
||||
# output feature dim, then put groups back in. We do something
|
||||
# similar on the output. An alternative which would require more FLOPs but
|
||||
# fewer reshapes would be to broadcast lhs.
|
||||
group_count = (feature_group_count if feature_group_count > 1
|
||||
else batch_group_count)
|
||||
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
|
||||
feature_group_count, rhs)
|
||||
group_count, rhs)
|
||||
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
|
||||
rhs_spec[0] + 1,
|
||||
new_rhs)
|
||||
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
|
||||
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[1], feature_group_count, out)
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision=precision)
|
||||
out = _reshape_axis_out_of(out_spec[1], group_count, out)
|
||||
out = _reshape_axis_out_of(out_spec[1] + 1, rhs.shape[rhs_bdim], out)
|
||||
out = _reshape_axis_into(out_spec[1], out_spec[1] + 1, out)
|
||||
return out, out_spec[1]
|
||||
@ -4612,7 +4714,7 @@ def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
|
||||
raise TypeError(msg.format(name, expected_length, len(window_strides)))
|
||||
|
||||
|
||||
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
|
||||
"""Compute the shape tuple of a conv given input shapes in canonical order."""
|
||||
if isinstance(pads, str):
|
||||
pads = padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
|
||||
@ -4625,8 +4727,9 @@ def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads):
|
||||
out_space = onp.floor_divide(
|
||||
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
|
||||
out_space = onp.maximum(0, out_space)
|
||||
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
|
||||
return tuple(out_shape)
|
||||
assert lhs_shape[0] % batch_group_count == 0
|
||||
out_shape = (lhs_shape[0] // batch_group_count, rhs_shape[0])
|
||||
return tuple(out_shape + tuple(out_space))
|
||||
|
||||
|
||||
def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
||||
@ -4639,7 +4742,7 @@ def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
||||
|
||||
|
||||
def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
|
||||
dimension_numbers):
|
||||
dimension_numbers):
|
||||
lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
|
||||
lhs_trans = onp.take(lhs_shape, lhs_perm)
|
||||
rhs_trans = onp.take(rhs_shape, rhs_perm)
|
||||
|
@ -12,4 +12,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.1.43"
|
||||
__version__ = "0.1.44"
|
||||
|
@ -469,9 +469,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dilation": lhs_dilation,
|
||||
"rhs_dilation": rhs_dilation, "dimension_numbers": dim_nums,
|
||||
"feature_group_count": feature_group_count,
|
||||
"batch_group_count": batch_group_count,
|
||||
"perms": perms, "rng_factory": rng_factory}
|
||||
for batch_group_count, feature_group_count in [(1, 1), (2, 1), (1, 2)]
|
||||
for lhs_shape, rhs_shape in [
|
||||
((b, i, 9, w), (j, i, 4, 5))
|
||||
((b * batch_group_count, i * feature_group_count, 9, w),
|
||||
(j * feature_group_count * batch_group_count, i, 4, 5))
|
||||
for w in [0, 10]
|
||||
for b, i, j in itertools.product([2, 3], repeat=3)]
|
||||
for dtype in float_dtypes for strides in [(1, 1), (2, 1)]
|
||||
@ -486,6 +490,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
]))
|
||||
def testConvGeneralDilated(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, lhs_dilation, rhs_dilation,
|
||||
feature_group_count, batch_group_count,
|
||||
dimension_numbers, perms, rng_factory):
|
||||
rng = rng_factory()
|
||||
lhs_perm, rhs_perm = perms # permute to compatible shapes
|
||||
@ -497,7 +502,8 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def fun(lhs, rhs):
|
||||
return lax.conv_general_dilated(
|
||||
lhs, rhs, strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers)
|
||||
dimension_numbers, feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
self._CompileAndCheck(fun, args_maker, check_dtypes=True)
|
||||
|
||||
@ -2004,24 +2010,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
||||
feature_group_count),
|
||||
feature_group_count, batch_group_count),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
|
||||
"perms": perms, "feature_group_count": feature_group_count}
|
||||
"perms": perms, "feature_group_count": feature_group_count,
|
||||
"batch_group_count": batch_group_count}
|
||||
# TODO(phawkins): make batch_group_count tests unconditional after
|
||||
# minimum jaxlib version is 0.1.44 or greater.
|
||||
for batch_group_count, feature_group_count in (
|
||||
[(1, 1), (2, 1), (1, 2)] if jax.lib.version > (0, 1, 43)
|
||||
else [(1, 1), (1, 2)])
|
||||
for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
|
||||
([(b, i, 6, 7), (b, i, 0, 4)], # lhs_shape
|
||||
(j, i, 1, 2), # rhs_shape
|
||||
([(b * batch_group_count, i * feature_group_count, 6, 7),
|
||||
(b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape
|
||||
(j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape
|
||||
[(1, 1), (1, 2), (2, 1)], # strides
|
||||
[(1, 1), (2, 1)], # lhs_dils
|
||||
[(1, 1), (2, 2)]) # rhs_dils
|
||||
for b, i, j in itertools.product([1, 2], repeat=3)]
|
||||
for lhs_shape in lhs_shapes
|
||||
for feature_group_count in [1, 2]
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
@ -2036,7 +2048,8 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
))
|
||||
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, lhs_dil, rhs_dil, dimension_numbers,
|
||||
perms, feature_group_count, rng_factory):
|
||||
perms, feature_group_count, batch_group_count,
|
||||
rng_factory):
|
||||
rng = rng_factory()
|
||||
tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 1e-4}
|
||||
|
||||
@ -2044,9 +2057,6 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
lhs_perm, rhs_perm = perms
|
||||
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
||||
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
|
||||
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
|
||||
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
|
||||
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
@ -2054,6 +2064,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
|
||||
atol=tol, rtol=tol)
|
||||
@ -2684,25 +2695,32 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}_lhs_bdim={}_rhs_bdim={}"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
|
||||
"_lhs_bdim={}_rhs_bdim={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
||||
feature_group_count, lhs_bdim, rhs_bdim),
|
||||
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
|
||||
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
|
||||
"feature_group_count": feature_group_count}
|
||||
"feature_group_count": feature_group_count,
|
||||
"batch_group_count": batch_group_count,
|
||||
}
|
||||
# TODO(phawkins): make batch_group_count tests unconditional after
|
||||
# minimum jaxlib version is 0.1.44 or greater.
|
||||
for batch_group_count, feature_group_count in (
|
||||
[(1, 1), (2, 1), (1, 2)] if jax.lib.version > (0, 1, 43)
|
||||
else [(1, 1), (1, 2)])
|
||||
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
|
||||
((b, i, 6, 7), # lhs_shape
|
||||
(j, i, 1, 2), # rhs_shape
|
||||
((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape
|
||||
(j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape
|
||||
[(1, 1), (1, 2), (2, 1)], # strides
|
||||
[((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads
|
||||
[(1, 1), (2, 1)], # lhs_dils
|
||||
[(1, 1), (2, 2)]) # rhs_dils
|
||||
for b, i, j in itertools.product([1, 2], repeat=3)]
|
||||
for feature_group_count in [1, 2]
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
@ -2719,11 +2737,10 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
if (lhs_bdim, rhs_bdim) != (None, None)
|
||||
for rng_factory in [jtu.rand_default]
|
||||
))
|
||||
# TODO(mattjj): some cases fail on TPU just due to numerical tolerances
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testConvGeneralDilatedBatching(
|
||||
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
|
||||
dimension_numbers, perms, feature_group_count, lhs_bdim, rhs_bdim, rng_factory):
|
||||
dimension_numbers, perms, feature_group_count, batch_group_count,
|
||||
lhs_bdim, rhs_bdim, rng_factory):
|
||||
rng = rng_factory()
|
||||
tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3
|
||||
|
||||
@ -2731,14 +2748,12 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
lhs_perm, rhs_perm = perms
|
||||
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
||||
dim_spec = lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)
|
||||
lhs_shape[dim_spec.lhs_spec[1]] *= feature_group_count
|
||||
rhs_shape[dim_spec.rhs_spec[0]] *= feature_group_count
|
||||
|
||||
conv = partial(lax.conv_general_dilated, window_strides=strides,
|
||||
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
||||
dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
|
||||
(dtype, dtype), rng, rtol=tol, atol=tol)
|
||||
|
Loading…
x
Reference in New Issue
Block a user