From 90da4f153ab7f4511ada1e4546b2eb24d1532115 Mon Sep 17 00:00:00 2001 From: Jane Liu Date: Tue, 12 Sep 2023 14:30:40 -0700 Subject: [PATCH] Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32 --- tests/batching_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/batching_test.py b/tests/batching_test.py index 525216043..cf136baae 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -222,6 +222,8 @@ class BatchingTest(jtu.JaxTestCase): self.assertAllClose(ans[i], expected_ans, check_dtypes=False) + # Replace the default TF32 with float32 in order to make it pass on A100 + @jax.default_matmul_precision("float32") def testDotGeneral(self): R = self.rng().randn