Update lax_test.py

This commit is contained in:
George Necula 2020-12-06 15:39:30 +02:00 committed by GitHub
parent 33c3cd0ea1
commit f51db5cd75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2127,7 +2127,7 @@ class LaxTest(jtu.JaxTestCase):
def test_window_strides_dimension_shape_rule(self):
# https://github.com/google/jax/issues/5087
msg = ("conv_general_dilated window and window_strides must have "
"the same number of dimension")
"the same number of dimensions")
lhs = jax.numpy.zeros((1, 1, 3, 3))
rhs = np.zeros((1, 1, 1, 1))
with self.assertRaisesRegex(ValueError, msg):