mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve conv rhs batching, add systematic test
This commit is contained in:
parent
1dc4a4d05e
commit
1262ca9b30
@ -1843,15 +1843,13 @@ def _conv_general_dilated_batch_rule(
|
||||
|
||||
return outputs, 0
|
||||
elif rhs_bdim is not None:
|
||||
#TODO(#212): use a map construct instead of unrolling.
|
||||
rhs = batching.move_dim_to_front(rhs, rhs_bdim)
|
||||
outputs = [
|
||||
conv_general_dilated(lhs, x, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
for x in rhs]
|
||||
outputs = [reshape(out, (1,) + out.shape) for out in outputs]
|
||||
outputs = concatenate(outputs, 0)
|
||||
return outputs, 0
|
||||
# move and reshape the bdim into the rhs output channels dimension
|
||||
_, rhs_spec, out_spec = dimension_numbers
|
||||
new_rhs = _reshape_axis_into(rhs_bdim, rhs_spec[0], rhs)
|
||||
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers)
|
||||
out = _reshape_axis_out_of(out_spec[1], rhs.shape[rhs_bdim], out)
|
||||
return out, out_spec[1]
|
||||
conv_general_dilated_p = standard_primitive(
|
||||
_conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
|
||||
'conv_general_dilated', _conv_general_dilated_translation_rule)
|
||||
@ -1861,6 +1859,21 @@ ad.defbilinear(conv_general_dilated_p,
|
||||
batching.primitive_batchers[
|
||||
conv_general_dilated_p] = _conv_general_dilated_batch_rule
|
||||
|
||||
def _reshape_axis_into(src, dst, x):
|
||||
perm = [i for i in range(x.ndim) if i != src]
|
||||
perm.insert(dst, src)
|
||||
new_shape = list(onp.delete(x.shape, src))
|
||||
new_shape[dst] *= x.shape[src]
|
||||
return reshape(transpose(x, perm), new_shape) # TODO(mattjj): manually fuse
|
||||
|
||||
def _reshape_axis_out_of(src, size1, x):
|
||||
shape = list(x.shape)
|
||||
size2, ragged = divmod(shape[src], size1)
|
||||
assert not ragged
|
||||
shape[src:src+1] = [size1, size2]
|
||||
return reshape(x, shape)
|
||||
|
||||
|
||||
def _dot_shape_rule(lhs, rhs):
|
||||
if lhs.ndim == 0 or rhs.ndim == 0:
|
||||
msg = "Dot only supports rank 1 or above, got shapes {} and {}."
|
||||
|
@ -507,6 +507,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
for k in x.keys():
|
||||
self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol)
|
||||
elif is_sequence(x) and not hasattr(x, '__array__'):
|
||||
import ipdb; ipdb.set_trace()
|
||||
self.assertTrue(is_sequence(y) and not hasattr(y, '__array__'))
|
||||
self.assertEqual(len(x), len(y))
|
||||
for x_elt, y_elt in zip(x, y):
|
||||
|
@ -485,7 +485,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
(5, 21, 5, 1)))
|
||||
self.assertAllClose(per_example, per_example_direct, check_dtypes=True)
|
||||
|
||||
|
||||
def testMaxPool(self):
|
||||
W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
|
||||
X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)
|
||||
|
@ -1656,12 +1656,12 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for lhs_dil in lhs_dils
|
||||
for dtype in [onp.float32]
|
||||
for padding in all_pads
|
||||
for rng in [jtu.rand_default()]
|
||||
for dim_nums, perms in [
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))
|
||||
]))
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
|
||||
for rng in [jtu.rand_default()]
|
||||
))
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
|
||||
padding, lhs_dil, rhs_dil, dimension_numbers,
|
||||
@ -2153,5 +2153,78 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
def slicer(x, bdim):
|
||||
if bdim is None:
|
||||
return lambda _: x
|
||||
else:
|
||||
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
|
||||
|
||||
class LaxVmapTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}_dims={}_lhs_bdim={}_rhs_bdim={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
||||
lhs_bdim, rhs_bdim),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil, "rng": rng, "dimension_numbers": dim_nums,
|
||||
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim}
|
||||
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [
|
||||
((b, i, 6, 7), # lhs_shape
|
||||
(j, i, 1, 2), # rhs_shape
|
||||
[(1, 1), (1, 2), (2, 1)], # strides
|
||||
[((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads
|
||||
[(1, 1), (2, 1)], # lhs_dils
|
||||
[(1, 1), (2, 2)]) # rhs_dils
|
||||
for b, i, j in itertools.product([1, 2], repeat=3)]
|
||||
for strides in all_strides
|
||||
for rhs_dil in rhs_dils
|
||||
for lhs_dil in lhs_dils
|
||||
for dtype in [onp.float32]
|
||||
for padding in all_pads
|
||||
for dim_nums, perms in [
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
|
||||
for lhs_bdim in itertools.chain([None], range(len(lhs_shape) + 1))
|
||||
for rhs_bdim in itertools.chain([None], range(len(rhs_shape) + 1))
|
||||
if (lhs_bdim, rhs_bdim) != (None, None)
|
||||
for rng in [jtu.rand_default()]
|
||||
))
|
||||
def testConvGeneralDilatedBatching(
|
||||
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
|
||||
dimension_numbers, perms, lhs_bdim, rhs_bdim, rng):
|
||||
tol = 1e-1 if onp.finfo(dtype).bits == 32 else 1e-3
|
||||
bdim_size = 10
|
||||
|
||||
# permute shapes to match dimension_numbers
|
||||
lhs_perm, rhs_perm = perms
|
||||
lhs_shape = list(onp.take(lhs_shape, lhs_perm))
|
||||
rhs_shape = list(onp.take(rhs_shape, rhs_perm))
|
||||
|
||||
# add batch dimension
|
||||
if lhs_bdim is not None:
|
||||
lhs_shape.insert(lhs_bdim, bdim_size)
|
||||
if rhs_bdim is not None:
|
||||
rhs_shape.insert(rhs_bdim, bdim_size)
|
||||
|
||||
# create arg values and sliced versions
|
||||
lhs = rng(lhs_shape, dtype)
|
||||
rhs = rng(rhs_shape, dtype)
|
||||
lhs_slice = slicer(lhs, lhs_bdim)
|
||||
rhs_slice = slicer(rhs, rhs_bdim)
|
||||
|
||||
conv = partial(lax.conv_general_dilated, window_strides=strides,
|
||||
padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
|
||||
dimension_numbers=dimension_numbers)
|
||||
ans = api.vmap(conv, (lhs_bdim, rhs_bdim))(lhs, rhs)
|
||||
expected = onp.stack([conv(lhs_slice(i), rhs_slice(i)) for i in range(bdim_size)])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user