mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add experimental rematerialization decorator
We want to allow users to control how reverse-mode autodiff saves values from the forward pass. In particular, we want it to be easy to signal that a function shouldn't have any of its intermediate residuals stored for the backward pass, and instead those values should be recomputed from the function's saved inputs. (This feature is especially handy for accelerators on which memory access is much more expensive than FLOPs are.) In JAX terms, since we implement reverse-mode as a composition of forward-mode, partial evaluation, and transposition, we want users to control how partial evaluation behaves. See https://github.com/google/jax/pull/1749 for more. Co-authored-by: Dougal Maclaurin <dougalm@google.com>
This commit is contained in:
parent
c42722838e
commit
9a8523603c
@ -104,7 +104,6 @@ Operators
|
||||
scatter
|
||||
scatter_add
|
||||
select
|
||||
shaped_identity
|
||||
shift_left
|
||||
shift_right_arithmetic
|
||||
shift_right_logical
|
||||
|
@ -22,7 +22,7 @@ import six
|
||||
from . import core
|
||||
from . import ad_util
|
||||
from . import dtypes
|
||||
from . util import prod
|
||||
from . util import prod, partialmethod
|
||||
|
||||
|
||||
def concretization_err_msg(fun):
|
||||
@ -145,6 +145,9 @@ class ShapedArray(UnshapedArray):
|
||||
def strip_weak_type(self):
|
||||
return ShapedArray(self.shape, self.dtype) if self.weak_type else self
|
||||
|
||||
def _forward_to_value(self, fun, ignored_tracer, *args):
|
||||
return fun(self.val, *args)
|
||||
|
||||
class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
@ -185,6 +188,15 @@ class ConcreteArray(ShapedArray):
|
||||
def strip_weak_type(self):
|
||||
return ConcreteArray(self.val) if self.weak_type else self
|
||||
|
||||
_bool = _nonzero = partialmethod(_forward_to_value, bool)
|
||||
_float = partialmethod(_forward_to_value, float)
|
||||
_int = partialmethod(_forward_to_value, int)
|
||||
if six.PY2:
|
||||
_long = partialmethod(_forward_to_value, long) # noqa: F821
|
||||
_complex = partialmethod(_forward_to_value, complex)
|
||||
_hex = partialmethod(_forward_to_value, hex)
|
||||
_oct = partialmethod(_forward_to_value, oct)
|
||||
|
||||
class AbstractToken(core.AbstractValue): pass
|
||||
|
||||
abstract_token = AbstractToken()
|
||||
|
13
jax/api.py
13
jax/api.py
@ -55,7 +55,7 @@ from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
|
||||
from .lib import xla_bridge as xb
|
||||
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
|
||||
host_id, host_ids, host_count)
|
||||
from .abstract_arrays import ShapedArray, raise_to_shaped
|
||||
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import pxla
|
||||
@ -1982,3 +1982,14 @@ def eval_shape(fun, *args, **kwargs):
|
||||
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
|
||||
out = [ShapeDtypeStruct(x.shape, x.dtype) for x in out]
|
||||
return tree_unflatten(out_tree(), out)
|
||||
|
||||
|
||||
def checkpoint(fun, concrete=False):
|
||||
@wraps(fun)
|
||||
def fun_remat(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, concrete=concrete)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return fun_remat
|
||||
remat = checkpoint
|
||||
|
@ -597,7 +597,8 @@ def call_bind(primitive, f, *args, **params):
|
||||
|
||||
|
||||
def call_impl(f, *args, **params):
|
||||
return f.call_wrapped(*args, **params)
|
||||
del params # params parameterize the call primitive, not the function
|
||||
return f.call_wrapped(*args)
|
||||
|
||||
|
||||
call_p = Primitive('call')
|
||||
|
@ -140,6 +140,9 @@ def unpair_pval(pval):
|
||||
return (aval_1, const_1), (aval_2, const_2)
|
||||
|
||||
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
if all(ct is zero for ct in cotangents_in):
|
||||
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
|
||||
|
||||
def write_cotangent(v, ct):
|
||||
# assert v not in primal_env
|
||||
if ct is not None:
|
||||
@ -159,13 +162,46 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
primal_env[v] = val
|
||||
|
||||
primal_env = {}
|
||||
write_primal(core.unitvar, core.unit)
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
map(write_primal, jaxpr.freevars, freevar_vals)
|
||||
map(write_primal, jaxpr.invars, args)
|
||||
|
||||
def is_linear(var):
|
||||
if type(var) is Literal:
|
||||
return False
|
||||
else:
|
||||
return primal_env.get(var, undefined_primal) is undefined_primal
|
||||
|
||||
linear_eqns = []
|
||||
for eqn in jaxpr.eqns:
|
||||
if not eqn.bound_subjaxprs:
|
||||
if any(is_linear(v) for v in eqn.invars):
|
||||
linear_eqns.append(eqn)
|
||||
else:
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
if any(is_linear(v) for v in it.chain(eqn.invars, const_vars, bound_vars)):
|
||||
linear_eqns.append(eqn)
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
|
||||
ct_env = {}
|
||||
map(write_cotangent, jaxpr.outvars, cotangents_in)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
for eqn in linear_eqns[::-1]:
|
||||
invals = map(read_primal, eqn.invars)
|
||||
if eqn.primitive.multiple_results:
|
||||
cts_in = map(read_cotangent, eqn.outvars)
|
||||
@ -187,6 +223,51 @@ def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return freevar_cts, cotangents_out
|
||||
|
||||
def _eval_primals(jaxpr, consts, freevar_vals, args):
|
||||
primal_env = {}
|
||||
|
||||
def read_primal(v):
|
||||
if type(v) is Literal:
|
||||
return v.val
|
||||
else:
|
||||
return primal_env.get(v, undefined_primal)
|
||||
|
||||
def write_primal(v, val):
|
||||
if val is not undefined_primal:
|
||||
primal_env[v] = val
|
||||
|
||||
def is_linear(var):
|
||||
if type(var) is Literal:
|
||||
return False
|
||||
else:
|
||||
return primal_env.get(var, undefined_primal) is undefined_primal
|
||||
|
||||
write_primal(core.unitvar, core.unit)
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
map(write_primal, jaxpr.freevars, freevar_vals)
|
||||
map(write_primal, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
if not eqn.bound_subjaxprs:
|
||||
if not any(is_linear(v) for v in eqn.invars):
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
ans = eqn.primitive.bind(*in_vals, **eqn.params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
else:
|
||||
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
|
||||
sub_consts = map(read_primal, const_vars)
|
||||
sub_freevar_vals = map(read_primal, bound_vars)
|
||||
in_vals = map(read_primal, eqn.invars)
|
||||
all_args, in_tree_def = tree_flatten((sub_consts, sub_freevar_vals, in_vals))
|
||||
fun = hashable_partial(wrap_init(_eval_primals), subjaxpr)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
||||
out_flat = eqn.primitive.bind(fun, *all_args, **eqn.params)
|
||||
ans = tree_unflatten(out_tree(), out_flat)
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
return map(read_primal, jaxpr.outvars)
|
||||
|
||||
class UndefinedPrimal(object):
|
||||
def __repr__(self): return '_'
|
||||
undefined_primal = UndefinedPrimal()
|
||||
@ -460,6 +541,7 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
out_flat = primitive.bind(fun, *all_args, **params)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
|
||||
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
|
||||
|
||||
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
|
||||
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
|
||||
|
@ -24,7 +24,7 @@ import numpy as onp
|
||||
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..abstract_arrays import ShapedArray, ConcreteArray
|
||||
from ..abstract_arrays import ShapedArray, ConcreteArray, raise_to_shaped
|
||||
from ..linear_util import thunk, transformation, transformation_with_aux
|
||||
from ..util import unzip2, safe_zip, safe_map, toposort, partial, split_list
|
||||
from ..core import (Trace, Tracer, new_master, Jaxpr, Literal, get_aval,
|
||||
@ -104,6 +104,8 @@ class JaxprTrace(Trace):
|
||||
return out_tracer
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
if call_primitive in call_partial_eval_rules:
|
||||
return call_partial_eval_rules[call_primitive](self, f, tracers, params)
|
||||
if call_primitive in map_primitives:
|
||||
return self.process_map(call_primitive, f, tracers, params)
|
||||
in_pvs, in_consts = unzip2([t.pval for t in tracers])
|
||||
@ -188,7 +190,6 @@ class JaxprTrace(Trace):
|
||||
return out_tracers
|
||||
return out, todo
|
||||
|
||||
|
||||
def _mapped_aval(aval):
|
||||
if aval is core.abstract_unit:
|
||||
return aval
|
||||
@ -207,6 +208,8 @@ def _unmapped_aval(size, aval):
|
||||
raise TypeError(aval)
|
||||
|
||||
map_primitives = set()
|
||||
custom_partial_eval_rules = {}
|
||||
call_partial_eval_rules = {}
|
||||
|
||||
|
||||
def partial_eval(f, trace, pvs):
|
||||
@ -450,4 +453,105 @@ def partial_eval_jaxpr(jaxpr, unknowns, instantiate):
|
||||
def _split_aval(unknown, aval):
|
||||
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
|
||||
|
||||
custom_partial_eval_rules = {}
|
||||
|
||||
remat_call_p = core.Primitive('remat_call')
|
||||
remat_call = partial(core.call_bind, remat_call_p)
|
||||
remat_call_p.def_custom_bind(remat_call)
|
||||
remat_call_p.def_impl(core.call_impl)
|
||||
remat_call_p.multiple_results = True
|
||||
|
||||
def _remat_partial_eval(trace, f, tracers, params):
|
||||
concrete = params['concrete']
|
||||
|
||||
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
||||
# the function being called, not just for the unknown parts. To do that, we
|
||||
# instantiate all the input tracers as constants in the jaxpr being formed.
|
||||
# Those tracers might have concrete avals, and doing abstract interpretation
|
||||
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
||||
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
||||
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
||||
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
||||
# then we raise the avals to the Shaped level.
|
||||
instantiated_tracers = map(trace.instantiate_const, tracers)
|
||||
if not concrete:
|
||||
instantiated_tracers = [
|
||||
JaxprTracer(trace, PartialVal((raise_to_shaped(t.pval[0]), unit)), t.recipe)
|
||||
if type(t.pval[0]) is ConcreteArray else t for t in instantiated_tracers]
|
||||
|
||||
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
||||
in_pvs, in_consts = unzip2(t.pval for t in instantiated_tracers)
|
||||
fun, aux = partial_eval(f, trace, in_pvs)
|
||||
out_flat = remat_call_p.bind(fun, *in_consts, **params)
|
||||
out_pvs, jaxpr, env = aux()
|
||||
out_pval_consts1, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
|
||||
out_pvals1 = [PartialVal((pv, const)) for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
|
||||
# Since we traced with everything marked as unknown, but we need to know which
|
||||
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
|
||||
in_avals = [raise_to_shaped(pv) for pv in in_pvs]
|
||||
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
|
||||
for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
|
||||
in_unknowns = [t.pval[0] is not None for t in tracers]
|
||||
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
||||
num_res = len(jaxpr_1.out_avals) - len(jaxpr_2.out_avals)
|
||||
|
||||
# First, we revise the jaxpr to be staged out not to output too much.
|
||||
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns)
|
||||
|
||||
# Next, we need values for the outputs that should be known. Since consts
|
||||
# weren't passed through Python for evaluation, we need to evaluate jaxpr_1,
|
||||
# minus the residual outputs that we don't need. When `concrete=True`, as an
|
||||
# optimization we can avoid redoing *some* redundant FLOPs, namely those that
|
||||
# produced concrete avals at the output, simply by using those as computed
|
||||
# values. For the use case of reverse-mode ad, all the primal outputs should
|
||||
# be concrete (thus not recomputed).
|
||||
to_compute = [not uk and type(pv) is not ConcreteArray
|
||||
for uk, pv in zip(out_unknowns, out_pvs)]
|
||||
jaxpr_1 = _dce_jaxpr(jaxpr_1, to_compute + [False] * num_res)
|
||||
_, in_consts = unzip2(t.pval for t in tracers)
|
||||
out_pval_consts2 = core.jaxpr_as_fun(jaxpr_1)(*in_consts)[:-num_res or None]
|
||||
out_pvals = map(_reconstruct_pval, out_pvals1, out_pval_consts2, out_unknowns)
|
||||
|
||||
# Now that we have out_pvals, the rest is just like JaxprTrace.process_call.
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
bound_subjaxpr = (jaxpr, const_tracers, map(trace.full_raise, env))
|
||||
out_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_pvals]
|
||||
eqn = new_eqn_recipe(instantiated_tracers, out_tracers, remat_call_p,
|
||||
(bound_subjaxpr,), params)
|
||||
for t in out_tracers:
|
||||
t.recipe = eqn
|
||||
return out_tracers
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
|
||||
def _dce_jaxpr(typed_jaxpr, outputs):
|
||||
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
||||
# nontrivially DCE through scan or other higher-order primitives.
|
||||
jaxpr = typed_jaxpr.jaxpr
|
||||
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
new_outvars, new_out_avals = unzip2(out_pairs)
|
||||
|
||||
needed_vars = set(new_outvars)
|
||||
new_eqns = []
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
if set(eqn.outvars) & needed_vars:
|
||||
new_eqns.append(eqn)
|
||||
needed_vars.update(eqn.invars)
|
||||
new_eqns = new_eqns[::-1]
|
||||
|
||||
new_jaxpr = core.Jaxpr(jaxpr.constvars, jaxpr.freevars, jaxpr.invars,
|
||||
new_outvars, new_eqns)
|
||||
return core.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, typed_jaxpr.in_avals,
|
||||
new_out_avals)
|
||||
|
||||
def _reconstruct_pval(pval1, const2, unknown):
|
||||
pv1, const1 = pval1
|
||||
if unknown or pv1 is None:
|
||||
return pval1
|
||||
else:
|
||||
if type(pv1) is ConcreteArray:
|
||||
return PartialVal((None, pv1.val))
|
||||
else:
|
||||
return PartialVal((None, const2))
|
||||
|
@ -762,6 +762,32 @@ pe.custom_partial_eval_rules[device_put_p] = lambda trace, x, **params: x
|
||||
ad.deflinear(device_put_p, lambda cotangent, **kwargs: [cotangent])
|
||||
|
||||
|
||||
def _remat_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes, in_nodes,
|
||||
backend=None, device=None, concrete=None):
|
||||
# This looks a lot like _xla_call_translation_rule, except for a widget we use
|
||||
# to foil CSE.
|
||||
del device, concrete # Unused.
|
||||
subc = xb.make_computation_builder("remat_call_subcomputation")
|
||||
consts = [subc.ParameterWithShape(c.GetShape(n)) for n in const_nodes]
|
||||
freevars = [subc.ParameterWithShape(c.GetShape(n)) for n in freevar_nodes]
|
||||
args = [subc.ParameterWithShape(c.GetShape(n)) for n in in_nodes]
|
||||
args = [_foil_cse(subc, x) for x in args]
|
||||
out_nodes = jaxpr_subcomp(subc, jaxpr, backend, axis_env, consts, freevars, *args)
|
||||
subc = subc.Build(subc.Tuple(*out_nodes))
|
||||
return c.Call(subc, list(const_nodes) + list(freevar_nodes) + list(in_nodes))
|
||||
call_translations[pe.remat_call_p] = _remat_translation_rule
|
||||
|
||||
def _foil_cse(c, x):
|
||||
rng = c.RngNormal(c.Constant(onp.array(0, dtype=onp.float32)),
|
||||
c.Constant(onp.array(1, dtype=onp.float32)),
|
||||
[])
|
||||
pred = c.Lt(rng, c.Constant(onp.finfo(onp.float32).max))
|
||||
xla_shape = c.GetShape(x)
|
||||
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
|
||||
zero = c.Broadcast(c.Constant(onp.array(0, dtype=dtype)), shape)
|
||||
return c.Select(pred, x, zero)
|
||||
|
||||
|
||||
### lazy constants
|
||||
|
||||
class DeviceConstant(DeviceArray):
|
||||
|
@ -1035,9 +1035,6 @@ def sort_key_val(keys, values, dimension=-1):
|
||||
def tie_in(x, y):
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
def shaped_identity(x):
|
||||
return shaped_identity_p.bind(x, shape=x.shape)
|
||||
|
||||
|
||||
def full(shape, fill_value, dtype=None):
|
||||
"""Returns an array of `shape` filled with `fill_value`.
|
||||
@ -1472,7 +1469,7 @@ _complex_basetype = lambda dtype: onp.abs(onp.zeros((), dtype)).dtype
|
||||
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
||||
xla.translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
@ -1480,17 +1477,17 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
def standard_reduction_primitive(shape_rule, dtype_rule, name, translation_rule=None):
|
||||
prim = Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule))
|
||||
prim.def_abstract_eval(partial(standard_abstract_eval, prim, shape_rule, dtype_rule))
|
||||
xla.reduction_translations[prim] = translation_rule or partial(standard_translate, name)
|
||||
return prim
|
||||
|
||||
|
||||
def standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs):
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs):
|
||||
assert all(isinstance(arg, UnshapedArray) for arg in args), args
|
||||
least_specialized = _max(
|
||||
map(type, args), key=operator.attrgetter('array_abstraction_level'))
|
||||
if least_specialized is ConcreteArray:
|
||||
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
||||
return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
|
||||
elif least_specialized is ShapedArray:
|
||||
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
|
||||
elif least_specialized is UnshapedArray:
|
||||
@ -3933,7 +3930,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
|
||||
def _sort_key_val_abstract_eval(keys, values, dimension):
|
||||
return keys, values
|
||||
return raise_to_shaped(keys), raise_to_shaped(values)
|
||||
|
||||
def _sort_key_val_jvp(primals, tangents, dimension):
|
||||
# NOTE(mattjj): this re-sorts three times, but if we had a variadic
|
||||
@ -3991,7 +3988,7 @@ def _sort_key_val_batch_rule(batched_args, batch_dims, dimension):
|
||||
new_dimension = dimension + (keys_bdim <= dimension)
|
||||
return sort_key_val(keys, new_values, new_dimension), (keys_bdim, keys_bdim)
|
||||
else:
|
||||
raise Exception # unreachable
|
||||
assert False # unreachable
|
||||
|
||||
sort_key_val_p = Primitive('sort_key_val')
|
||||
sort_key_val_p.multiple_results = True
|
||||
@ -4013,21 +4010,13 @@ def _tie_in_batch_rule(batched_args, batch_dims):
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
tie_in_p.def_impl(lambda x, y: y)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: y)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.deflinear(tie_in_p, _tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
||||
masking.shape_rules[tie_in_p] = lambda shape_exprs: shape_exprs[1]
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
||||
shaped_identity_p = Primitive('shape_id')
|
||||
shaped_identity_p.def_impl(lambda x, shape: x)
|
||||
shaped_identity_p.def_abstract_eval(lambda x, shape: x)
|
||||
xla.translations[shaped_identity_p] = lambda c, x, shape: x
|
||||
ad.deflinear(shaped_identity_p, lambda t, shape: [shaped_identity(t)])
|
||||
batching.primitive_batchers[shaped_identity_p] = \
|
||||
lambda a, d, shape: (shaped_identity(a[0]), d[0])
|
||||
|
||||
|
||||
### constants
|
||||
|
||||
|
@ -1084,7 +1084,7 @@ def custom_root(f, initial_guess, solve, tangent_solve):
|
||||
|
||||
|
||||
def _root_abstract_eval(*args, **kwargs):
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
|
||||
|
||||
|
||||
def _root_impl(*args, **kwargs):
|
||||
@ -1253,7 +1253,7 @@ def custom_linear_solve(
|
||||
|
||||
|
||||
def _linear_solve_abstract_eval(*args, **kwargs):
|
||||
return args[sum(kwargs['const_lengths']):]
|
||||
return _map(raise_to_shaped, args[sum(kwargs['const_lengths']):])
|
||||
|
||||
|
||||
def _custom_linear_solve_impl(*args, **kwargs):
|
||||
|
@ -148,8 +148,8 @@ class WrappedFun(object):
|
||||
gen = gen(*(gen_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
gen = None
|
||||
|
||||
del gen
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
del args
|
||||
while stack:
|
||||
|
@ -1326,6 +1326,163 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans1, onp.cos(2.), check_dtypes=False)
|
||||
self.assertAllClose(ans2, onp.cos(3.), check_dtypes=False)
|
||||
|
||||
def test_remat_basic(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(x), 3.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans, f_lin = api.linearize(f, 2.)
|
||||
expected = onp.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = f_lin(3.)
|
||||
expected = onp.cos(2.) * 3.
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(f_lin)(3.)
|
||||
self.assertIn('sin', str(jaxpr))
|
||||
|
||||
def test_remat_grad_python_control_flow(self):
|
||||
@partial(api.remat, concrete=True)
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_jit(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x))
|
||||
|
||||
def f_(x):
|
||||
return g(x)
|
||||
f = api.jit(f_)
|
||||
|
||||
ans = f(2.)
|
||||
expected = onp.sin(onp.sin(2.))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jit(api.grad(f_))(2.)
|
||||
expected = onp.cos(onp.sin(2.)) * onp.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_vmap(self):
|
||||
@api.remat
|
||||
def g(x):
|
||||
return lax.sin(lax.sin(x))
|
||||
|
||||
x = onp.arange(3.)
|
||||
|
||||
ans = api.vmap(g)(x)
|
||||
expected = onp.sin(onp.sin(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jacfwd(g)(x)
|
||||
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.jacrev(g)(x)
|
||||
expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_higher_order_autodiff(self):
|
||||
def f(x):
|
||||
return lax.cos(lax.sin(x))
|
||||
g = api.remat(f)
|
||||
|
||||
ans = api.grad(api.grad(g))(3.)
|
||||
expected = api.grad(api.grad(f))(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_remat_scan(self):
|
||||
to_scan = lambda c, x: (np.sin(c), None)
|
||||
|
||||
def f_noremat(x):
|
||||
y, _ = lax.scan(to_scan, x, onp.arange(3.))
|
||||
return y
|
||||
|
||||
def f_yesremat(x):
|
||||
y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
|
||||
return y
|
||||
|
||||
ans = f_yesremat(4.)
|
||||
expected = f_noremat(4.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f_yesremat)(4.)
|
||||
expected = api.grad(f_noremat)(4.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
self.assertIn(' sin ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
|
||||
scan_eqn, = jaxpr.eqns
|
||||
self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
|
||||
|
||||
def test_remat_no_redundant_flops(self):
|
||||
# see https://github.com/google/jax/pull/1749#issuecomment-558267584
|
||||
|
||||
@api.jit
|
||||
def g(x):
|
||||
return f(2., x)
|
||||
|
||||
@api.remat
|
||||
def f(x, y):
|
||||
return np.sin(x) * y
|
||||
|
||||
# We swap out sin_p's impl rule to count how many times it's invoked
|
||||
called = []
|
||||
sin_impl = lax.sin_p.impl
|
||||
try:
|
||||
lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x))
|
||||
api.grad(g)(3.)
|
||||
finally:
|
||||
lax.sin_p.def_impl(sin_impl)
|
||||
num_calls = len(called)
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_remat_binomial_checkpointing(self):
|
||||
def binom_checkpoint(funs):
|
||||
if len(funs) == 1:
|
||||
return funs[0]
|
||||
else:
|
||||
f1 = binom_checkpoint(funs[:len(funs)//2])
|
||||
f2 = binom_checkpoint(funs[len(funs)//2:])
|
||||
return api.remat(lambda x: f1(f2(x)))
|
||||
|
||||
f1 = binom_checkpoint([np.sin, np.sin, np.sin, np.sin])
|
||||
f2 = lambda x: np.sin(np.sin(np.sin(np.sin(x))))
|
||||
x = 4.
|
||||
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user