mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update lax_test.py
This commit is contained in:
parent
33c3cd0ea1
commit
f51db5cd75
@ -2127,7 +2127,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def test_window_strides_dimension_shape_rule(self):
|
||||
# https://github.com/google/jax/issues/5087
|
||||
msg = ("conv_general_dilated window and window_strides must have "
|
||||
"the same number of dimension")
|
||||
"the same number of dimensions")
|
||||
lhs = jax.numpy.zeros((1, 1, 3, 3))
|
||||
rhs = np.zeros((1, 1, 1, 1))
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
|
Loading…
x
Reference in New Issue
Block a user