[Pallas] Fully skip GPU attention tests on win32.

PiperOrigin-RevId: 672588009
This commit is contained in:
Justin Fu 2024-09-09 10:20:47 -07:00 committed by jax authors
parent c1bac25a66
commit 4bdfe09241

View File

@ -51,7 +51,7 @@ class PallasBaseTest(jtu.JaxTestCase):
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32" and not self.INTERPRET:
if sys.platform == "win32":
self.skipTest("Only works on non-Windows platforms")
super().setUp()