mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

* dropping support for special AD handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variale, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs. This allows us to make some significant cleanup in the internals.