133 lines
4.8 KiB
Python
Raw Normal View History

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from jax._src import core
from jax._src import api_util
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
from jax._src import linear_util as lu
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import unzip2
JaxVal = Any
register = api_util.register_class_with_attrs
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
class GetAttrPrimitive(core.Primitive):
def bind_with_trace(self, trace, args, params):
() = args
return trace.process_getattr(**params)
getattr_p = GetAttrPrimitive('getattr')
class SetAttrPrimitive(core.Primitive):
def bind_with_trace(self, trace, args, params):
val, = args
return trace.process_setattr(trace.full_raise(val), **params)
setattr_p = SetAttrPrimitive('setattr')
def jax_getattr(obj: Any, attr: str):
return getattr_p.bind(obj=obj, attr=attr)
def jax_setattr(obj: Any, attr: str, val: JaxVal):
setattr_p.bind(val, obj=obj, attr=attr)
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
def _getattr_impl(_, *, obj, attr):
return getattr(obj, attr)
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
core.EvalTrace.process_getattr = _getattr_impl
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
def _setattr_impl(_, val, *, obj, attr):
setattr(obj, attr, val)
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
core.EvalTrace.process_setattr = _setattr_impl
def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
frame = trace.main.jaxpr_stack[-1] # type: ignore
if (obj, attr) not in frame.attrs_tracked:
init_val = getattr(obj, attr)
aval = core.raise_to_shaped(core.get_aval(init_val))
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
setattr(obj, attr, tracer)
frame.attrs_tracked.append((obj, attr))
frame.attrs_inits.append(init_val)
frame.attrs_vars.append(var)
pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked
def _getattr_staging(trace, *, obj, attr):
trace._ensure_tracked(obj, attr)
return getattr(obj, attr)
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
pe.DynamicJaxprTrace.process_getattr = _getattr_staging
def _setattr_staging(trace, tracer, *, obj, attr):
trace._ensure_tracked(obj, attr)
setattr(obj, attr, tracer)
[attrs] add a jvp function with attrs support See the "autodiff of stateful functions" section of go/jax-oop-proposal (soon to be turned into a JEP so that's visible to all). For now at least, the api is in attrs.py, and the implementation forks a bit of the logic in ad.py rather than extending it in-place. The basic strategy is analogous to what we do with `trace_to_jaxpr_dynamic`, namely we just accumulate an `attrs_tracked` on the `JVPTrace`'s main. Those represent the `(object, attrname) : tuple[Any, str]` pairs that we ever touch with `setattr_p` and a `JVPTracer`. We need not do anything with `getattr_p`, and indeed the `JVPTrace` will never even see it since it doesn't take a data/term-level argument. That handles the perturbations to attrs that happen inside the function being differentiated. To handle the input perturbations, we just stuff `JVPTracer`s in those attributes when we create tracers for ordinary inputs. The JVP rule signature (for entries in ad.primitive_jvps) wasn't general enough because those rules don't take the `JVPTrace` as an argument (and thus had no way to get at the `MainTrace` or the `attrs_tracked`. So I switched `getattr_p` and `setattr_p` to use custom bind rules and call into a `trace.process_getattr` and `trace.process_setattr` instead. The alternative would be generalizing our JVP rule signatures, or inserting some alternative rule path in the standard `JVPTrace.process_primitive`. It seemed simpler and more conventional not to touch that path and insetad just make `process_getattr`/`process_setattr`. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-02-02 23:31:37 -08:00
pe.DynamicJaxprTrace.process_setattr = _setattr_staging
def jvp(f, primals, tangents, tangent_attrs_in):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree_ = tree_flatten(tangents)
if in_tree != in_tree_: raise Exception
f_, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
primals_flat, tangents_flat, tangent_attrs_in)
out_primals = tree_unflatten(out_tree(), out_primals_flat)
out_tangents = tree_unflatten(out_tree(), out_tangents_flat)
return out_primals, out_tangents, tangent_attrs_out
def _jvp(fun: lu.WrappedFun):
return jvpfun2(jvp_subtrace2(fun))
@lu.transformation
def jvpfun2(primals, tangents, tangent_attrs_in):
with core.new_main(ad.JVPTrace) as main:
out_primals, out_tangents, tangent_attrs_out = \
yield (main, primals, tangents, tangent_attrs_in), {}
del main
yield out_primals, out_tangents, tangent_attrs_out
@lu.transformation
def jvp_subtrace2(main, primals, tangents, tangent_attrs_in):
main.attrs_tracked = [] # attrs written to
trace = main.with_cur_sublevel()
for obj, name, tangent in tangent_attrs_in:
primal = jax_getattr(obj, name)
tracer = ad.JVPTracer(trace, primal, tangent)
jax_setattr(obj, name, tracer)
in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x
for x, t in zip(primals, tangents)]
ans = yield in_tracers, {}
out_tracers = map(trace.full_raise, ans)
out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
tangent_attrs_out = []
for (obj, name) in main.attrs_tracked:
tracer = trace.full_raise(jax_getattr(obj, name))
jax_setattr(obj, name, tracer.primal)
if type(tracer.tangent) is not ad.Zero:
tangent_attrs_out.append((obj, name, tracer.tangent))
del main.attrs_tracked
yield out_primals, out_tangents, tangent_attrs_out
def _setattr_jvp(trace, tracer, *, obj, attr):
if (obj, attr) not in trace.main.attrs_tracked:
trace.main.attrs_tracked.append((obj, attr))
setattr(obj, attr, tracer)
ad.JVPTrace.process_setattr = _setattr_jvp