7 Commits

Author SHA1 Message Date
Matthew Johnson
c44dda84d0 [attrs] fix tracer lifetime bug, fixes #20082 2024-03-05 12:08:44 -08:00
Matthew Johnson
3736b322b7 [xmap-removal] remove reduce_axes from grad / vjp / backward_pass
The reduce_axes machinery was planned to be used for xmap. It's not needed for
e.g. shard_map, see https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html.
2024-02-25 15:50:54 -08:00
Matthew Johnson
b0b88d87d3 [attrs] add linearize and vjp support 2024-02-23 16:43:49 -08:00
Matthew Johnson
67572d3094 [attrs] simplify input side of jvp internals 2024-02-22 14:06:03 -08:00
Matthew Johnson
a45cc437f4 [attrs] allow passing a jax-attrs object to jit functions
currently we don't get any interesting cache hits; only on object identity
match
2024-02-13 16:53:46 -08:00
Matthew Johnson
cdb466c517 [attrs] add a jvp function with attrs support
See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon
to be turned into a JEP so that's visible to all).

For now at least, the api is in attrs.py, and the implementation forks a bit of
the logic in ad.py rather than extending it in-place.

The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`,
namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those
represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch
with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`,
and indeed the `JVPTrace` will never even see it since it doesn't take a
data/term-level argument.

That handles the perturbations to attrs that happen inside the function being
differentiated. To handle the input perturbations, we just stuff `JVPTracer`s
in those attributes when we create tracers for ordinary inputs.

The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough
because those rules don't take the `JVPTrace` as an argument (and thus had no
way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p`
and `setattr_p` to use custom bind rules and call into a
`trace.process_getattr` and `trace.process_setattr` instead. The alternative
would be generalizing our JVP rule signatures, or inserting some alternative
rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and
more conventional not to touch that path and insetad just make
`process_getattr`/`process_setattr`.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-13 14:19:50 -08:00
Matthew Johnson
4a8babb101 integrate attrs in jax.jit
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00