mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Cast "axis" arg of tf.concat to tf.int32.
PiperOrigin-RevId: 537478940
This commit is contained in:
parent
5639e194be
commit
983e1c0fd1
@ -1923,7 +1923,7 @@ tf_impl_with_avals[lax.clamp_p] = _clamp
|
||||
|
||||
|
||||
def _concatenate(*operands, dimension):
|
||||
return tf.concat(operands, axis=dimension)
|
||||
return tf.concat(operands, axis=tf.cast(dimension, tf.int32))
|
||||
|
||||
|
||||
tf_impl[lax.concatenate_p] = _concatenate
|
||||
|
Loading…
x
Reference in New Issue
Block a user