Cast "axis" arg of tf.concat to tf.int32.

PiperOrigin-RevId: 537478940
This commit is contained in:
Yu Emma Wang 2023-06-02 21:11:54 -07:00 committed by jax authors
parent 5639e194be
commit 983e1c0fd1

View File

@ -1923,7 +1923,7 @@ tf_impl_with_avals[lax.clamp_p] = _clamp
def _concatenate(*operands, dimension):
return tf.concat(operands, axis=dimension)
return tf.concat(operands, axis=tf.cast(dimension, tf.int32))
tf_impl[lax.concatenate_p] = _concatenate