Add input validation for the padding argument to lax.conv_general_dilated.

Fixes #10729
This commit is contained in:
Peter Hawkins 2022-05-16 16:06:52 -04:00
parent 744f6b4ee8
commit 44f1e05a76
2 changed files with 17 additions and 0 deletions

View File

@ -14,6 +14,7 @@
import builtins
from functools import partial
import operator
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
@ -141,6 +142,15 @@ def conv_general_dilated(
padding = lax.padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
else:
try:
padding = tuple((operator.index(lo), operator.index(hi))
for lo, hi in padding)
except (ValueError, TypeError) as e:
raise ValueError(
"padding argument to conv_general_dilated should be a string or a "
f"sequence of (low, high) pairs, got {padding}") from e
preferred_element_type = (
None if preferred_element_type is None else
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))

View File

@ -1060,6 +1060,13 @@ class LaxTest(jtu.JaxTestCase):
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26)))
def testConvInvalidPadding(self):
x = jnp.ones((1, 10, 10, 5), dtype=jnp.bfloat16)
with self.assertRaisesRegex(ValueError,
r"padding argument.*, got \(3, 3\)"):
jax.lax.conv_general_dilated_patches(x, (5, 5), window_strides=(1, 1),
padding=(3, 3))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),