diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index de539461a..cf1d8b3d2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4321,6 +4321,7 @@ copy_p.def_impl(partial(xla.apply_primitive, copy_p)) copy_p.def_abstract_eval(lambda x: x) mlir.register_lowering(copy_p, lambda ctx, x: [x]) ad.deflinear(copy_p, lambda t: [copy_p.bind(t)]) +pe.def_trivial_padding(copy_p) batching.defvectorized(copy_p)