mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Skip "graph" mode primitive tests on TPUs.
PiperOrigin-RevId: 521468145
This commit is contained in:
parent
d743d23859
commit
ff313a37a2
@ -51,6 +51,7 @@ def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
|
||||
return tf.function( # GRAPH
|
||||
func_tf,
|
||||
autograph=False,
|
||||
# Note that jit_compile defaults to True on TPU and False elsewhere
|
||||
input_signature=_make_tf_input_signature(*tf_args))(*tf_args) # GRAPH
|
||||
elif mode == "compiled":
|
||||
# Adding an explicit input_signature prevents TF from constant-folding
|
||||
@ -215,6 +216,8 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
unexpected_successes: List[str] = []
|
||||
# Run the "compiled" mode first, it is most important
|
||||
for mode in ("compiled", "eager", "graph"):
|
||||
if mode == "graph" and jtu.device_under_test() == "tpu":
|
||||
continue # The "graph" mode on TPU is the same as "compiled"
|
||||
def log_message(extra):
|
||||
return f"[{self._testMethodName}] {mode=}: {extra}"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user