mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 05:56:05 +00:00

This follows after #26078, and #26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.