mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add padding rule for copy_p
This commit is contained in:
parent
314cf8a439
commit
1d920b51a9
@ -4318,6 +4318,7 @@ copy_p.def_impl(partial(xla.apply_primitive, copy_p))
|
||||
copy_p.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(copy_p, lambda ctx, x: [x])
|
||||
ad.deflinear(copy_p, lambda t: [copy_p.bind(t)])
|
||||
pe.def_trivial_padding(copy_p)
|
||||
batching.defvectorized(copy_p)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user