mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU] Unbreak tests
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially seemed right because both expressions evalute to 2 for 1 :) PiperOrigin-RevId: 670527107
This commit is contained in:
parent
eed273c106
commit
4c3111bf26
@ -757,7 +757,7 @@ class FragmentedArray:
|
||||
case _:
|
||||
raise AssertionError(swizzle)
|
||||
stagger_amount = swizzle // 64
|
||||
if (cols_per_tile // 8) % (stagger_amount + 1):
|
||||
if (cols_per_tile // 8) % (stagger_amount * 2):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# We rely on canonicalization to clean up the selects.
|
||||
|
@ -49,7 +49,6 @@ jax_test(
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_configs = ["gpu_h100_2gpu"],
|
||||
shard_count = 4,
|
||||
tags = ["notap"], # Broken at head.
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
@ -61,7 +60,6 @@ jax_test(
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
shard_count = 5,
|
||||
tags = ["notap"], # Broken at head.
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax/experimental/mosaic/gpu/examples:matmul",
|
||||
|
Loading…
x
Reference in New Issue
Block a user