mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add jax2tf patch to ignore these primitives
This commit is contained in:
parent
b7e26ba3ee
commit
15f9ac4aee
@ -1553,6 +1553,9 @@ tf_not_yet_impl = [
|
||||
"bitcast",
|
||||
"repeat",
|
||||
"roll",
|
||||
# temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740
|
||||
"bias_fwd_p",
|
||||
"bias_bwd_p",
|
||||
]
|
||||
|
||||
tf_impl[random_internal.random_clone_p] = lambda x: x
|
||||
|
Loading…
x
Reference in New Issue
Block a user