mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add defvjp functions for custom VJPs
c.f. #116, which won't be closed until we add documentation
This commit is contained in:
parent
d5efe88d0c
commit
85755820bb
@ -178,3 +178,12 @@ array_types = [onp.ndarray, onp.float64, onp.float32, onp.float16,
|
||||
for t in array_types:
|
||||
core.pytype_aval_mappings[t] = ConcreteArray
|
||||
ad_util.jaxval_zeros_likers[t] = zeros_like_array
|
||||
|
||||
|
||||
def raise_to_shaped(aval):
|
||||
if type(aval) is core.AbstractTuple:
|
||||
return core.AbstractTuple(map(raise_to_shaped, aval))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError(type(aval))
|
||||
|
@ -23,6 +23,7 @@ from .. import core as core
|
||||
from ..core import JaxTuple, Trace, Tracer, new_master, get_aval, pack, call_p, Primitive
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, zero, Zero)
|
||||
from ..abstract_arrays import raise_to_shaped
|
||||
from ..util import unzip2, unzip3, safe_map, safe_zip, partial
|
||||
from ..tree_util import process_pytree, build_tree, register_pytree_node, tree_map
|
||||
from ..linear_util import thunk, staged, transformation, transformation_with_aux, wrap_init
|
||||
@ -307,7 +308,6 @@ def deflinear(primitive, transpose_rule):
|
||||
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
||||
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
||||
|
||||
|
||||
def linear_jvp(primitive, primals, tangents, **params):
|
||||
val_out = primitive.bind(*primals, **params)
|
||||
if all(tangent is zero for tangent in tangents):
|
||||
@ -316,7 +316,6 @@ def linear_jvp(primitive, primals, tangents, **params):
|
||||
tangents = map(instantiate_zeros, primals, tangents)
|
||||
return val_out, primitive.bind(*tangents, **params)
|
||||
|
||||
|
||||
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
||||
return zero if cotangent is zero else transpose_rule(cotangent, **kwargs)
|
||||
|
||||
@ -332,19 +331,16 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
||||
if rule is not None and t is not zero]
|
||||
return val_out, reduce(add_tangents, tangents_out, zero)
|
||||
|
||||
|
||||
def defjvp2(primitive, *jvprules):
|
||||
assert isinstance(primitive, Primitive)
|
||||
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
||||
|
||||
|
||||
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
||||
val_out = primitive.bind(*primals, **params)
|
||||
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
||||
if rule is not None and t is not zero)
|
||||
return val_out, reduce(add_tangents, tangents_out, zero)
|
||||
|
||||
|
||||
def add_tangents(x, y):
|
||||
if x is zero:
|
||||
return y
|
||||
@ -354,6 +350,57 @@ def add_tangents(x, y):
|
||||
return add_jaxvals(x, y)
|
||||
|
||||
|
||||
def defvjp_all(prim, custom_vjp):
|
||||
name = prim.name
|
||||
|
||||
def fun_jvp(xs, ts):
|
||||
ts = map(instantiate_zeros, xs, ts) # TODO(mattjj): avoid instantiation?
|
||||
primal_out, tangent_out = fun_jvp_p.bind(pack(xs), pack(ts))
|
||||
return primal_out, tangent_out
|
||||
primitive_jvps[prim] = fun_jvp
|
||||
|
||||
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
||||
def fun_jvp_partial_eval(trace, *tracers):
|
||||
primals_tracer, tangents_tracer = tracers
|
||||
primal_out, vjp_py = custom_vjp(*primals_tracer)
|
||||
|
||||
in_aval = raise_to_shaped(get_aval(primal_out))
|
||||
ct_pval = pe.PartialVal((in_aval, core.unit))
|
||||
vjp_jaxpr, out_pval, residuals = pe.trace_unwrapped_to_jaxpr(
|
||||
lambda ct: pack(vjp_py(ct)), (ct_pval,))
|
||||
out_pv, out_const = out_pval
|
||||
tangent_out = fun_lin_p.bind(out_const, pack(residuals), tangents_tracer,
|
||||
in_aval=in_aval, out_pv=out_pv, vjp_jaxpr=vjp_jaxpr)
|
||||
|
||||
return pack((primal_out, tangent_out))
|
||||
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
||||
|
||||
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
||||
fun_lin_p.def_abstract_eval(lambda c, r, ts, in_aval, out_pv, vjp_jaxpr: in_aval)
|
||||
def fun_lin_transpose(ct, out_const, residuals, ts, in_aval, out_pv, vjp_jaxpr):
|
||||
assert ts is None and out_const is not None and residuals is not None
|
||||
ans = core.eval_jaxpr(vjp_jaxpr, residuals, (), ct)
|
||||
out = pe.merge_pvals(ans, pe.PartialVal((out_pv, out_const)))
|
||||
return [None, None, out]
|
||||
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
||||
|
||||
def defvjp(prim, *vjps):
|
||||
def vjpmaker(*primals):
|
||||
ans = prim.bind(*primals)
|
||||
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
||||
for x, vjp in zip(primals, vjps)]
|
||||
return ans, vjpfun
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
def defvjp2(prim, *vjps):
|
||||
def vjpmaker(*primals):
|
||||
ans = prim.bind(*primals)
|
||||
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
||||
for x, vjp in zip(primals, vjps)]
|
||||
return ans, vjpfun
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
|
||||
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
||||
assert isinstance(prim, Primitive)
|
||||
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
||||
@ -362,7 +409,6 @@ def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
||||
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
||||
defbilinear = partial(defbilinear_broadcasting, lambda g, x: g)
|
||||
|
||||
|
||||
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
||||
assert (x is None) ^ (y is None)
|
||||
if x is None:
|
||||
@ -377,7 +423,6 @@ def defjvp_zero(primitive):
|
||||
assert isinstance(primitive, Primitive)
|
||||
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
||||
|
||||
|
||||
def zero_jvp(primitive, primals, tangents, **params):
|
||||
return primitive.bind(*primals, **params), zero
|
||||
|
||||
|
@ -26,7 +26,7 @@ from six.moves import reduce
|
||||
|
||||
from .. import core
|
||||
from ..core import Trace, Tracer, new_master, pack, AbstractTuple, JaxTuple
|
||||
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types
|
||||
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types, raise_to_shaped
|
||||
from ..ad_util import add_jaxvals_p, zeros_like_p, zeros_like_jaxval
|
||||
from ..linear_util import transformation, transformation_with_aux, wrap_init
|
||||
from ..tree_util import register_pytree_node
|
||||
@ -160,14 +160,6 @@ def shaped_aval(x):
|
||||
except KeyError:
|
||||
raise TypeError("{} is not a valid type for batching".format(type(x)))
|
||||
|
||||
def raise_to_shaped(aval):
|
||||
if type(aval) is AbstractTuple:
|
||||
return AbstractTuple(map(raise_to_shaped, aval))
|
||||
elif isinstance(aval, ShapedArray):
|
||||
return ShapedArray(aval.shape, aval.dtype)
|
||||
else:
|
||||
raise TypeError(type(aval))
|
||||
|
||||
def remove_batch_dim_from_aval(bdim, aval):
|
||||
t = type(aval)
|
||||
if t is AbstractTuple:
|
||||
|
@ -36,7 +36,7 @@ from .. import linear_util as lu
|
||||
from ..config import flags
|
||||
from ..core import Primitive
|
||||
from ..abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
array_types, make_shaped_array)
|
||||
array_types, make_shaped_array, raise_to_shaped)
|
||||
from ..api_util import (pytree_fun_to_jaxtupletree_fun, pytree_to_jaxtupletree,
|
||||
pytree_fun_to_flatjaxtuple_fun, pytree_to_flatjaxtuple)
|
||||
from ..interpreters import partial_eval as pe
|
||||
@ -3951,8 +3951,5 @@ def subvals(lst, replace):
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
# abstractify wrapper used internally for primitives like while_loop
|
||||
if isinstance(x, core.Tracer):
|
||||
return pe.PartialVal((xla.abstractify(x.aval), core.unit))
|
||||
else:
|
||||
return pe.PartialVal((xla.abstractify(x), core.unit))
|
||||
# used internally for initial-style higher-order primitives
|
||||
return pe.PartialVal((raise_to_shaped(core.get_aval(x)), core.unit))
|
||||
|
@ -26,7 +26,7 @@ import jax.numpy as np
|
||||
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
|
||||
from jax import api
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters.ad import defjvp
|
||||
from jax.interpreters.ad import defjvp, defvjp, defvjp2, defvjp_all
|
||||
from jax.interpreters.xla import DeviceArray
|
||||
from jax.abstract_arrays import concretization_err_msg
|
||||
|
||||
@ -466,6 +466,77 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_complex_input_jacfwd_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
|
||||
|
||||
def test_defvjp_all(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return 2. * foo_p.bind(x)
|
||||
|
||||
defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
|
||||
|
||||
def test_defvjp_all_const(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return foo_p.bind(x)
|
||||
|
||||
defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, 9., check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 12., check_dtypes=True)
|
||||
|
||||
def test_defvjp_all_higher_order_revmode(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return 2. * foo_p.bind(x)
|
||||
|
||||
defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
|
||||
ans = api.grad(api.grad(foo))(3.)
|
||||
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
|
||||
|
||||
def test_defvjp_all_multiple_arguments(self):
|
||||
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
|
||||
# the first argument in one case
|
||||
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x, y): return foo_p.bind(x, y)
|
||||
|
||||
def vjpfun(x, y):
|
||||
out = x**2 + y**3
|
||||
vjp = lambda g: (g + x + y, g * x * 9.)
|
||||
return out, vjp
|
||||
|
||||
defvjp_all(foo_p, vjpfun)
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
||||
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
|
||||
|
||||
ans = api.grad(foo, (0, 1))(3., 4.)
|
||||
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
|
||||
|
||||
def test_defvjp(self):
|
||||
@api.custom_transforms
|
||||
def foo(x, y):
|
||||
return np.sin(x * y)
|
||||
|
||||
defvjp(foo.primitive, None, lambda g, x, y: g * x * y)
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
||||
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 0., check_dtypes=False)
|
||||
|
||||
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
|
||||
self.assertAllClose(ans_0, 0., check_dtypes=False)
|
||||
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
|
||||
|
||||
def test_defvjp2(self):
|
||||
@api.custom_transforms
|
||||
def foo(x, y):
|
||||
return np.sin(x * y)
|
||||
|
||||
defvjp2(foo.primitive, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
|
||||
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
|
||||
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
|
||||
check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user