Reenable GPU int matmul test since the XLA bug is fixed.

PiperOrigin-RevId: 495117439
This commit is contained in:
Peter Hawkins 2022-12-13 13:50:10 -08:00 committed by jax authors
parent c6eb632f57
commit b80af85298

View File

@ -219,8 +219,6 @@ class LaxVmapTest(jtu.JaxTestCase):
dtype=default_dtypes,
)
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
if jtu.device_under_test() == "gpu" and dtype == np.int64:
raise unittest.SkipTest("Wrong outputs for batched matmuls (b/258497059)")
rng = jtu.rand_default(self.rng())
op = partial(lax.dot, precision=lax.Precision.HIGHEST)
self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),