mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Enable testWhileLoopBatchedWithConstBody for GPU
The XLA:GPU issue causing the internal error has been fixed.
This commit is contained in:
parent
be751d1dd6
commit
0776c4e628
@ -360,8 +360,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = np.array([0, 2, 2, 4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
# TODO(b/202709967): Enable once fixed
|
||||
@unittest.skipIf(jtu.device_under_test() == 'gpu', "Test triggers an internal XLA:GPU error")
|
||||
def testWhileLoopBatchedWithConstBody(self):
|
||||
def f(x):
|
||||
def body_fn(_): return jnp.asarray(0., dtype=jnp.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user