mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32
This commit is contained in:
parent
801cbef011
commit
90da4f153a
@ -222,6 +222,8 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
|
||||
|
||||
# Replace the default TF32 with float32 in order to make it pass on A100
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testDotGeneral(self):
|
||||
R = self.rng().randn
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user