add defvjp functions for custom VJPs

c.f. #116, which won't be closed until we add documentation
This commit is contained in:
Matthew Johnson 2019-04-23 17:47:28 -07:00
parent d5efe88d0c
commit 85755820bb
5 changed files with 137 additions and 23 deletions

View File

@ -178,3 +178,12 @@ array_types = [onp.ndarray, onp.float64, onp.float32, onp.float16,
for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
ad_util.jaxval_zeros_likers[t] = zeros_like_array
def raise_to_shaped(aval):
if type(aval) is core.AbstractTuple:
return core.AbstractTuple(map(raise_to_shaped, aval))
elif isinstance(aval, ShapedArray):
return ShapedArray(aval.shape, aval.dtype)
else:
raise TypeError(type(aval))

View File

@ -23,6 +23,7 @@ from .. import core as core
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, zero, Zero)
from ..abstract_arrays import raise_to_shaped
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
@ -307,7 +308,6 @@ def deflinear(primitive, transpose_rule):
primitive_jvps[primitive] = partial(linear_jvp, primitive)
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
def linear_jvp(primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
if all(tangent is zero for tangent in tangents):
@ -316,7 +316,6 @@ def linear_jvp(primitive, primals, tangents, **params):
tangents = map(instantiate_zeros, primals, tangents)
return val_out, primitive.bind(*tangents, **params)
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
@ -332,19 +331,16 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params):
if rule is not None and t is not zero]
return val_out, reduce(add_tangents, tangents_out, zero)
def defjvp2(primitive, *jvprules):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
val_out = primitive.bind(*primals, **params)
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not zero)
return val_out, reduce(add_tangents, tangents_out, zero)
def add_tangents(x, y):
if x is zero:
return y
@ -354,6 +350,57 @@ def add_tangents(x, y):
return add_jaxvals(x, y)
def defvjp_all(prim, custom_vjp):
name = prim.name
def fun_jvp(xs, ts):
ts = map(instantiate_zeros, xs, ts) # TODO(mattjj): avoid instantiation?
primal_out, tangent_out = fun_jvp_p.bind(pack(xs), pack(ts))
return primal_out, tangent_out
primitive_jvps[prim] = fun_jvp
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
def fun_jvp_partial_eval(trace, *tracers):
primals_tracer, tangents_tracer = tracers
primal_out, vjp_py = custom_vjp(*primals_tracer)
in_aval = raise_to_shaped(get_aval(primal_out))
ct_pval = pe.PartialVal((in_aval, core.unit))
vjp_jaxpr, out_pval, residuals = pe.trace_unwrapped_to_jaxpr(
lambda ct: pack(vjp_py(ct)), (ct_pval,))
out_pv, out_const = out_pval
tangent_out = fun_lin_p.bind(out_const, pack(residuals), tangents_tracer,
in_aval=in_aval, out_pv=out_pv, vjp_jaxpr=vjp_jaxpr)
return pack((primal_out, tangent_out))
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
fun_lin_p.def_abstract_eval(lambda c, r, ts, in_aval, out_pv, vjp_jaxpr: in_aval)
def fun_lin_transpose(ct, out_const, residuals, ts, in_aval, out_pv, vjp_jaxpr):
assert ts is None and out_const is not None and residuals is not None
ans = core.eval_jaxpr(vjp_jaxpr, residuals, (), ct)
out = pe.merge_pvals(ans, pe.PartialVal((out_pv, out_const)))
return [None, None, out]
primitive_transposes[fun_lin_p] = fun_lin_transpose
def defvjp(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defvjp2(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
assert isinstance(prim, Primitive)
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
@ -362,7 +409,6 @@ def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
assert (x is None) ^ (y is None)
if x is None:
@ -377,7 +423,6 @@ def defjvp_zero(primitive):
assert isinstance(primitive, Primitive)
primitive_jvps[primitive] = partial(zero_jvp, primitive)
def zero_jvp(primitive, primals, tangents, **params):
return primitive.bind(*primals, **params), zero

View File

@ -26,7 +26,7 @@ from six.moves import reduce
from .. import core
from ..core import Trace, Tracer, new_master, pack, AbstractTuple, JaxTuple
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types, raise_to_shaped
from ..ad_util import add_jaxvals_p, zeros_like_p, zeros_like_jaxval
from ..linear_util import transformation, transformation_with_aux, wrap_init
from ..tree_util import register_pytree_node
@ -160,14 +160,6 @@ def shaped_aval(x):
except KeyError:
raise TypeError("{} is not a valid type for batching".format(type(x)))
def raise_to_shaped(aval):
if type(aval) is AbstractTuple:
return AbstractTuple(map(raise_to_shaped, aval))
elif isinstance(aval, ShapedArray):
return ShapedArray(aval.shape, aval.dtype)
else:
raise TypeError(type(aval))
def remove_batch_dim_from_aval(bdim, aval):
t = type(aval)
if t is AbstractTuple:

View File

@ -36,7 +36,7 @@ from .. import linear_util as lu
from ..config import flags
from ..core import Primitive
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
array_types, make_shaped_array)
array_types, make_shaped_array, raise_to_shaped)
from ..api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
from ..interpreters import partial_eval as pe
@ -3951,8 +3951,5 @@ def subvals(lst, replace):
def _abstractify(x):
# abstractify wrapper used internally for primitives like while_loop
if isinstance(x, core.Tracer):
return pe.PartialVal((xla.abstractify(x.aval), core.unit))
else:
return pe.PartialVal((xla.abstractify(x), core.unit))
# used internally for initial-style higher-order primitives
return pe.PartialVal((raise_to_shaped(core.get_aval(x)), core.unit))

View File

@ -26,7 +26,7 @@ import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
from jax import api
from jax.core import Primitive
from jax.interpreters.ad import defjvp
from jax.interpreters.ad import defjvp, defvjp, defvjp2, defvjp_all
from jax.interpreters.xla import DeviceArray
from jax.abstract_arrays import concretization_err_msg
@ -466,6 +466,77 @@ class APITest(jtu.JaxTestCase):
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
def test_defvjp_all(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
def test_defvjp_all_const(self):
foo_p = Primitive('foo')
def foo(x): return foo_p.bind(x)
defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 9., check_dtypes=False)
self.assertAllClose(grad_ans, 12., check_dtypes=True)
def test_defvjp_all_higher_order_revmode(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
ans = api.grad(api.grad(foo))(3.)
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
def test_defvjp_all_multiple_arguments(self):
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
# the first argument in one case
foo_p = Primitive('foo')
def foo(x, y): return foo_p.bind(x, y)
def vjpfun(x, y):
out = x**2 + y**3
vjp = lambda g: (g + x + y, g * x * 9.)
return out, vjp
defvjp_all(foo_p, vjpfun)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
ans = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def test_defvjp(self):
@api.custom_transforms
def foo(x, y):
return np.sin(x * y)
defvjp(foo.primitive, None, lambda g, x, y: g * x * y)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 0., check_dtypes=False)
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans_0, 0., check_dtypes=False)
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
def test_defvjp2(self):
@api.custom_transforms
def foo(x, y):
return np.sin(x * y)
defvjp2(foo.primitive, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
check_dtypes=False)
if __name__ == '__main__':
absltest.main()