[SDY] fix JAX layouts tests for Shardy.

PiperOrigin-RevId: 697715276
This commit is contained in:
Bill Varcho 2024-11-18 12:13:55 -08:00 committed by jax authors
parent 70b05f6cde
commit 0ed6eaeb4a

View File

@ -267,6 +267,9 @@ jax_multiplatform_test(
backend_tags = {
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
},
enable_configs = [
"tpu_v3_2x2_shardy",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",