From bfbf9e1c3313cae2dfd71e39e64c50c8f42017ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Sat, 4 Nov 2023 16:04:49 -0700 Subject: [PATCH] [XLA:GPU] Consider Triton for all non-pure GEMM fusions This is a big step toward enabling xla_gpu_triton_gemm_any by default. It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true). PiperOrigin-RevId: 579524573 --- tests/xmap_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 62d384268..30074a189 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -1905,7 +1905,7 @@ class NamedAutodiffTests(jtu.JaxTestCase): in_axes=({}, {0: 'batch'}, {0: 'batch'}), out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy) expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy) - self.assertAllClose(out, expected, check_dtypes=True) + self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005) # reduced out = xmap(vjp_f_reduced, @@ -1913,7 +1913,7 @@ class NamedAutodiffTests(jtu.JaxTestCase): out_axes=({}, {0: 'batch'}))(w, x, gy) # the reduced VJP is also the VJP when using a positional batch axis expected = vjp_f(w, x, gy) - self.assertAllClose(out, expected, check_dtypes=True) + self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005) def testVjpReduceAxesCollective(self): @@ -1952,7 +1952,7 @@ class NamedAutodiffTests(jtu.JaxTestCase): in_axes=({}, {0: 'batch'}), out_axes=({}, {0: 'batch'}))(w, x) expected = jax.grad(f_positional, (0, 1))(w, x) - self.assertAllClose(out, expected, check_dtypes=True) + self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005) if __name__ == '__main__':