[Mosaic GPU] Unbreak tests

I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)

PiperOrigin-RevId: 670527107
This commit is contained in:
Adam Paszke 2024-09-03 06:06:54 -07:00 committed by jax authors
parent eed273c106
commit 4c3111bf26
2 changed files with 1 additions and 3 deletions

View File

@ -757,7 +757,7 @@ class FragmentedArray:
case _:
raise AssertionError(swizzle)
stagger_amount = swizzle // 64
if (cols_per_tile // 8) % (stagger_amount + 1):
if (cols_per_tile // 8) % (stagger_amount * 2):
raise NotImplementedError
else:
# We rely on canonicalization to clean up the selects.

View File

@ -49,7 +49,6 @@ jax_test(
disable_configs = DISABLED_CONFIGS,
enable_configs = ["gpu_h100_2gpu"],
shard_count = 4,
tags = ["notap"], # Broken at head.
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
@ -61,7 +60,6 @@ jax_test(
disable_backends = DISABLED_BACKENDS,
disable_configs = DISABLED_CONFIGS,
shard_count = 5,
tags = ["notap"], # Broken at head.
deps = [
"//jax:mosaic_gpu",
"//jax/experimental/mosaic/gpu/examples:matmul",