add jax2tf patch to ignore these primitives

This commit is contained in:
Matthew Johnson 2024-10-02 17:21:56 +00:00
parent b7e26ba3ee
commit 15f9ac4aee

View File

@ -1553,6 +1553,9 @@ tf_not_yet_impl = [
"bitcast",
"repeat",
"roll",
# temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740
"bias_fwd_p",
"bias_bwd_p",
]
tf_impl[random_internal.random_clone_p] = lambda x: x