Enable testWhileLoopBatchedWithConstBody for GPU

The XLA:GPU issue causing the internal error has been fixed.
This commit is contained in:
Rahul Joshi 2021-11-15 16:48:52 -08:00
parent be751d1dd6
commit 0776c4e628

View File

@ -360,8 +360,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = np.array([0, 2, 2, 4])
self.assertAllClose(ans, expected, check_dtypes=False)
# TODO(b/202709967): Enable once fixed
@unittest.skipIf(jtu.device_under_test() == 'gpu', "Test triggers an internal XLA:GPU error")
def testWhileLoopBatchedWithConstBody(self):
def f(x):
def body_fn(_): return jnp.asarray(0., dtype=jnp.float32)