mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Reenable GPU int matmul test since the XLA bug is fixed.
PiperOrigin-RevId: 495117439
This commit is contained in:
parent
c6eb632f57
commit
b80af85298
@ -219,8 +219,6 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
dtype=default_dtypes,
|
||||
)
|
||||
def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
|
||||
if jtu.device_under_test() == "gpu" and dtype == np.int64:
|
||||
raise unittest.SkipTest("Wrong outputs for batched matmuls (b/258497059)")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
op = partial(lax.dot, precision=lax.Precision.HIGHEST)
|
||||
self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user