[jax2tf] Skip "graph" mode primitive tests on TPUs.

PiperOrigin-RevId: 521468145
This commit is contained in:
George Necula 2023-04-03 08:38:52 -07:00 committed by jax authors
parent d743d23859
commit ff313a37a2

View File

@ -51,6 +51,7 @@ def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
return tf.function( # GRAPH
func_tf,
autograph=False,
# Note that jit_compile defaults to True on TPU and False elsewhere
input_signature=_make_tf_input_signature(*tf_args))(*tf_args) # GRAPH
elif mode == "compiled":
# Adding an explicit input_signature prevents TF from constant-folding
@ -215,6 +216,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
unexpected_successes: List[str] = []
# Run the "compiled" mode first, it is most important
for mode in ("compiled", "eager", "graph"):
if mode == "graph" and jtu.device_under_test() == "tpu":
continue # The "graph" mode on TPU is the same as "compiled"
def log_message(extra):
return f"[{self._testMethodName}] {mode=}: {extra}"