mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add input validation for the padding argument to lax.conv_general_dilated.
Fixes #10729
This commit is contained in:
parent
744f6b4ee8
commit
44f1e05a76
@ -14,6 +14,7 @@
|
||||
|
||||
import builtins
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -141,6 +142,15 @@ def conv_general_dilated(
|
||||
padding = lax.padtype_to_pads(
|
||||
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
|
||||
window_strides, padding)
|
||||
else:
|
||||
try:
|
||||
padding = tuple((operator.index(lo), operator.index(hi))
|
||||
for lo, hi in padding)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(
|
||||
"padding argument to conv_general_dilated should be a string or a "
|
||||
f"sequence of (low, high) pairs, got {padding}") from e
|
||||
|
||||
preferred_element_type = (
|
||||
None if preferred_element_type is None else
|
||||
dtypes.canonicalize_dtype(np.dtype(preferred_element_type)))
|
||||
|
@ -1060,6 +1060,13 @@ class LaxTest(jtu.JaxTestCase):
|
||||
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
|
||||
self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26)))
|
||||
|
||||
def testConvInvalidPadding(self):
|
||||
x = jnp.ones((1, 10, 10, 5), dtype=jnp.bfloat16)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"padding argument.*, got \(3, 3\)"):
|
||||
jax.lax.conv_general_dilated_patches(x, (5, 5), window_strides=(1, 1),
|
||||
padding=(3, 3))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
|
||||
jtu.format_shape_dtype_string(lhs_shape, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user