Add back support for custom_transforms (#2484)

* add also the tests
* mark the old APIs as deprecated
This commit is contained in:
George Necula 2020-03-22 19:50:06 +01:00 committed by GitHub
parent 069cb3e2fb
commit f658eb5bf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 836 additions and 0 deletions

View File

@ -1460,6 +1460,540 @@ def _valid_jaxtype(arg):
return True
class CustomTransformsFunction(object):
def __init__(self, fun, prim):
self.fun = fun
self.prim = prim
wraps(fun)(self)
def __repr__(self):
return '<jax.custom_transforms function {fun}>'.format(fun=self.__name__)
def __call__(self, *args):
# TODO(mattjj): instead of tracing to a jaxpr, use process_call
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_pvals = [pe.PartialVal((raise_to_shaped(core.get_aval(x)), core.unit))
for x in args_flat]
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
num_consts=len(consts))
return tree_unflatten(out_tree(), outs)
def custom_transforms(fun):
"""Wraps a function so that its transformation behavior can be controlled.
A primary use case of ``custom_transforms`` is defining custom VJP rules (aka
custom gradients) for a Python function, while still supporting other
transformations like ``jax.jit`` and ``jax.vmap``. Custom differentiation
rules can be supplied using the ``jax.defjvp`` and ``jax.defvjp`` functions.
The ``custom_transforms`` decorator wraps ``fun`` so that its transformation
behavior can be overridden, but not all transformation rules need to be
specified manually. The default behavior is retained for any non-overridden
rules.
The function ``fun`` must satisfy the same constraints required for jit
compilation. In particular the shapes of arrays in the computation of ``fun``
may depend on the shapes of ``fun``'s arguments, but not their values.
Value dependent Python control flow is also not yet supported.
Args:
fun: a Python callable. Must be functionally pure. Its arguments and return
value should be arrays, scalars, or (nested) standard Python containers
(tuple/list/dict) thereof.
Returns:
A Python callable with the same input/output and transformation behavior as
``fun``, but for which custom transformation rules can be supplied, e.g.
using ``jax.defvjp``.
For example:
>>> @jax.custom_transforms
... def f(x):
... return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> print(jax.grad(f)(3.))
-5.4667816
>>> jax.defvjp(f, lambda g, x: g * x)
>>> print(jax.grad(f)(3.))
3.0
"""
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
name = getattr(fun, '__name__', '<unnamed custom_transforms primitive>')
fun_p = core.Primitive(name)
fun_p.multiple_results = True
def fun_impl(*args, **params):
consts, args = split_list(args, [params['num_consts']])
return core.eval_jaxpr(params['jaxpr'], consts, *args)
fun_p.def_impl(fun_impl)
def fun_jvp(primals, tangents, **params):
return ad.jvp(lu.wrap_init(fun_impl, params)).call_wrapped(primals, tangents)
ad.primitive_jvps[fun_p] = fun_jvp
def fun_batch(args, dims, **params):
return batching.batch_fun(lu.wrap_init(fun_impl, params), args, dims)
batching.primitive_batchers[fun_p] = fun_batch
def fun_abstract_eval(*avals, **params):
return pe.abstract_eval_fun(fun_impl, *avals, **params)
fun_p.def_abstract_eval(fun_abstract_eval)
def fun_translation(c, *xla_args, **params):
return xla.lower_fun(fun_impl)(c, *xla_args, **params)
xla.translations[fun_p] = fun_translation
return CustomTransformsFunction(fun, fun_p)
def _check_custom_transforms_type(name, fun):
if type(fun) is not CustomTransformsFunction:
msg = ("{} requires a custom_transforms function as its first argument, "
"but got type {}.")
raise TypeError(msg.format(name, type(fun)))
def defjvp_all(fun, custom_jvp):
"""Define a custom JVP rule for a ``custom_transforms`` function.
If ``fun`` represents a function with signature ``a -> b``, then
``custom_jvp`` represents a function with signature ``(a, T a) -> (b, T b)``,
where we use ``T x`` to represent a tangent type for the type ``x``.
In more detail, ``custom_jvp`` must take two arguments, both tuples of length
equal to the number of positional arguments to ``fun``. The first argument to
``custom_jvp`` represents the input primal values, and the second represents
the input tangent values. ``custom_jvp`` must return a pair where the first
element represents the output primal value and the second element represents
the output tangent value.
Defining a custom JVP rule also affects the default VJP rule, which is derived
from the JVP rule automatically via transposition.
Args:
fun: a custom_transforms function.
custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
arguments specifying the input primal values and tangent values,
respectively. The tuple elements can be arrays, scalars, or (nested)
standard Python containers (tuple/list/dict) thereof. The output must be a
pair representing the primal output and tangent output, which can be
arrays, scalars, or (nested) standard Python containers. Must be
functionally pure.
Returns:
None. A side-effect is that ``fun`` is associated with the JVP rule
specified by ``custom_jvp``.
For example:
>>> @jax.custom_transforms
... def f(x):
... return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
-10.933563
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
16.0
"""
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
_check_custom_transforms_type("defjvp_all", fun)
def custom_transforms_jvp(primals, tangents, **params):
num_consts, in_tree = params['num_consts'], params['in_tree']
_, args_flat = split_list(primals, [num_consts])
consts_dot, args_dot_flat = split_list(tangents, [num_consts])
if not all(t is ad_util.zero for t in consts_dot):
msg = ("Detected differentiation with respect to closed-over values with "
"custom JVP rule, which isn't supported.")
raise ValueError(msg)
args = tree_unflatten(in_tree, args_flat)
args_dot = tree_unflatten(in_tree, args_dot_flat)
out, out_dot = custom_jvp(args, args_dot)
out_flat, out_tree = tree_flatten(out)
out_dot_flat, out_tree2 = tree_flatten(out_dot)
if out_tree != out_tree2:
msg = ("Custom JVP rule returned different tree structures for primals "
"and tangents, but they must be equal: {} and {}.")
raise TypeError(msg.format(out_tree, out_tree2))
return out_flat, out_dot_flat
ad.primitive_jvps[fun.prim] = custom_transforms_jvp
def defjvp(fun, *jvprules):
"""Definine JVP rules for each argument separately.
This function is a convenience wrapper around ``jax.defjvp_all`` for
separately defining JVP rules for each of the function's arguments. This
convenience wrapper does not provide a mechanism for depending on anything
other than the function arguments and its primal output value, though
depending on intermediate results is possible using ``jax.defjvp_all``.
The signature of each component JVP rule is ``lambda g, ans, *primals: ...``
where ``g`` represents the tangent of the corresponding positional argument,
``ans`` represents the output primal, and ``*primals`` represents all the
primal positional arguments.
Defining a custom JVP rule also affects the default VJP rule, which is derived
from the JVP rule automatically via transposition.
Args:
fun: a custom_transforms function.
*jvprules: a sequence of functions or Nones specifying the JVP rule for each
corresponding positional argument. When an element is None, it indicates
that the Jacobian from the corresponding input to the output is zero.
Returns:
None. A side-effect is that ``fun`` is associated with the JVP rule
specified by ``*jvprules``.
For example:
>>> @jax.custom_transforms
... def f(x):
... return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
-10.933563
>>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans)
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
16.412119
"""
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
_check_custom_transforms_type("defjvp", fun)
def custom_jvp(primals, tangents):
ans = fun(*primals)
tangents_out = [rule(t, ans, *primals) for rule, t in zip(jvprules, tangents)
if rule is not None and t is not ad_util.zero]
return ans, functools.reduce(ad.add_tangents, tangents_out, ad_util.zero)
defjvp_all(fun, custom_jvp)
def defvjp_all(fun, custom_vjp):
"""Define a custom VJP rule for a ``custom_transforms`` function.
If ``fun`` represents a function with signature ``a -> b``, then
``custom_vjp`` represents a function with signature ``a -> (b, CT b -> CT a)``
where we use ``CT x`` to represent a cotangent type for the type ``x``. That
is, ``custom_vjp`` should take the same arguments as ``fun`` and return a pair
where the first element represents the primal value of ``fun`` applied to the
arguments, and the second element is a VJP function that maps from output
cotangents to input cotangents, returning a tuple with length equal to the
number of positional arguments supplied to ``fun``.
The VJP function returned as the second element of the output of
``custom_vjp`` can close over intermediate values computed when evaluating the
primal value of ``fun``. That is, use lexical closure to share work between
the forward pass and the backward pass of reverse-mode automatic
differentiation.
See also ``jax.custom_gradient``.
Args:
fun: a custom_transforms function.
custom_vjp: a Python callable specifying the VJP rule, taking the same
arguments as ``fun`` and returning a pair where the first element is the
value of ``fun`` applied to the arguments and the second element is a
Python callable representing the VJP map from output cotangents to input
cotangents. The returned VJP function must accept a value with the same
shape as the value of ``fun`` applied to the arguments and must return a
tuple with length equal to the number of positional arguments to ``fun``.
Arguments can be arrays, scalars, or (nested) standard Python containers
(tuple/list/dict) thereof. Must be functionally pure.
Returns:
None. A side-effect is that ``fun`` is associated with the VJP rule
specified by ``custom_vjp``.
For example:
>>> @jax.custom_transforms
... def f(x):
... return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> print(jax.grad(f)(3.))
-5.4667816
>>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,)))
>>> print(f(3.))
0.4121185
>>> print(jax.grad(f)(3.))
3.0
An example with a function on two arguments, so that the VJP function must
return a tuple of length two:
>>> @jax.custom_transforms
... def f(x, y):
... return x * y
...
>>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x)))
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(4.0, 3.0)
"""
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
_check_custom_transforms_type("defvjp_all", fun)
def custom_transforms_vjp(*consts_and_args, **params):
num_consts, in_tree = params['num_consts'], params['in_tree']
consts, args_flat = split_list(consts_and_args, [num_consts])
args = tree_unflatten(params['in_tree'], args_flat)
out, vjp = custom_vjp(*args)
out_flat, out_tree = tree_flatten(out)
if out_tree != params['out_tree']:
msg = (
"First output of `custom_vjp`: {} doesn't match the structure of "
"the output of `fun`: {}\n"
"{}\n"
"vs\n"
"{}\n".format(custom_vjp, fun, out_tree, params['out_tree'])
)
raise TypeError(msg)
def vjp_flat(*cts_flat):
cts = tree_unflatten(out_tree, cts_flat)
args_cts_flat, in_tree2 = tree_flatten(vjp(cts))
if in_tree != in_tree2:
msg = (
"Output of the `vjp`: {} doesn't match the structure of args of "
"`fun`: {}\n"
"{}\n"
"vs\n"
"{}\n".format(vjp, fun, in_tree2, in_tree)
)
raise TypeError(msg)
return [core.unit] * num_consts + list(args_cts_flat)
return out_flat, vjp_flat
ad.defvjp_all(fun.prim, custom_transforms_vjp)
def defvjp(fun, *vjprules):
"""Define VJP rules for each argument separately.
This function is a convenience wrapper around ``jax.defvjp_all`` for
separately defining VJP rules for each of the function's arguments. This
convenience wrapper does not provide a mechanism for depending on anything
other than the function arguments and its primal output value, though
depending on intermediate results is possible using ``jax.defvjp_all``.
The signature of each component VJP rule is ``lambda g, ans, *primals: ...``
where ``g`` represents the output cotangent, ``ans`` represents the output
primal, and ``*primals`` represents all the primal positional arguments.
Args:
fun: a custom_transforms function.
*vjprules: a sequence of functions or Nones specifying the VJP rule for each
corresponding positional argument. When an element is None, it indicates
that the Jacobian from the corresponding input to the output is zero.
Returns:
None. A side-effect is that ``fun`` is associated with the VJP rule
specified by ``*vjprules``.
For example:
>>> @jax.custom_transforms
... def f(x, y):
... return np.sin(x ** 2 + y)
...
>>> print(f(3., 4.))
0.42016703
>>> print(jax.grad(f)(3., 4.))
5.4446807
>>> print(jax.grad(f, 1)(3., 4.))
0.9074468
>>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans)
>>> print(jax.grad(f)(3., 4.))
0.0
>>> print(jax.grad(f, 1)(3., 4.))
8.420167
"""
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
_check_custom_transforms_type("defvjp", fun)
def custom_vjp(*primals):
ans = fun(*primals)
# TODO(mattjj): avoid instantiating zeros?
def vjpfun(ct):
return tuple(vjp(ct, ans, *primals) if vjp else ad_util.zeros_like_jaxval(x)
for x, vjp in zip(primals, vjprules))
return ans, vjpfun
defvjp_all(fun, custom_vjp)
def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
While the canonical way to define custom VJP rules is via ``jax.defvjp_all``
and its convenience wrappers, the ``custom_gradient`` convenience wrapper
follows TensorFlow's ``tf.custom_gradient`` API. The difference here is that
``custom_gradient`` can be used as a decorator on one function that returns
both the primal value (representing the output of the mathematical function to
be differentiated) and the VJP (gradient) function.
See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
If the mathematical function to be differentiated has type signature
``a -> b``, then the Python callable ``fun`` should have signature
``a -> (b, CT b -> CT a)`` where we use ``CT x`` to denote a cotangent type
for ``x``. See the example below. That is, ``fun`` should return a pair where
the first element represents the value of the mathematical function to be
differentiated and the second element is a function that represents the custom
VJP rule.
The custom VJP function returned as the second element of the output of ``fun``
can close over intermediate values computed when evaluating the function to be
differentiated. That is, use lexical closure to share work between the forward
pass and the backward pass of reverse-mode automatic differentiation.
Args:
fun: a Python callable specifying both the mathematical function to be
differentiated and its reverse-mode differentiation rule. It should return
a pair consisting of an output value and a Python callable that represents
the custom gradient function.
Returns:
A Python callable with signature ``a -> b``, i.e. that returns the output
value specified by the first element of ``fun``'s output pair. A side effect
is that under-the-hood ``jax.defvjp_all`` is called to set up the returned
Python callable with the custom VJP rule specified by the second element
of ``fun``'s output pair.
For example:
>>> @jax.custom_gradient
... def f(x):
... return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
>>> print(jax.grad(f)(3.))
3.0
An example with a function on two arguments, so that the VJP function must
return a tuple of length two:
>>> @jax.custom_gradient
... def f(x, y):
... return x * y, lambda g: (y, x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(4.0, 3.0)
"""
def primal_fun(*args, **kwargs):
ans, _ = fun(*args, **kwargs)
return ans
primal_fun = custom_transforms(primal_fun)
defvjp_all(primal_fun, fun)
return primal_fun
def jarrett(fun):
new_fun = custom_transforms(fun)
def elementwise_jvp(primals, tangents):
pushfwd = partial(jvp, fun, primals)
y, jacs = vmap(pushfwd, out_axes=(None, 0))(_elementwise_std_basis(tangents))
flat_tangents, _ = tree_flatten(tangents)
out_tangent = sum([t * jac for t, jac in zip(flat_tangents, jacs)])
return y, out_tangent
defjvp_all(new_fun, elementwise_jvp)
return new_fun
def _elementwise_std_basis(pytree):
leaves, _ = tree_flatten(pytree)
arity = len(leaves)
dims = map(onp.size, leaves)
# TODO(mattjj): use symbolic constants
dtype = dtypes.result_type(*leaves)
if not dtypes.issubdtype(dtype, onp.floating):
msg = ("Jacobian only defined for functions with floating input and output "
"dtypes (i.e. dtypes that model real numbers), got {}.")
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
basis_array = onp.stack([onp.concatenate(
[onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
for j in range(arity)]) for i in range(arity)])
return _unravel_array_into_pytree(pytree, 1, basis_array)
# This function mostly exists for making slides about JAX.
def _make_graphviz(fun):
"""Adapts `fun` to return a graphviz dot string of its program representation.
Args:
fun: The function whose `jaxpr` is to be rendered into graphviz dot. Its
positional arguments and return value should be arrays, scalars, or
standard Python containers (tuple/list/dict) thereof.
Returns:
A wrapped version of `fun`, set up to return a graphviz dot string.
See make_jaxpr for a related function.
"""
# TODO(mattjj): handle eqn.restructure
# TODO(mattjj): handle subjaxprs
def pv_like(x):
aval = xla.abstractify(x)
return pe.PartialVal((aval, core.unit))
id_names = ("id{}".format(i) for i in it.count())
def jaxpr_to_graphviz(jaxpr, consts):
fragment = []
fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
fragment.extend(map(constant_node, jaxpr.constvars, consts))
for eqn in jaxpr.eqns:
id_name = next(id_names)
fragment.append(function_node(id_name, eqn.primitive.name))
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
for ov in jaxpr.outvars:
fragment.append(outvar_node(ov, "out"))
return graph(''.join(fragment))
edge = '{} -> {} [color=gray30];\n'.format
function_node = '{} [label="{}", shape=box, color=lightskyblue, style=filled];\n'.format
invar_node = '{} [rank=2, label="{}", color=mediumspringgreen, style=filled];\n'.format
outvar_node = '{} [label="{}", fillcolor=indianred1, style="filled,dashed", color=black];\n'.format
constant_node = '{} [rank=2, label="{}", color=goldenrod1, style=filled];\n'.format
freevar_node = '{} [rank=2, label="{}", color=palegreen, style=filled];\n'.format
graph = 'digraph G {{{}}}'.format
@wraps(fun)
def graphviz_maker(*args, **kwargs):
wrapped = lu.wrap_init(fun, kwargs)
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
pvals = map(pv_like, jax_args)
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
return jaxpr_to_graphviz(jaxpr, consts)
graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)
return graphviz_maker
class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):

View File

@ -435,6 +435,65 @@ def add_tangents(x, y):
return add_jaxvals(x, y)
def defvjp_all(prim, custom_vjp):
# see https://github.com/google/jax/pull/636
name = prim.name
def fun_jvp(xs, ts, **params):
ts = map(instantiate_zeros, xs, ts)
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
if prim.multiple_results:
return primals, tangents
else:
primal, = primals
tangent, = tangents
return primal, tangent
primitive_jvps[prim] = fun_jvp
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
fun_jvp_p.multiple_results = True
def fun_jvp_partial_eval(trace, *tracers, **params):
primals, tangents = split_list(tracers, [len(tracers) // 2])
primals_out, vjp_py = custom_vjp(*primals, **params)
if not prim.multiple_results:
primals_out = [primals_out]
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
num_res=len(res), out_avals=out_avals)
return primals_out + tangents_out
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
fun_lin_p.multiple_results = True
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
def fun_lin_transpose(cts, *args, **kwargs):
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
res, _ = split_list(args, [num_res])
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
return [None] * num_res + outs
primitive_transposes[fun_lin_p] = fun_lin_transpose
def defvjp(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defvjp2(prim, *vjps):
def vjpmaker(*primals):
ans = prim.bind(*primals)
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
for x, vjp in zip(primals, vjps)]
return ans, vjpfun
defvjp_all(prim, vjpmaker)
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
assert isinstance(prim, Primitive)
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)

View File

@ -596,6 +596,200 @@ class APITest(jtu.JaxTestCase):
def test_complex_input_jacfwd_raises_error(self):
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
def test_defvjp_all(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
def test_defvjp_all_const(self):
foo_p = Primitive('foo')
def foo(x): return foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, 9., check_dtypes=False)
self.assertAllClose(grad_ans, 12., check_dtypes=True)
def test_defvjp_all_higher_order_revmode(self):
foo_p = Primitive('foo')
def foo(x): return 2. * foo_p.bind(x)
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
ans = api.grad(api.grad(foo))(3.)
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
def test_defvjp_all_multiple_arguments(self):
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
# the first argument in one case
foo_p = Primitive('foo')
def foo(x, y): return foo_p.bind(x, y)
def vjpfun(x, y):
out = x**2 + y**3
vjp = lambda g: (g + x + y, g * x * 9.)
return out, vjp
ad.defvjp_all(foo_p, vjpfun)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
ans = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
def test_defvjp_all_custom_transforms(self):
@api.custom_transforms
def foo(x):
return np.sin(x)
api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x,)))
val_ans, grad_ans = api.value_and_grad(foo)(3.)
self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
self.assertAllClose(grad_ans, 3., check_dtypes=False)
# TODO(mattjj): add defvjp_all test with pytree arguments
def test_defvjp(self):
@api.custom_transforms
def foo(x, y):
return np.sin(x * y)
api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 0., check_dtypes=False)
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans_0, 0., check_dtypes=False)
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
def test_defvjp_higher_order(self):
@api.custom_transforms
def foo(x):
return np.sin(2. * x)
api.defvjp(foo, lambda g, _, x: g * np.cos(x))
ans = api.grad(api.grad(foo))(2.)
expected = api.grad(api.grad(np.sin))(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_defvjp_use_ans(self):
@api.custom_transforms
def foo(x, y):
return np.sin(x * y)
api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
check_dtypes=False)
# TODO
# def test_defjvp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
# "Detected differentiation with respect to closed-over values with "
# "custom JVP rule, which isn't supported.")
# TODO
# def test_defvjp_closure_error(self):
# def foo(x):
# @api.custom_transforms
# def bar(y):
# return x * y
# api.defvjp(bar, lambda g, ans, y: x * y)
# return bar(x)
# jtu.check_raises(
# lambda: grad(foo)(1.,), ValueError,
# "Detected differentiation w.r.t. variables from outside "
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
# "defvjp_all only support differentiation w.r.t. positional arguments.")
def test_custom_transforms_eval_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = f((1, 2))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
def test_custom_transforms_jit_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = jit(f)((1, 2))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
def test_custom_transforms_jit_with_pytrees_consts(self):
# The purpose of this test is to exercise the custom_transforms default
# translation rule in how it deals with constants that are too large to be
# treated as literals (at the time of writing).
z = onp.arange(10.)
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': z * b}
ans = jit(f)((1, 2))
self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
def test_custom_transforms_jvp_with_pytrees(self):
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
def test_custom_transforms_vmap_with_pytrees(self):
raise unittest.SkipTest("Test deprecated custom_transforms")
@api.custom_transforms
def f(x):
a, b = x[0], x[1]
return {'hi': 2 * a, 'bye': 2 * b}
ans = api.vmap(f)((onp.arange(3), onp.ones((3, 2))))
expected = {'hi': 2 * onp.arange(3), 'bye': 2 * onp.ones((3, 2))}
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_transforms_jvp_with_closure(self):
def f(x):
@api.custom_transforms
def g(y):
return x * y
return g(x)
ans = api.grad(f)(1.)
expected = 2.
self.assertAllClose(ans, expected, check_dtypes=False)
def test_custom_gradient(self):
@api.custom_gradient
def f(x):
return x ** 2, lambda g: (g * x,)
self.assertAllClose(f(3.), 9., check_dtypes=False)
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
def test_legacy_devicearray_repr(self):
dx = device_put(3.)
str(dx.item()) # doesn't crash
@ -956,6 +1150,55 @@ class APITest(jtu.JaxTestCase):
check_warning(lambda: np.tri(2, dtype="float64"),
lambda: np.tri(2, dtype="float32"))
def test_custom_vjp_zeros(self):
@api.custom_transforms
def f(x, y):
return 2 * x, 3 * y
def f_vjp(x, y):
return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
api.defvjp_all(f, f_vjp, )
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
def test_custom_transforms_vjp_nones(self):
# issue rasied by jsnoek@ and jumper@
@jax.custom_transforms
def solve(a, b):
return np.dot(np.linalg.inv(a), b)
# print(solve(a, b))
def solve_vjp(a, b):
x = solve(a, b)
def vjp(x_tangent):
dx = np.dot(solve(a, x_tangent), x.T)
out = (dx, b * 0.)
return out
return x, vjp
jax.defvjp_all(solve, solve_vjp)
gf = grad(lambda a,b: np.sum(solve(a, b)))
n = 3
a_in = np.linspace(0, 1, n)[:, None]
a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1
real_x = onp.random.RandomState(0).randn(n)
b = np.dot(a + np.eye(a.shape[0]), real_x)
print(gf(a, b)) # doesn't crash
def test_vmap_in_axes_list(self):
# https://github.com/google/jax/issues/2367
dictionary = {'a': 5., 'b': np.ones(2)}
x = np.zeros(3)
y = np.arange(3.)
def f(dct, x, y):
return dct['a'] + dct['b'] + x + y
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
self.assertAllClose(out1, out2, check_dtypes=True)
def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
self.assertRaisesRegex(