Add experimental rematerialization decorator

We want to allow users to control how reverse-mode autodiff saves values
from the forward pass. In particular, we want it to be easy to signal
that a function shouldn't have any of its intermediate residuals stored
for the backward pass, and instead those values should be recomputed
from the function's saved inputs. (This feature is especially handy for
accelerators on which memory access is much more expensive than FLOPs
are.) In JAX terms, since we implement reverse-mode as a composition of
forward-mode, partial evaluation, and transposition, we want users to
control how partial evaluation behaves.

See https://github.com/google/jax/pull/1749 for more.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
Matthew Johnson 2019-11-22 10:53:11 -08:00 committed by Matthew Johnson
parent c42722838e
commit 9a8523603c
11 changed files with 410 additions and 29 deletions

View File

@ -104,7 +104,6 @@ Operators
scatter
scatter_add
select
shaped_identity
shift_left
shift_right_arithmetic
shift_right_logical

View File

@ -22,7 +22,7 @@ import six
from . import core
from . import ad_util
from . import dtypes
from . util import prod
from . util import prod, partialmethod
def concretization_err_msg(fun):
@ -145,6 +145,9 @@ class ShapedArray(UnshapedArray):
def strip_weak_type(self):
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
@ -185,6 +188,15 @@ class ConcreteArray(ShapedArray):
def strip_weak_type(self):
return ConcreteArray(self.val) if self.weak_type else self
_bool = _nonzero = partialmethod(_forward_to_value, bool)
_float = partialmethod(_forward_to_value, float)
_int = partialmethod(_forward_to_value, int)
if six.PY2:
_long = partialmethod(_forward_to_value, long) # noqa: F821
_complex = partialmethod(_forward_to_value, complex)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
class AbstractToken(core.AbstractValue): pass
abstract_token = AbstractToken()

View File

@ -55,7 +55,7 @@ from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
from .lib import xla_bridge as xb
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ShapedArray, raise_to_shaped
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
@ -1982,3 +1982,14 @@ def eval_shape(fun, *args, **kwargs):
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
return tree_unflatten(out_tree(), out)
def checkpoint(fun, concrete=False):
@wraps(fun)
def fun_remat(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
return tree_unflatten(out_tree(), out_flat)
return fun_remat
remat = checkpoint

View File

@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params):
def call_impl(f, *args, **params):
return f.call_wrapped(*args, **params)
del params # params parameterize the call primitive, not the function
return f.call_wrapped(*args)
call_p = Primitive('call')

View File

@ -140,6 +140,9 @@ def unpair_pval(pval):
return (aval_1, const_1), (aval_2, const_2)
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
def write_cotangent(v, ct):
# assert v not in primal_env
if ct is not None:
@ -159,13 +162,46 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
primal_env[v] = val
primal_env = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
def is_linear(var):
if type(var) is Literal:
return False
else:
return primal_env.get(var, undefined_primal) is undefined_primal
linear_eqns = []
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
else:
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)):
linear_eqns.append(eqn)
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
ct_env = {}
map(write_cotangent, jaxpr.outvars, cotangents_in)
for eqn in jaxpr.eqns[::-1]:
for eqn in linear_eqns[::-1]:
invals = map(read_primal, eqn.invars)
if eqn.primitive.multiple_results:
cts_in = map(read_cotangent, eqn.outvars)
@ -187,6 +223,51 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
cotangents_out = map(read_cotangent, jaxpr.invars)
return freevar_cts, cotangents_out
def _eval_primals(jaxpr, consts, freevar_vals, args):
primal_env = {}
def read_primal(v):
if type(v) is Literal:
return v.val
else:
return primal_env.get(v, undefined_primal)
def write_primal(v, val):
if val is not undefined_primal:
primal_env[v] = val
def is_linear(var):
if type(var) is Literal:
return False
else:
return primal_env.get(var, undefined_primal) is undefined_primal
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
if not any(is_linear(v) for v in eqn.invars):
in_vals = map(read_primal, eqn.invars)
ans = eqn.primitive.bind(*in_vals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
in_vals = map(read_primal, eqn.invars)
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
ans = tree_unflatten(out_tree(), out_flat)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)
class UndefinedPrimal(object):
def __repr__(self): return '_'
undefined_primal = UndefinedPrimal()
@ -460,6 +541,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))

View File

@ -24,7 +24,7 @@ import numpy as onp
from .. import core
from .. import linear_util as lu
from ..abstract_arrays import ShapedArray, ConcreteArray
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
from ..linear_util import thunk, transformation, transformation_with_aux
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
@ -104,6 +104,8 @@ class JaxprTrace(Trace):
return out_tracer
def process_call(self, call_primitive, f, tracers, params):
if call_primitive in call_partial_eval_rules:
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
if call_primitive in map_primitives:
return self.process_map(call_primitive, f, tracers, params)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
@ -188,7 +190,6 @@ class JaxprTrace(Trace):
return out_tracers
return out, todo
def _mapped_aval(aval):
if aval is core.abstract_unit:
return aval
@ -207,6 +208,8 @@ def _unmapped_aval(size, aval):
raise TypeError(aval)
map_primitives = set()
custom_partial_eval_rules = {}
call_partial_eval_rules = {}
def partial_eval(f, trace, pvs):
@ -450,4 +453,105 @@ def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
def _split_aval(unknown, aval):
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
custom_partial_eval_rules = {}
remat_call_p = core.Primitive('remat_call')
remat_call = partial(core.call_bind, remat_call_p)
remat_call_p.def_custom_bind(remat_call)
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True
def _remat_partial_eval(trace, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
instantiated_tracers = map(trace.instantiate_const, tracers)
if not concrete:
instantiated_tracers = [
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
fun, aux = partial_eval(f, trace, in_pvs)
out_flat = remat_call_p.bind(fun, *in_consts, **params)
out_pvs, jaxpr, env = aux()
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_avals = [raise_to_shaped(pv) for pv in in_pvs]
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
for pv, const in zip(out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in tracers]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
# First, we revise the jaxpr to be staged out not to output too much.
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
# Next, we need values for the outputs that should be known. Since consts
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
# minus the residual outputs that we don't need. When `concrete=True`, as an
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
# produced concrete avals at the output, simply by using those as computed
# values. For the use case of reverse-mode ad, all the primal outputs should
# be concrete (thus not recomputed).
to_compute = [not uk and type(pv) is not ConcreteArray
for uk, pv in zip(out_unknowns, out_pvs)]
jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
_, in_consts = unzip2(t.pval for t in tracers)
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None]
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
const_tracers = map(trace.new_instantiated_const, consts)
bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env))
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
(bound_subjaxpr,), params)
for t in out_tracers:
t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _dce_jaxpr(typed_jaxpr, outputs):
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan or other higher-order primitives.
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)
needed_vars = set(new_outvars)
new_eqns = []
for eqn in jaxpr.eqns[::-1]:
if set(eqn.outvars) & needed_vars:
new_eqns.append(eqn)
needed_vars.update(eqn.invars)
new_eqns = new_eqns[::-1]
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
new_outvars, new_eqns)
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
new_out_avals)
def _reconstruct_pval(pval1, const2, unknown):
pv1, const1 = pval1
if unknown or pv1 is None:
return pval1
else:
if type(pv1) is ConcreteArray:
return PartialVal((None, pv1.val))
else:
return PartialVal((None, const2))

View File

@ -762,6 +762,32 @@ pe.custom_partial_eval_rules[device_put_p] = lambda trace, x, **params: x
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
backend=None, device=None, concrete=None):
# This looks a lot like _xla_call_translation_rule, except for a widget we use
# to foil CSE.
del device, concrete # Unused.
subc = xb.make_computation_builder("remat_call_subcomputation")
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes]
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
args = [_foil_cse(subc, x) for x in args]
out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args)
subc = subc.Build(subc.Tuple(*out_nodes))
return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes))
call_translations[pe.remat_call_p] = _remat_translation_rule
def _foil_cse(c, x):
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
c.Constant(onp.array(1, dtype=onp.float32)),
[])
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
xla_shape = c.GetShape(x)
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
return c.Select(pred, x, zero)
### lazy constants
class DeviceConstant(DeviceArray):

View File

@ -1035,9 +1035,6 @@ def sort_key_val(keys, values, dimension=-1):
def tie_in(x, y):
return tie_in_p.bind(x, y)
def shaped_identity(x):
return shaped_identity_p.bind(x, shape=x.shape)
def full(shape, fill_value, dtype=None):
"""Returns an array of `shape` filled with `fill_value`.
@ -1472,7 +1469,7 @@ _complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
xla.translations[prim] = translation_rule or partial(standard_translate, name)
return prim
@ -1480,17 +1477,17 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
prim = Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
return prim
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
assert all(isinstance(arg, UnshapedArray) for arg in args), args
least_specialized = _max(
map(type, args), key=operator.attrgetter('array_abstraction_level'))
if least_specialized is ConcreteArray:
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
elif least_specialized is ShapedArray:
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
elif least_specialized is UnshapedArray:
@ -3933,7 +3930,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_key_val_abstract_eval(keys, values, dimension):
return keys, values
return raise_to_shaped(keys), raise_to_shaped(values)
def _sort_key_val_jvp(primals, tangents, dimension):
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
@ -3991,7 +3988,7 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
new_dimension = dimension + (keys_bdim <= dimension)
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
else:
raise Exception # unreachable
assert False # unreachable
sort_key_val_p = Primitive('sort_key_val')
sort_key_val_p.multiple_results = True
@ -4013,21 +4010,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
tie_in_p = Primitive('tie_in')
tie_in_p.def_impl(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: y)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1]
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
shaped_identity_p = Primitive('shape_id')
shaped_identity_p.def_impl(lambda x, shape: x)
shaped_identity_p.def_abstract_eval(lambda x, shape: x)
xla.translations[shaped_identity_p] = lambda c, x, shape: x
ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)])
batching.primitive_batchers[shaped_identity_p] = \
lambda a, d, shape: (shaped_identity(a[0]), d[0])
### constants

View File

@ -1084,7 +1084,7 @@ def custom_root(f, initial_guess, solve, tangent_solve):
def _root_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
def _root_impl(*args, **kwargs):
@ -1253,7 +1253,7 @@ def custom_linear_solve(
def _linear_solve_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
def _custom_linear_solve_impl(*args, **kwargs):

View File

@ -148,8 +148,8 @@ class WrappedFun(object):
gen = gen(*(gen_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = None
del gen
ans = self.f(*args, **dict(self.params, **kwargs))
del args
while stack:

View File

@ -1326,6 +1326,163 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
def test_remat_basic(self):
@api.remat
def g(x):
return lax.sin(x), 3.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = onp.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans, f_lin = api.linearize(f, 2.)
expected = onp.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = f_lin(3.)
expected = onp.cos(2.) * 3.
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(f_lin)(3.)
self.assertIn('sin', str(jaxpr))
def test_remat_grad_python_control_flow(self):
@partial(api.remat, concrete=True)
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
ans = f(2.)
expected = onp.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_jit(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
def f_(x):
return g(x)
f = api.jit(f_)
ans = f(2.)
expected = onp.sin(onp.sin(2.))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jit(api.grad(f_))(2.)
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_vmap(self):
@api.remat
def g(x):
return lax.sin(lax.sin(x))
x = onp.arange(3.)
ans = api.vmap(g)(x)
expected = onp.sin(onp.sin(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacfwd(g)(x)
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.jacrev(g)(x)
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_higher_order_autodiff(self):
def f(x):
return lax.cos(lax.sin(x))
g = api.remat(f)
ans = api.grad(api.grad(g))(3.)
expected = api.grad(api.grad(f))(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
def test_remat_scan(self):
to_scan = lambda c, x: (np.sin(c), None)
def f_noremat(x):
y, _ = lax.scan(to_scan, x, onp.arange(3.))
return y
def f_yesremat(x):
y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
return y
ans = f_yesremat(4.)
expected = f_noremat(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f_yesremat)(4.)
expected = api.grad(f_noremat)(4.)
self.assertAllClose(ans, expected, check_dtypes=False)
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
scan_eqn, = jaxpr.eqns
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
def test_remat_no_redundant_flops(self):
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
@api.jit
def g(x):
return f(2., x)
@api.remat
def f(x, y):
return np.sin(x) * y
# We swap out sin_p's impl rule to count how many times it's invoked
called = []
sin_impl = lax.sin_p.impl
try:
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
api.grad(g)(3.)
finally:
lax.sin_p.def_impl(sin_impl)
num_calls = len(called)
self.assertEqual(num_calls, 1)
def test_remat_binomial_checkpointing(self):
def binom_checkpoint(funs):
if len(funs) == 1:
return funs[0]
else:
f1 = binom_checkpoint(funs[:len(funs)//2])
f2 = binom_checkpoint(funs[len(funs)//2:])
return api.remat(lambda x: f1(f2(x)))
f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x))))
x = 4.
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
if __name__ == '__main__':
absltest.main()