Use float32 for testNestedMap and testPdotBatching in XMapTest

Multiple test cases were failing on Ampere+ due to use of TF32.
This commit is contained in:
Andrey Portnoy 2023-09-19 14:35:23 -04:00
parent 3b66fbf841
commit 4eea05723b

View File

@ -553,6 +553,7 @@ class XMapTest(XMapTestCase):
"vmap_as_xmap": vmap_as_xmap}
@parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases))
@jax.default_matmul_precision("float32")
def testNestedMap(self,
xmap_in_axes, xmap_out_axes,
vmap_in_axes, vmap_out_axes, vmap_result_axis,
@ -1341,6 +1342,7 @@ class PDotTests(XMapTestCase):
self.assertAllClose(z, jnp.dot(x, y))
@jtu.with_mesh([('r1', 2)])
@jax.default_matmul_precision("float32")
def testPdotBatching(self):
def f(x, y):
return lax.pdot(x, y, 'i')