improve conv rhs batching, add systematic test

This commit is contained in:
Matthew Johnson 2019-06-15 12:01:20 -07:00
parent 1dc4a4d05e
commit 1262ca9b30
4 changed files with 99 additions and 13 deletions

View File

@ -1843,15 +1843,13 @@ def _conv_general_dilated_batch_rule(
return outputs, 0
elif rhs_bdim is not None:
#TODO(#212): use a map construct instead of unrolling.
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
outputs = [
conv_general_dilated(lhs, x, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
for x in rhs]
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
outputs = concatenate(outputs, 0)
return outputs, 0
# move and reshape the bdim into the rhs output channels dimension
_, rhs_spec, out_spec = dimension_numbers
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers)
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
return out, out_spec[1]
conv_general_dilated_p = standard_primitive(
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
'conv_general_dilated', _conv_general_dilated_translation_rule)
@ -1861,6 +1859,21 @@ ad.defbilinear(conv_general_dilated_p,
batching.primitive_batchers[
conv_general_dilated_p] = _conv_general_dilated_batch_rule
def _reshape_axis_into(src, dst, x):
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
new_shape = list(onp.delete(x.shape, src))
new_shape[dst] *= x.shape[src]
return reshape(transpose(x, perm), new_shape) # TODO(mattjj): manually fuse
def _reshape_axis_out_of(src, size1, x):
shape = list(x.shape)
size2, ragged = divmod(shape[src], size1)
assert not ragged
shape[src:src+1] = [size1, size2]
return reshape(x, shape)
def _dot_shape_rule(lhs, rhs):
if lhs.ndim == 0 or rhs.ndim == 0:
msg = "Dot only supports rank 1 or above, got shapes {} and {}."

View File

@ -507,6 +507,7 @@ class JaxTestCase(parameterized.TestCase):
for k in x.keys():
self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol)
elif is_sequence(x) and not hasattr(x, '__array__'):
import ipdb; ipdb.set_trace()
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
self.assertEqual(len(x), len(y))
for x_elt, y_elt in zip(x, y):

View File

@ -485,7 +485,6 @@ class BatchingTest(jtu.JaxTestCase):
(5, 21, 5, 1)))
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
def testMaxPool(self):
W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

View File

@ -1656,12 +1656,12 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for lhs_dil in lhs_dils
for dtype in [onp.float32]
for padding in all_pads
for rng in [jtu.rand_default()]
for dim_nums, perms in [
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))
]))
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
for rng in [jtu.rand_default()]
))
@jtu.skip_on_devices("tpu")
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
padding, lhs_dil, rhs_dil, dimension_numbers,
@ -2153,5 +2153,78 @@ class LaxAutodiffTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=True)
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
class LaxVmapTest(jtu.JaxTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
"rhs_dilation={}_dims={}_lhs_bdim={}_rhs_bdim={}"
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
lhs_bdim, rhs_bdim),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
"rhs_dil": rhs_dil, "rng": rng, "dimension_numbers": dim_nums,
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim}
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
((b, i, 6, 7), # lhs_shape
(j, i, 1, 2), # rhs_shape
[(1, 1), (1, 2), (2, 1)], # strides
[((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads
[(1, 1), (2, 1)], # lhs_dils
[(1, 1), (2, 2)]) # rhs_dils
for b, i, j in itertools.product([1, 2], repeat=3)]
for strides in all_strides
for rhs_dil in rhs_dils
for lhs_dil in lhs_dils
for dtype in [onp.float32]
for padding in all_pads
for dim_nums, perms in [
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
for lhs_bdim in itertools.chain([None], range(len(lhs_shape) + 1))
for rhs_bdim in itertools.chain([None], range(len(rhs_shape) + 1))
if (lhs_bdim, rhs_bdim) != (None, None)
for rng in [jtu.rand_default()]
))
def testConvGeneralDilatedBatching(
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
dimension_numbers, perms, lhs_bdim, rhs_bdim, rng):
tol = 1e-1 if onp.finfo(dtype).bits == 32 else 1e-3
bdim_size = 10
# permute shapes to match dimension_numbers
lhs_perm, rhs_perm = perms
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
# add batch dimension
if lhs_bdim is not None:
lhs_shape.insert(lhs_bdim, bdim_size)
if rhs_bdim is not None:
rhs_shape.insert(rhs_bdim, bdim_size)
# create arg values and sliced versions
lhs = rng(lhs_shape, dtype)
rhs = rng(rhs_shape, dtype)
lhs_slice = slicer(lhs, lhs_bdim)
rhs_slice = slicer(rhs, rhs_bdim)
conv = partial(lax.conv_general_dilated, window_strides=strides,
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
dimension_numbers=dimension_numbers)
ans = api.vmap(conv, (lhs_bdim, rhs_bdim))(lhs, rhs)
expected = onp.stack([conv(lhs_slice(i), rhs_slice(i)) for i in range(bdim_size)])
self.assertAllClose(ans, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main()