Fix batching rule for convolution for batch dimensions of size 0.

This commit is contained in:
Peter Hawkins 2022-06-01 13:07:23 -04:00
parent e225317ff8
commit ece9b999fb
2 changed files with 42 additions and 12 deletions

View File

@ -534,6 +534,29 @@ def _conv_general_dilated_batch_rule(
lhs_bdim, rhs_bdim = batch_dims
lhs_spec, rhs_spec, out_spec = dimension_numbers
# Some of the cases that reshape into batch or feature dimensions do not work
# with size 0 batch dimensions. The best fix would be to extend HLO to support
# multiple batch dimensions.
if ((lhs_bdim is not None and lhs.shape[lhs_bdim] == 0) or
(rhs_bdim is not None and rhs.shape[rhs_bdim] == 0)):
lhs_shape_unbatched, rhs_shape_unbatched = list(lhs.shape), list(rhs.shape)
if lhs_bdim is not None:
lhs_shape_unbatched.pop(lhs_bdim)
if rhs_bdim is not None:
rhs_shape_unbatched.pop(rhs_bdim)
shape = _conv_general_dilated_shape_rule(
core.ShapedArray(lhs_shape_unbatched, lhs.dtype),
core.ShapedArray(rhs_shape_unbatched, rhs.dtype),
window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
return lax.full(
(0,) + shape, 0,
dtype=lhs.dtype if preferred_element_type is None
else preferred_element_type), 0
if lhs_bdim is not None and rhs_bdim is not None:
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
if batch_group_count > 1:
@ -596,8 +619,7 @@ def _conv_general_dilated_batch_rule(
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
group_count, rhs)
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
rhs_spec[0] + 1,
new_rhs)
rhs_spec[0] + 1, new_rhs)
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
@ -737,6 +759,10 @@ mlir.register_lowering(
def _reshape_axis_into(src, dst, x):
# NB: `dst` is the number of the dimension that we should reshape into
# *after* `src` is removed from `x`'s list of dimensions. For example, if
# `src` is an added batch dimension, `dst` might name a target dimension in
# the unbatched list of dimensions.
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
new_shape = list(np.delete(x.shape, src))

View File

@ -114,17 +114,18 @@ class LaxVmapTest(jtu.JaxTestCase):
"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
"_lhs_bdim={}_rhs_bdim={}"
"_lhs_bdim={}_rhs_bdim={}_bdim_size={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim,
bdim_size),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
"feature_group_count": feature_group_count,
"batch_group_count": batch_group_count,
"batch_group_count": batch_group_count, "bdim_size": bdim_size,
} for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)])
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([
((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape
@ -142,7 +143,10 @@ class LaxVmapTest(jtu.JaxTestCase):
for dim_nums, perms in s([
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3])),
(("HWCN", "HWIO", "HWCN"), ([2, 3, 1, 0], [2, 3, 1, 0])),
])
for bdim_size in s([0, 5])
for lhs_bdim in s(itertools.chain([cast(Optional[int], None)],
range(len(lhs_shape) + 1)))
for rhs_bdim in s(itertools.chain([cast(Optional[int], None)],
@ -152,7 +156,7 @@ class LaxVmapTest(jtu.JaxTestCase):
def testConvGeneralDilatedBatching(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
dimension_numbers, perms, feature_group_count, batch_group_count,
lhs_bdim, rhs_bdim):
lhs_bdim, rhs_bdim, bdim_size):
rng = jtu.rand_default(self.rng())
tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3
@ -167,8 +171,9 @@ class LaxVmapTest(jtu.JaxTestCase):
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
precision=lax.Precision.HIGHEST)
self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
(dtype, dtype), rng, rtol=tol, atol=tol)
self._CheckBatching(conv, bdim_size, (lhs_bdim, rhs_bdim),
(lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol,
atol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
@ -467,12 +472,11 @@ class LaxVmapTest(jtu.JaxTestCase):
for init_val, op, dtypes in [
(0, lax.add, default_dtypes),
(1, lax.mul, default_dtypes),
(0, lax.max, all_dtypes), # non-monoidal
# non-monoidal for everything except unsigned integers
(0, lax.max, all_dtypes),
(-np.inf, lax.max, float_dtypes),
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
(dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
(dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
(dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
(np.inf, lax.min, float_dtypes),
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
(dtypes.iinfo(np.int64).max, lax.min, [np.int64]),