Fix sparse dot metadata loader

Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.

PiperOrigin-RevId: 633513110
This commit is contained in:
Sergey Kozub 2024-05-14 03:07:12 -07:00 committed by jax authors
parent e735a00cdc
commit 84774b39e6

View File

@ -999,8 +999,7 @@ jax_test(
],
enable_configs = [
"gpu_a100",
# TODO(b/337303303): re-enable the test
# "gpu_h100",
"gpu_h100",
],
deps = [
"//jax:experimental_sparse",