mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix flaky testcase LaxTest.testConvTransposePaddingList on GPU.
This commit is contained in:
parent
6bd9292b3d
commit
c54bb06434
@ -978,7 +978,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
a = jnp.ones((28,28))
|
||||
b = jnp.ones((3,3))
|
||||
c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1))
|
||||
self.assertArraysEqual(c, 9 * jnp.ones((1, 1, 26, 26)))
|
||||
self.assertAllClose(c, 9 * jnp.ones((1, 1, 26, 26)))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_lhs_shape={}_rhs_shape={}_precision={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user