mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Fix batching rule for convolution for batch dimensions of size 0.
This commit is contained in:
parent
e225317ff8
commit
ece9b999fb
@ -534,6 +534,29 @@ def _conv_general_dilated_batch_rule(
|
||||
lhs_bdim, rhs_bdim = batch_dims
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
|
||||
# Some of the cases that reshape into batch or feature dimensions do not work
|
||||
# with size 0 batch dimensions. The best fix would be to extend HLO to support
|
||||
# multiple batch dimensions.
|
||||
if ((lhs_bdim is not None and lhs.shape[lhs_bdim] == 0) or
|
||||
(rhs_bdim is not None and rhs.shape[rhs_bdim] == 0)):
|
||||
lhs_shape_unbatched, rhs_shape_unbatched = list(lhs.shape), list(rhs.shape)
|
||||
if lhs_bdim is not None:
|
||||
lhs_shape_unbatched.pop(lhs_bdim)
|
||||
if rhs_bdim is not None:
|
||||
rhs_shape_unbatched.pop(rhs_bdim)
|
||||
shape = _conv_general_dilated_shape_rule(
|
||||
core.ShapedArray(lhs_shape_unbatched, lhs.dtype),
|
||||
core.ShapedArray(rhs_shape_unbatched, rhs.dtype),
|
||||
window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
return lax.full(
|
||||
(0,) + shape, 0,
|
||||
dtype=lhs.dtype if preferred_element_type is None
|
||||
else preferred_element_type), 0
|
||||
|
||||
|
||||
if lhs_bdim is not None and rhs_bdim is not None:
|
||||
assert lhs.shape[lhs_bdim] == rhs.shape[rhs_bdim]
|
||||
if batch_group_count > 1:
|
||||
@ -596,8 +619,7 @@ def _conv_general_dilated_batch_rule(
|
||||
new_rhs = _reshape_axis_out_of(rhs_spec[0] + int(rhs_bdim <= rhs_spec[0]),
|
||||
group_count, rhs)
|
||||
new_rhs = _reshape_axis_into(rhs_bdim + int(rhs_spec[0] < rhs_bdim),
|
||||
rhs_spec[0] + 1,
|
||||
new_rhs)
|
||||
rhs_spec[0] + 1, new_rhs)
|
||||
new_rhs = _reshape_axis_into(rhs_spec[0], rhs_spec[0], new_rhs)
|
||||
out = conv_general_dilated(lhs, new_rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
@ -737,6 +759,10 @@ mlir.register_lowering(
|
||||
|
||||
|
||||
def _reshape_axis_into(src, dst, x):
|
||||
# NB: `dst` is the number of the dimension that we should reshape into
|
||||
# *after* `src` is removed from `x`'s list of dimensions. For example, if
|
||||
# `src` is an added batch dimension, `dst` might name a target dimension in
|
||||
# the unbatched list of dimensions.
|
||||
perm = [i for i in range(x.ndim) if i != src]
|
||||
perm.insert(dst, src)
|
||||
new_shape = list(np.delete(x.shape, src))
|
||||
|
@ -114,17 +114,18 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
"testcase_name":
|
||||
"_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
|
||||
"rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
|
||||
"_lhs_bdim={}_rhs_bdim={}"
|
||||
"_lhs_bdim={}_rhs_bdim={}_bdim_size={}"
|
||||
.format(jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
jtu.format_shape_dtype_string(rhs_shape, dtype),
|
||||
strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
|
||||
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
|
||||
feature_group_count, batch_group_count, lhs_bdim, rhs_bdim,
|
||||
bdim_size),
|
||||
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
|
||||
"strides": strides, "padding": padding, "lhs_dil": lhs_dil,
|
||||
"rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
|
||||
"perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim,
|
||||
"feature_group_count": feature_group_count,
|
||||
"batch_group_count": batch_group_count,
|
||||
"batch_group_count": batch_group_count, "bdim_size": bdim_size,
|
||||
} for batch_group_count, feature_group_count in s([(1, 1), (2, 1), (1, 2)])
|
||||
for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in s([
|
||||
((b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape
|
||||
@ -142,7 +143,10 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for dim_nums, perms in s([
|
||||
(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
|
||||
(("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])
|
||||
(("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3])),
|
||||
(("HWCN", "HWIO", "HWCN"), ([2, 3, 1, 0], [2, 3, 1, 0])),
|
||||
])
|
||||
for bdim_size in s([0, 5])
|
||||
for lhs_bdim in s(itertools.chain([cast(Optional[int], None)],
|
||||
range(len(lhs_shape) + 1)))
|
||||
for rhs_bdim in s(itertools.chain([cast(Optional[int], None)],
|
||||
@ -152,7 +156,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
def testConvGeneralDilatedBatching(
|
||||
self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil,
|
||||
dimension_numbers, perms, feature_group_count, batch_group_count,
|
||||
lhs_bdim, rhs_bdim):
|
||||
lhs_bdim, rhs_bdim, bdim_size):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3
|
||||
|
||||
@ -167,8 +171,9 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count,
|
||||
precision=lax.Precision.HIGHEST)
|
||||
self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
|
||||
(dtype, dtype), rng, rtol=tol, atol=tol)
|
||||
self._CheckBatching(conv, bdim_size, (lhs_bdim, rhs_bdim),
|
||||
(lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
|
||||
@ -467,12 +472,11 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
for init_val, op, dtypes in [
|
||||
(0, lax.add, default_dtypes),
|
||||
(1, lax.mul, default_dtypes),
|
||||
(0, lax.max, all_dtypes), # non-monoidal
|
||||
# non-monoidal for everything except unsigned integers
|
||||
(0, lax.max, all_dtypes),
|
||||
(-np.inf, lax.max, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
|
||||
(dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
|
||||
(dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
|
||||
(np.inf, lax.min, float_dtypes),
|
||||
(dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
|
||||
(dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user