mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add back support for custom_transforms (#2484)
* add also the tests * mark the old APIs as deprecated
This commit is contained in:
parent
069cb3e2fb
commit
f658eb5bf5
534
jax/api.py
534
jax/api.py
@ -1460,6 +1460,540 @@ def _valid_jaxtype(arg):
|
||||
return True
|
||||
|
||||
|
||||
class CustomTransformsFunction(object):
|
||||
def __init__(self, fun, prim):
|
||||
self.fun = fun
|
||||
self.prim = prim
|
||||
wraps(fun)(self)
|
||||
|
||||
def __repr__(self):
|
||||
return '<jax.custom_transforms function {fun}>'.format(fun=self.__name__)
|
||||
|
||||
def __call__(self, *args):
|
||||
# TODO(mattjj): instead of tracing to a jaxpr, use process_call
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_pvals = [pe.PartialVal((raise_to_shaped(core.get_aval(x)), core.unit))
|
||||
for x in args_flat]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
|
||||
in_tree=in_tree, out_tree=out_tree(),
|
||||
num_consts=len(consts))
|
||||
return tree_unflatten(out_tree(), outs)
|
||||
|
||||
def custom_transforms(fun):
|
||||
"""Wraps a function so that its transformation behavior can be controlled.
|
||||
|
||||
A primary use case of ``custom_transforms`` is defining custom VJP rules (aka
|
||||
custom gradients) for a Python function, while still supporting other
|
||||
transformations like ``jax.jit`` and ``jax.vmap``. Custom differentiation
|
||||
rules can be supplied using the ``jax.defjvp`` and ``jax.defvjp`` functions.
|
||||
|
||||
The ``custom_transforms`` decorator wraps ``fun`` so that its transformation
|
||||
behavior can be overridden, but not all transformation rules need to be
|
||||
specified manually. The default behavior is retained for any non-overridden
|
||||
rules.
|
||||
|
||||
The function ``fun`` must satisfy the same constraints required for jit
|
||||
compilation. In particular the shapes of arrays in the computation of ``fun``
|
||||
may depend on the shapes of ``fun``'s arguments, but not their values.
|
||||
Value dependent Python control flow is also not yet supported.
|
||||
|
||||
Args:
|
||||
fun: a Python callable. Must be functionally pure. Its arguments and return
|
||||
value should be arrays, scalars, or (nested) standard Python containers
|
||||
(tuple/list/dict) thereof.
|
||||
|
||||
Returns:
|
||||
A Python callable with the same input/output and transformation behavior as
|
||||
``fun``, but for which custom transformation rules can be supplied, e.g.
|
||||
using ``jax.defvjp``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x):
|
||||
... return np.sin(x ** 2)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> print(jax.grad(f)(3.))
|
||||
-5.4667816
|
||||
>>> jax.defvjp(f, lambda g, x: g * x)
|
||||
>>> print(jax.grad(f)(3.))
|
||||
3.0
|
||||
"""
|
||||
|
||||
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
|
||||
name = getattr(fun, '__name__', '<unnamed custom_transforms primitive>')
|
||||
fun_p = core.Primitive(name)
|
||||
fun_p.multiple_results = True
|
||||
|
||||
def fun_impl(*args, **params):
|
||||
consts, args = split_list(args, [params['num_consts']])
|
||||
return core.eval_jaxpr(params['jaxpr'], consts, *args)
|
||||
fun_p.def_impl(fun_impl)
|
||||
|
||||
def fun_jvp(primals, tangents, **params):
|
||||
return ad.jvp(lu.wrap_init(fun_impl, params)).call_wrapped(primals, tangents)
|
||||
ad.primitive_jvps[fun_p] = fun_jvp
|
||||
|
||||
def fun_batch(args, dims, **params):
|
||||
return batching.batch_fun(lu.wrap_init(fun_impl, params), args, dims)
|
||||
batching.primitive_batchers[fun_p] = fun_batch
|
||||
|
||||
def fun_abstract_eval(*avals, **params):
|
||||
return pe.abstract_eval_fun(fun_impl, *avals, **params)
|
||||
fun_p.def_abstract_eval(fun_abstract_eval)
|
||||
|
||||
def fun_translation(c, *xla_args, **params):
|
||||
return xla.lower_fun(fun_impl)(c, *xla_args, **params)
|
||||
xla.translations[fun_p] = fun_translation
|
||||
|
||||
return CustomTransformsFunction(fun, fun_p)
|
||||
|
||||
def _check_custom_transforms_type(name, fun):
|
||||
if type(fun) is not CustomTransformsFunction:
|
||||
msg = ("{} requires a custom_transforms function as its first argument, "
|
||||
"but got type {}.")
|
||||
raise TypeError(msg.format(name, type(fun)))
|
||||
|
||||
def defjvp_all(fun, custom_jvp):
|
||||
"""Define a custom JVP rule for a ``custom_transforms`` function.
|
||||
|
||||
If ``fun`` represents a function with signature ``a -> b``, then
|
||||
``custom_jvp`` represents a function with signature ``(a, T a) -> (b, T b)``,
|
||||
where we use ``T x`` to represent a tangent type for the type ``x``.
|
||||
|
||||
In more detail, ``custom_jvp`` must take two arguments, both tuples of length
|
||||
equal to the number of positional arguments to ``fun``. The first argument to
|
||||
``custom_jvp`` represents the input primal values, and the second represents
|
||||
the input tangent values. ``custom_jvp`` must return a pair where the first
|
||||
element represents the output primal value and the second element represents
|
||||
the output tangent value.
|
||||
|
||||
Defining a custom JVP rule also affects the default VJP rule, which is derived
|
||||
from the JVP rule automatically via transposition.
|
||||
|
||||
Args:
|
||||
fun: a custom_transforms function.
|
||||
custom_jvp: a Python callable specifying the JVP rule, taking two tuples as
|
||||
arguments specifying the input primal values and tangent values,
|
||||
respectively. The tuple elements can be arrays, scalars, or (nested)
|
||||
standard Python containers (tuple/list/dict) thereof. The output must be a
|
||||
pair representing the primal output and tangent output, which can be
|
||||
arrays, scalars, or (nested) standard Python containers. Must be
|
||||
functionally pure.
|
||||
|
||||
Returns:
|
||||
None. A side-effect is that ``fun`` is associated with the JVP rule
|
||||
specified by ``custom_jvp``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x):
|
||||
... return np.sin(x ** 2)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
-10.933563
|
||||
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
16.0
|
||||
"""
|
||||
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
|
||||
_check_custom_transforms_type("defjvp_all", fun)
|
||||
def custom_transforms_jvp(primals, tangents, **params):
|
||||
num_consts, in_tree = params['num_consts'], params['in_tree']
|
||||
_, args_flat = split_list(primals, [num_consts])
|
||||
consts_dot, args_dot_flat = split_list(tangents, [num_consts])
|
||||
if not all(t is ad_util.zero for t in consts_dot):
|
||||
msg = ("Detected differentiation with respect to closed-over values with "
|
||||
"custom JVP rule, which isn't supported.")
|
||||
raise ValueError(msg)
|
||||
args = tree_unflatten(in_tree, args_flat)
|
||||
args_dot = tree_unflatten(in_tree, args_dot_flat)
|
||||
out, out_dot = custom_jvp(args, args_dot)
|
||||
out_flat, out_tree = tree_flatten(out)
|
||||
out_dot_flat, out_tree2 = tree_flatten(out_dot)
|
||||
if out_tree != out_tree2:
|
||||
msg = ("Custom JVP rule returned different tree structures for primals "
|
||||
"and tangents, but they must be equal: {} and {}.")
|
||||
raise TypeError(msg.format(out_tree, out_tree2))
|
||||
return out_flat, out_dot_flat
|
||||
ad.primitive_jvps[fun.prim] = custom_transforms_jvp
|
||||
|
||||
def defjvp(fun, *jvprules):
|
||||
"""Definine JVP rules for each argument separately.
|
||||
|
||||
This function is a convenience wrapper around ``jax.defjvp_all`` for
|
||||
separately defining JVP rules for each of the function's arguments. This
|
||||
convenience wrapper does not provide a mechanism for depending on anything
|
||||
other than the function arguments and its primal output value, though
|
||||
depending on intermediate results is possible using ``jax.defjvp_all``.
|
||||
|
||||
The signature of each component JVP rule is ``lambda g, ans, *primals: ...``
|
||||
where ``g`` represents the tangent of the corresponding positional argument,
|
||||
``ans`` represents the output primal, and ``*primals`` represents all the
|
||||
primal positional arguments.
|
||||
|
||||
Defining a custom JVP rule also affects the default VJP rule, which is derived
|
||||
from the JVP rule automatically via transposition.
|
||||
|
||||
Args:
|
||||
fun: a custom_transforms function.
|
||||
*jvprules: a sequence of functions or Nones specifying the JVP rule for each
|
||||
corresponding positional argument. When an element is None, it indicates
|
||||
that the Jacobian from the corresponding input to the output is zero.
|
||||
|
||||
Returns:
|
||||
None. A side-effect is that ``fun`` is associated with the JVP rule
|
||||
specified by ``*jvprules``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x):
|
||||
... return np.sin(x ** 2)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
-10.933563
|
||||
>>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans)
|
||||
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
|
||||
>>> print(out_primal)
|
||||
0.4121185
|
||||
>>> print(out_tangent)
|
||||
16.412119
|
||||
"""
|
||||
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
|
||||
_check_custom_transforms_type("defjvp", fun)
|
||||
def custom_jvp(primals, tangents):
|
||||
ans = fun(*primals)
|
||||
tangents_out = [rule(t, ans, *primals) for rule, t in zip(jvprules, tangents)
|
||||
if rule is not None and t is not ad_util.zero]
|
||||
return ans, functools.reduce(ad.add_tangents, tangents_out, ad_util.zero)
|
||||
defjvp_all(fun, custom_jvp)
|
||||
|
||||
def defvjp_all(fun, custom_vjp):
|
||||
"""Define a custom VJP rule for a ``custom_transforms`` function.
|
||||
|
||||
If ``fun`` represents a function with signature ``a -> b``, then
|
||||
``custom_vjp`` represents a function with signature ``a -> (b, CT b -> CT a)``
|
||||
where we use ``CT x`` to represent a cotangent type for the type ``x``. That
|
||||
is, ``custom_vjp`` should take the same arguments as ``fun`` and return a pair
|
||||
where the first element represents the primal value of ``fun`` applied to the
|
||||
arguments, and the second element is a VJP function that maps from output
|
||||
cotangents to input cotangents, returning a tuple with length equal to the
|
||||
number of positional arguments supplied to ``fun``.
|
||||
|
||||
The VJP function returned as the second element of the output of
|
||||
``custom_vjp`` can close over intermediate values computed when evaluating the
|
||||
primal value of ``fun``. That is, use lexical closure to share work between
|
||||
the forward pass and the backward pass of reverse-mode automatic
|
||||
differentiation.
|
||||
|
||||
See also ``jax.custom_gradient``.
|
||||
|
||||
Args:
|
||||
fun: a custom_transforms function.
|
||||
custom_vjp: a Python callable specifying the VJP rule, taking the same
|
||||
arguments as ``fun`` and returning a pair where the first element is the
|
||||
value of ``fun`` applied to the arguments and the second element is a
|
||||
Python callable representing the VJP map from output cotangents to input
|
||||
cotangents. The returned VJP function must accept a value with the same
|
||||
shape as the value of ``fun`` applied to the arguments and must return a
|
||||
tuple with length equal to the number of positional arguments to ``fun``.
|
||||
Arguments can be arrays, scalars, or (nested) standard Python containers
|
||||
(tuple/list/dict) thereof. Must be functionally pure.
|
||||
|
||||
Returns:
|
||||
None. A side-effect is that ``fun`` is associated with the VJP rule
|
||||
specified by ``custom_vjp``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x):
|
||||
... return np.sin(x ** 2)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> print(jax.grad(f)(3.))
|
||||
-5.4667816
|
||||
>>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,)))
|
||||
>>> print(f(3.))
|
||||
0.4121185
|
||||
>>> print(jax.grad(f)(3.))
|
||||
3.0
|
||||
|
||||
An example with a function on two arguments, so that the VJP function must
|
||||
return a tuple of length two:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x, y):
|
||||
... return x * y
|
||||
...
|
||||
>>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x)))
|
||||
>>> print(f(3., 4.))
|
||||
12.0
|
||||
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
||||
(4.0, 3.0)
|
||||
"""
|
||||
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
|
||||
_check_custom_transforms_type("defvjp_all", fun)
|
||||
def custom_transforms_vjp(*consts_and_args, **params):
|
||||
num_consts, in_tree = params['num_consts'], params['in_tree']
|
||||
consts, args_flat = split_list(consts_and_args, [num_consts])
|
||||
args = tree_unflatten(params['in_tree'], args_flat)
|
||||
out, vjp = custom_vjp(*args)
|
||||
out_flat, out_tree = tree_flatten(out)
|
||||
if out_tree != params['out_tree']:
|
||||
msg = (
|
||||
"First output of `custom_vjp`: {} doesn't match the structure of "
|
||||
"the output of `fun`: {}\n"
|
||||
"{}\n"
|
||||
"vs\n"
|
||||
"{}\n".format(custom_vjp, fun, out_tree, params['out_tree'])
|
||||
)
|
||||
raise TypeError(msg)
|
||||
def vjp_flat(*cts_flat):
|
||||
cts = tree_unflatten(out_tree, cts_flat)
|
||||
args_cts_flat, in_tree2 = tree_flatten(vjp(cts))
|
||||
if in_tree != in_tree2:
|
||||
msg = (
|
||||
"Output of the `vjp`: {} doesn't match the structure of args of "
|
||||
"`fun`: {}\n"
|
||||
"{}\n"
|
||||
"vs\n"
|
||||
"{}\n".format(vjp, fun, in_tree2, in_tree)
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return [core.unit] * num_consts + list(args_cts_flat)
|
||||
return out_flat, vjp_flat
|
||||
ad.defvjp_all(fun.prim, custom_transforms_vjp)
|
||||
|
||||
def defvjp(fun, *vjprules):
|
||||
"""Define VJP rules for each argument separately.
|
||||
|
||||
This function is a convenience wrapper around ``jax.defvjp_all`` for
|
||||
separately defining VJP rules for each of the function's arguments. This
|
||||
convenience wrapper does not provide a mechanism for depending on anything
|
||||
other than the function arguments and its primal output value, though
|
||||
depending on intermediate results is possible using ``jax.defvjp_all``.
|
||||
|
||||
The signature of each component VJP rule is ``lambda g, ans, *primals: ...``
|
||||
where ``g`` represents the output cotangent, ``ans`` represents the output
|
||||
primal, and ``*primals`` represents all the primal positional arguments.
|
||||
|
||||
Args:
|
||||
fun: a custom_transforms function.
|
||||
*vjprules: a sequence of functions or Nones specifying the VJP rule for each
|
||||
corresponding positional argument. When an element is None, it indicates
|
||||
that the Jacobian from the corresponding input to the output is zero.
|
||||
|
||||
Returns:
|
||||
None. A side-effect is that ``fun`` is associated with the VJP rule
|
||||
specified by ``*vjprules``.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_transforms
|
||||
... def f(x, y):
|
||||
... return np.sin(x ** 2 + y)
|
||||
...
|
||||
>>> print(f(3., 4.))
|
||||
0.42016703
|
||||
>>> print(jax.grad(f)(3., 4.))
|
||||
5.4446807
|
||||
>>> print(jax.grad(f, 1)(3., 4.))
|
||||
0.9074468
|
||||
>>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans)
|
||||
>>> print(jax.grad(f)(3., 4.))
|
||||
0.0
|
||||
>>> print(jax.grad(f, 1)(3., 4.))
|
||||
8.420167
|
||||
"""
|
||||
warn("custom_transforms is deprecated and replaced by custom_vjp/custom_jvp.")
|
||||
_check_custom_transforms_type("defvjp", fun)
|
||||
def custom_vjp(*primals):
|
||||
ans = fun(*primals)
|
||||
# TODO(mattjj): avoid instantiating zeros?
|
||||
def vjpfun(ct):
|
||||
return tuple(vjp(ct, ans, *primals) if vjp else ad_util.zeros_like_jaxval(x)
|
||||
for x, vjp in zip(primals, vjprules))
|
||||
return ans, vjpfun
|
||||
defvjp_all(fun, custom_vjp)
|
||||
|
||||
def custom_gradient(fun):
|
||||
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
||||
|
||||
While the canonical way to define custom VJP rules is via ``jax.defvjp_all``
|
||||
and its convenience wrappers, the ``custom_gradient`` convenience wrapper
|
||||
follows TensorFlow's ``tf.custom_gradient`` API. The difference here is that
|
||||
``custom_gradient`` can be used as a decorator on one function that returns
|
||||
both the primal value (representing the output of the mathematical function to
|
||||
be differentiated) and the VJP (gradient) function.
|
||||
|
||||
See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
|
||||
|
||||
If the mathematical function to be differentiated has type signature
|
||||
``a -> b``, then the Python callable ``fun`` should have signature
|
||||
``a -> (b, CT b -> CT a)`` where we use ``CT x`` to denote a cotangent type
|
||||
for ``x``. See the example below. That is, ``fun`` should return a pair where
|
||||
the first element represents the value of the mathematical function to be
|
||||
differentiated and the second element is a function that represents the custom
|
||||
VJP rule.
|
||||
|
||||
The custom VJP function returned as the second element of the output of ``fun``
|
||||
can close over intermediate values computed when evaluating the function to be
|
||||
differentiated. That is, use lexical closure to share work between the forward
|
||||
pass and the backward pass of reverse-mode automatic differentiation.
|
||||
|
||||
Args:
|
||||
fun: a Python callable specifying both the mathematical function to be
|
||||
differentiated and its reverse-mode differentiation rule. It should return
|
||||
a pair consisting of an output value and a Python callable that represents
|
||||
the custom gradient function.
|
||||
|
||||
Returns:
|
||||
A Python callable with signature ``a -> b``, i.e. that returns the output
|
||||
value specified by the first element of ``fun``'s output pair. A side effect
|
||||
is that under-the-hood ``jax.defvjp_all`` is called to set up the returned
|
||||
Python callable with the custom VJP rule specified by the second element
|
||||
of ``fun``'s output pair.
|
||||
|
||||
For example:
|
||||
|
||||
>>> @jax.custom_gradient
|
||||
... def f(x):
|
||||
... return x ** 2, lambda g: (g * x,)
|
||||
...
|
||||
>>> print(f(3.))
|
||||
9.0
|
||||
>>> print(jax.grad(f)(3.))
|
||||
3.0
|
||||
|
||||
An example with a function on two arguments, so that the VJP function must
|
||||
return a tuple of length two:
|
||||
|
||||
>>> @jax.custom_gradient
|
||||
... def f(x, y):
|
||||
... return x * y, lambda g: (y, x)
|
||||
...
|
||||
>>> print(f(3., 4.))
|
||||
12.0
|
||||
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
|
||||
(4.0, 3.0)
|
||||
"""
|
||||
def primal_fun(*args, **kwargs):
|
||||
ans, _ = fun(*args, **kwargs)
|
||||
return ans
|
||||
primal_fun = custom_transforms(primal_fun)
|
||||
defvjp_all(primal_fun, fun)
|
||||
return primal_fun
|
||||
|
||||
|
||||
def jarrett(fun):
|
||||
new_fun = custom_transforms(fun)
|
||||
|
||||
def elementwise_jvp(primals, tangents):
|
||||
pushfwd = partial(jvp, fun, primals)
|
||||
y, jacs = vmap(pushfwd, out_axes=(None, 0))(_elementwise_std_basis(tangents))
|
||||
flat_tangents, _ = tree_flatten(tangents)
|
||||
out_tangent = sum([t * jac for t, jac in zip(flat_tangents, jacs)])
|
||||
return y, out_tangent
|
||||
defjvp_all(new_fun, elementwise_jvp)
|
||||
|
||||
return new_fun
|
||||
|
||||
def _elementwise_std_basis(pytree):
|
||||
leaves, _ = tree_flatten(pytree)
|
||||
arity = len(leaves)
|
||||
dims = map(onp.size, leaves)
|
||||
# TODO(mattjj): use symbolic constants
|
||||
dtype = dtypes.result_type(*leaves)
|
||||
if not dtypes.issubdtype(dtype, onp.floating):
|
||||
msg = ("Jacobian only defined for functions with floating input and output "
|
||||
"dtypes (i.e. dtypes that model real numbers), got {}.")
|
||||
raise TypeError(msg.format(dtype)) # TODO(mattjj, dougalm): handle complex
|
||||
basis_array = onp.stack([onp.concatenate(
|
||||
[onp.ones(dims[j], dtype) if i == j else onp.zeros(dims[j], dtype)
|
||||
for j in range(arity)]) for i in range(arity)])
|
||||
return _unravel_array_into_pytree(pytree, 1, basis_array)
|
||||
|
||||
|
||||
# This function mostly exists for making slides about JAX.
|
||||
def _make_graphviz(fun):
|
||||
"""Adapts `fun` to return a graphviz dot string of its program representation.
|
||||
|
||||
Args:
|
||||
fun: The function whose `jaxpr` is to be rendered into graphviz dot. Its
|
||||
positional arguments and return value should be arrays, scalars, or
|
||||
standard Python containers (tuple/list/dict) thereof.
|
||||
|
||||
Returns:
|
||||
A wrapped version of `fun`, set up to return a graphviz dot string.
|
||||
|
||||
See make_jaxpr for a related function.
|
||||
"""
|
||||
# TODO(mattjj): handle eqn.restructure
|
||||
# TODO(mattjj): handle subjaxprs
|
||||
|
||||
def pv_like(x):
|
||||
aval = xla.abstractify(x)
|
||||
return pe.PartialVal((aval, core.unit))
|
||||
|
||||
id_names = ("id{}".format(i) for i in it.count())
|
||||
|
||||
def jaxpr_to_graphviz(jaxpr, consts):
|
||||
fragment = []
|
||||
|
||||
fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
|
||||
fragment.extend(map(constant_node, jaxpr.constvars, consts))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
id_name = next(id_names)
|
||||
fragment.append(function_node(id_name, eqn.primitive.name))
|
||||
fragment.extend(edge(invar, id_name) for invar in eqn.invars)
|
||||
fragment.extend(edge(id_name, outvar) for outvar in eqn.outvars)
|
||||
for ov in jaxpr.outvars:
|
||||
fragment.append(outvar_node(ov, "out"))
|
||||
return graph(''.join(fragment))
|
||||
|
||||
edge = '{} -> {} [color=gray30];\n'.format
|
||||
function_node = '{} [label="{}", shape=box, color=lightskyblue, style=filled];\n'.format
|
||||
invar_node = '{} [rank=2, label="{}", color=mediumspringgreen, style=filled];\n'.format
|
||||
outvar_node = '{} [label="{}", fillcolor=indianred1, style="filled,dashed", color=black];\n'.format
|
||||
constant_node = '{} [rank=2, label="{}", color=goldenrod1, style=filled];\n'.format
|
||||
freevar_node = '{} [rank=2, label="{}", color=palegreen, style=filled];\n'.format
|
||||
graph = 'digraph G {{{}}}'.format
|
||||
|
||||
@wraps(fun)
|
||||
def graphviz_maker(*args, **kwargs):
|
||||
wrapped = lu.wrap_init(fun, kwargs)
|
||||
jax_args, in_tree = tree_flatten((args, kwargs))
|
||||
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
|
||||
pvals = map(pv_like, jax_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
|
||||
return jaxpr_to_graphviz(jaxpr, consts)
|
||||
|
||||
graphviz_maker.__name__ = "make_graphviz({})".format(graphviz_maker.__name__)
|
||||
return graphviz_maker
|
||||
|
||||
|
||||
class ShapeDtypeStruct(object):
|
||||
__slots__ = ["shape", "dtype"]
|
||||
def __init__(self, shape, dtype):
|
||||
|
@ -435,6 +435,65 @@ def add_tangents(x, y):
|
||||
return add_jaxvals(x, y)
|
||||
|
||||
|
||||
def defvjp_all(prim, custom_vjp):
|
||||
# see https://github.com/google/jax/pull/636
|
||||
name = prim.name
|
||||
|
||||
def fun_jvp(xs, ts, **params):
|
||||
ts = map(instantiate_zeros, xs, ts)
|
||||
primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
|
||||
primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
|
||||
if prim.multiple_results:
|
||||
return primals, tangents
|
||||
else:
|
||||
primal, = primals
|
||||
tangent, = tangents
|
||||
return primal, tangent
|
||||
primitive_jvps[prim] = fun_jvp
|
||||
|
||||
fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
|
||||
fun_jvp_p.multiple_results = True
|
||||
def fun_jvp_partial_eval(trace, *tracers, **params):
|
||||
primals, tangents = split_list(tracers, [len(tracers) // 2])
|
||||
primals_out, vjp_py = custom_vjp(*primals, **params)
|
||||
if not prim.multiple_results:
|
||||
primals_out = [primals_out]
|
||||
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
||||
ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
||||
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
||||
num_res=len(res), out_avals=out_avals)
|
||||
return primals_out + tangents_out
|
||||
pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
|
||||
|
||||
fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
|
||||
fun_lin_p.multiple_results = True
|
||||
fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
|
||||
def fun_lin_transpose(cts, *args, **kwargs):
|
||||
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
|
||||
res, _ = split_list(args, [num_res])
|
||||
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
|
||||
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
|
||||
return [None] * num_res + outs
|
||||
primitive_transposes[fun_lin_p] = fun_lin_transpose
|
||||
|
||||
def defvjp(prim, *vjps):
|
||||
def vjpmaker(*primals):
|
||||
ans = prim.bind(*primals)
|
||||
vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
|
||||
for x, vjp in zip(primals, vjps)]
|
||||
return ans, vjpfun
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
def defvjp2(prim, *vjps):
|
||||
def vjpmaker(*primals):
|
||||
ans = prim.bind(*primals)
|
||||
vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
|
||||
for x, vjp in zip(primals, vjps)]
|
||||
return ans, vjpfun
|
||||
defvjp_all(prim, vjpmaker)
|
||||
|
||||
|
||||
def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
|
||||
assert isinstance(prim, Primitive)
|
||||
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
|
||||
|
@ -596,6 +596,200 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_complex_input_jacfwd_raises_error(self):
|
||||
self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
|
||||
|
||||
def test_defvjp_all(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return 2. * foo_p.bind(x)
|
||||
|
||||
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (4 * g * np.sin(x),)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, 2 * 3.**2, check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 4 * 2 * onp.sin(3.), check_dtypes=False)
|
||||
|
||||
def test_defvjp_all_const(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return foo_p.bind(x)
|
||||
|
||||
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (12.,)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, 9., check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 12., check_dtypes=True)
|
||||
|
||||
def test_defvjp_all_higher_order_revmode(self):
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x): return 2. * foo_p.bind(x)
|
||||
|
||||
ad.defvjp_all(foo_p, lambda x: (x**2, lambda g: (g * x ** 2,)))
|
||||
ans = api.grad(api.grad(foo))(3.)
|
||||
self.assertAllClose(ans, 2 * 2 * 3., check_dtypes=False)
|
||||
|
||||
def test_defvjp_all_multiple_arguments(self):
|
||||
# also tests passing in symbolic zero tangents b/c we differentiate wrt only
|
||||
# the first argument in one case
|
||||
|
||||
foo_p = Primitive('foo')
|
||||
def foo(x, y): return foo_p.bind(x, y)
|
||||
|
||||
def vjpfun(x, y):
|
||||
out = x**2 + y**3
|
||||
vjp = lambda g: (g + x + y, g * x * 9.)
|
||||
return out, vjp
|
||||
|
||||
ad.defvjp_all(foo_p, vjpfun)
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
||||
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)
|
||||
|
||||
ans = api.grad(foo, (0, 1))(3., 4.)
|
||||
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
|
||||
|
||||
def test_defvjp_all_custom_transforms(self):
|
||||
@api.custom_transforms
|
||||
def foo(x):
|
||||
return np.sin(x)
|
||||
|
||||
api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x,)))
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3.)
|
||||
self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 3., check_dtypes=False)
|
||||
|
||||
# TODO(mattjj): add defvjp_all test with pytree arguments
|
||||
|
||||
def test_defvjp(self):
|
||||
@api.custom_transforms
|
||||
def foo(x, y):
|
||||
return np.sin(x * y)
|
||||
|
||||
api.defvjp(foo, None, lambda g, _, x, y: g * x * y)
|
||||
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
|
||||
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 0., check_dtypes=False)
|
||||
|
||||
ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.)
|
||||
self.assertAllClose(ans_0, 0., check_dtypes=False)
|
||||
self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)
|
||||
|
||||
def test_defvjp_higher_order(self):
|
||||
@api.custom_transforms
|
||||
def foo(x):
|
||||
return np.sin(2. * x)
|
||||
|
||||
api.defvjp(foo, lambda g, _, x: g * np.cos(x))
|
||||
ans = api.grad(api.grad(foo))(2.)
|
||||
expected = api.grad(api.grad(np.sin))(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_defvjp_use_ans(self):
|
||||
@api.custom_transforms
|
||||
def foo(x, y):
|
||||
return np.sin(x * y)
|
||||
|
||||
api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans))
|
||||
val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.)
|
||||
self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False)
|
||||
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
|
||||
check_dtypes=False)
|
||||
|
||||
# TODO
|
||||
# def test_defjvp_closure_error(self):
|
||||
# def foo(x):
|
||||
# @api.custom_transforms
|
||||
# def bar(y):
|
||||
# return x * y
|
||||
|
||||
# api.defjvp(bar, lambda y_dot, ans, y: x * y)
|
||||
# return bar(x)
|
||||
# jtu.check_raises(
|
||||
# lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
|
||||
# "Detected differentiation with respect to closed-over values with "
|
||||
# "custom JVP rule, which isn't supported.")
|
||||
|
||||
# TODO
|
||||
# def test_defvjp_closure_error(self):
|
||||
# def foo(x):
|
||||
# @api.custom_transforms
|
||||
# def bar(y):
|
||||
# return x * y
|
||||
|
||||
# api.defvjp(bar, lambda g, ans, y: x * y)
|
||||
# return bar(x)
|
||||
# jtu.check_raises(
|
||||
# lambda: grad(foo)(1.,), ValueError,
|
||||
# "Detected differentiation w.r.t. variables from outside "
|
||||
# "the scope of <jax.custom_transforms function bar>, but defvjp and "
|
||||
# "defvjp_all only support differentiation w.r.t. positional arguments.")
|
||||
|
||||
def test_custom_transforms_eval_with_pytrees(self):
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
a, b = x[0], x[1]
|
||||
return {'hi': 2 * a, 'bye': 2 * b}
|
||||
|
||||
ans = f((1, 2))
|
||||
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
||||
|
||||
def test_custom_transforms_jit_with_pytrees(self):
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
a, b = x[0], x[1]
|
||||
return {'hi': 2 * a, 'bye': 2 * b}
|
||||
|
||||
ans = jit(f)((1, 2))
|
||||
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
||||
|
||||
def test_custom_transforms_jit_with_pytrees_consts(self):
|
||||
# The purpose of this test is to exercise the custom_transforms default
|
||||
# translation rule in how it deals with constants that are too large to be
|
||||
# treated as literals (at the time of writing).
|
||||
z = onp.arange(10.)
|
||||
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
a, b = x[0], x[1]
|
||||
return {'hi': 2 * a, 'bye': z * b}
|
||||
|
||||
ans = jit(f)((1, 2))
|
||||
self.assertAllClose(ans, {'hi': 2 * 1, 'bye': z * 2}, check_dtypes=False)
|
||||
|
||||
def test_custom_transforms_jvp_with_pytrees(self):
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
a, b = x[0], x[1]
|
||||
return {'hi': 2 * a, 'bye': 2 * b}
|
||||
|
||||
ans, out_tangent = api.jvp(f, ((1, 2),), ((3, 4),))
|
||||
self.assertEqual(ans, {'hi': 2 * 1, 'bye': 2 * 2})
|
||||
self.assertEqual(out_tangent, {'hi': 2 * 3, 'bye': 2 * 4})
|
||||
|
||||
def test_custom_transforms_vmap_with_pytrees(self):
|
||||
raise unittest.SkipTest("Test deprecated custom_transforms")
|
||||
@api.custom_transforms
|
||||
def f(x):
|
||||
a, b = x[0], x[1]
|
||||
return {'hi': 2 * a, 'bye': 2 * b}
|
||||
|
||||
ans = api.vmap(f)((onp.arange(3), onp.ones((3, 2))))
|
||||
expected = {'hi': 2 * onp.arange(3), 'bye': 2 * onp.ones((3, 2))}
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_transforms_jvp_with_closure(self):
|
||||
def f(x):
|
||||
@api.custom_transforms
|
||||
def g(y):
|
||||
return x * y
|
||||
return g(x)
|
||||
|
||||
ans = api.grad(f)(1.)
|
||||
expected = 2.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_custom_gradient(self):
|
||||
@api.custom_gradient
|
||||
def f(x):
|
||||
return x ** 2, lambda g: (g * x,)
|
||||
|
||||
self.assertAllClose(f(3.), 9., check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False)
|
||||
|
||||
def test_legacy_devicearray_repr(self):
|
||||
dx = device_put(3.)
|
||||
str(dx.item()) # doesn't crash
|
||||
@ -956,6 +1150,55 @@ class APITest(jtu.JaxTestCase):
|
||||
check_warning(lambda: np.tri(2, dtype="float64"),
|
||||
lambda: np.tri(2, dtype="float32"))
|
||||
|
||||
def test_custom_vjp_zeros(self):
|
||||
@api.custom_transforms
|
||||
def f(x, y):
|
||||
return 2 * x, 3 * y
|
||||
|
||||
def f_vjp(x, y):
|
||||
return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1])
|
||||
|
||||
api.defvjp_all(f, f_vjp, )
|
||||
api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash
|
||||
|
||||
def test_custom_transforms_vjp_nones(self):
|
||||
# issue rasied by jsnoek@ and jumper@
|
||||
@jax.custom_transforms
|
||||
def solve(a, b):
|
||||
return np.dot(np.linalg.inv(a), b)
|
||||
# print(solve(a, b))
|
||||
|
||||
def solve_vjp(a, b):
|
||||
x = solve(a, b)
|
||||
def vjp(x_tangent):
|
||||
dx = np.dot(solve(a, x_tangent), x.T)
|
||||
out = (dx, b * 0.)
|
||||
return out
|
||||
return x, vjp
|
||||
jax.defvjp_all(solve, solve_vjp)
|
||||
gf = grad(lambda a,b: np.sum(solve(a, b)))
|
||||
|
||||
n = 3
|
||||
a_in = np.linspace(0, 1, n)[:, None]
|
||||
a = np.dot(a_in, a_in.T) + np.eye(n) * 0.1
|
||||
real_x = onp.random.RandomState(0).randn(n)
|
||||
b = np.dot(a + np.eye(a.shape[0]), real_x)
|
||||
print(gf(a, b)) # doesn't crash
|
||||
|
||||
def test_vmap_in_axes_list(self):
|
||||
# https://github.com/google/jax/issues/2367
|
||||
dictionary = {'a': 5., 'b': np.ones(2)}
|
||||
x = np.zeros(3)
|
||||
y = np.arange(3.)
|
||||
|
||||
|
||||
def f(dct, x, y):
|
||||
return dct['a'] + dct['b'] + x + y
|
||||
|
||||
out1 = api.vmap(f, (None, 0, 0))(dictionary, x, y)
|
||||
out2 = api.vmap(f, [None, 0, 0])(dictionary, x, y)
|
||||
self.assertAllClose(out1, out2, check_dtypes=True)
|
||||
|
||||
def test_vmap_in_axes_tree_prefix_error(self):
|
||||
# https://github.com/google/jax/issues/795
|
||||
self.assertRaisesRegex(
|
||||
|
Loading…
x
Reference in New Issue
Block a user