[XLA:GPU] Consider Triton for all non-pure GEMM fusions

This is a big step toward enabling xla_gpu_triton_gemm_any by default.

It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true).

PiperOrigin-RevId: 579524573
This commit is contained in:
Tamás Danyluk 2023-11-04 16:04:49 -07:00 committed by jax authors
parent dda76733e8
commit bfbf9e1c33

View File

@ -1905,7 +1905,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy)
expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy)
self.assertAllClose(out, expected, check_dtypes=True)
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
# reduced
out = xmap(vjp_f_reduced,
@ -1913,7 +1913,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
out_axes=({}, {0: 'batch'}))(w, x, gy)
# the reduced VJP is also the VJP when using a positional batch axis
expected = vjp_f(w, x, gy)
self.assertAllClose(out, expected, check_dtypes=True)
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
def testVjpReduceAxesCollective(self):
@ -1952,7 +1952,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
in_axes=({}, {0: 'batch'}),
out_axes=({}, {0: 'batch'}))(w, x)
expected = jax.grad(f_positional, (0, 1))(w, x)
self.assertAllClose(out, expected, check_dtypes=True)
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
if __name__ == '__main__':