mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:GPU] Consider Triton for all non-pure GEMM fusions
This is a big step toward enabling xla_gpu_triton_gemm_any by default. It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true). PiperOrigin-RevId: 579524573
This commit is contained in:
parent
dda76733e8
commit
bfbf9e1c33
@ -1905,7 +1905,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
|
||||
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
||||
out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy)
|
||||
expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
|
||||
|
||||
# reduced
|
||||
out = xmap(vjp_f_reduced,
|
||||
@ -1913,7 +1913,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
|
||||
out_axes=({}, {0: 'batch'}))(w, x, gy)
|
||||
# the reduced VJP is also the VJP when using a positional batch axis
|
||||
expected = vjp_f(w, x, gy)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
|
||||
|
||||
def testVjpReduceAxesCollective(self):
|
||||
|
||||
@ -1952,7 +1952,7 @@ class NamedAutodiffTests(jtu.JaxTestCase):
|
||||
in_axes=({}, {0: 'batch'}),
|
||||
out_axes=({}, {0: 'batch'}))(w, x)
|
||||
expected = jax.grad(f_positional, (0, 1))(w, x)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
self.assertAllClose(out, expected, check_dtypes=True, rtol=0.005)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user