Fix flaky testcase LaxTest.testConvTransposePaddingList on GPU.

This commit is contained in:
Peter Hawkins 2022-01-27 09:11:21 -05:00
parent 6bd9292b3d
commit c54bb06434

View File

@ -978,7 +978,7 @@ class LaxTest(jtu.JaxTestCase):
a = jnp.ones((28,28))
b = jnp.ones((3,3))
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
self.assertArraysEqual(c, 9 * jnp.ones((1, 1, 26, 26)))
self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(