Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32

This commit is contained in:
Jane Liu 2023-09-12 14:30:40 -07:00
parent 801cbef011
commit 90da4f153a

View File

@ -222,6 +222,8 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
# Replace the default TF32 with float32 in order to make it pass on A100
@jax.default_matmul_precision("float32")
def testDotGeneral(self):
R = self.rng().randn