Added debug_callback to the list of exclusions in jax2tf/tests/primitives_test.py

PiperOrigin-RevId: 641149152
This commit is contained in:
Sergei Lebedev 2024-06-07 00:00:36 -07:00 committed by jax authors
parent c01c98400d
commit 5d6413cecc

View File

@ -179,7 +179,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
if p.name == "debug_callback":
if p.name == "debug_callback" or p.name == "debug_print":
# TODO(sharadmv,necula): enable debug callbacks in TF
continue
if p.name == "pallas_call":