See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon
to be turned into a JEP so that's visible to all).
For now at least, the api is in attrs.py, and the implementation forks a bit of
the logic in ad.py rather than extending it in-place.
The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`,
namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those
represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch
with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`,
and indeed the `JVPTrace` will never even see it since it doesn't take a
data/term-level argument.
That handles the perturbations to attrs that happen inside the function being
differentiated. To handle the input perturbations, we just stuff `JVPTracer`s
in those attributes when we create tracers for ordinary inputs.
The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough
because those rules don't take the `JVPTrace` as an argument (and thus had no
way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p`
and `setattr_p` to use custom bind rules and call into a
`trace.process_getattr` and `trace.process_setattr` instead. The alternative
would be generalizing our JVP rule signatures, or inserting some alternative
rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and
more conventional not to touch that path and insetad just make
`process_getattr`/`process_setattr`.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>