mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Re-enable oss paged attn kernel
PiperOrigin-RevId: 725411244
This commit is contained in:
parent
ffd3faad72
commit
d3ed6ca0cc
@ -463,7 +463,6 @@ jax_multiplatform_test(
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
"nomsan", # Times out.
|
||||
"notap", # this code has data race issues that XLA improvements unhide. b/392946030
|
||||
"notsan", # Times out.
|
||||
],
|
||||
deps = [
|
||||
|
@ -133,7 +133,6 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
|
||||
attn_logits_soft_cap,
|
||||
are_kv_quantized,
|
||||
):
|
||||
self.skipTest("This kernel has data races that need to be fixed.")
|
||||
if not jtu.is_device_tpu_at_least(4):
|
||||
self.skipTest("Only supports TPU generation 4 or above")
|
||||
if jtu.is_device_tpu(version=4) and are_kv_quantized:
|
||||
|
Loading…
x
Reference in New Issue
Block a user