From 4eea05723b693d5283e4a92c0c14bdf97e2faf10 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Tue, 19 Sep 2023 14:35:23 -0400 Subject: [PATCH] Use float32 for testNestedMap and testPdotBatching in XMapTest Multiple test cases were failing on Ampere+ due to use of TF32. --- tests/xmap_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 61f5da11b..534de6757 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -553,6 +553,7 @@ class XMapTest(XMapTestCase): "vmap_as_xmap": vmap_as_xmap} @parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases)) + @jax.default_matmul_precision("float32") def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_result_axis, @@ -1341,6 +1342,7 @@ class PDotTests(XMapTestCase): self.assertAllClose(z, jnp.dot(x, y)) @jtu.with_mesh([('r1', 2)]) + @jax.default_matmul_precision("float32") def testPdotBatching(self): def f(x, y): return lax.pdot(x, y, 'i')