delete old remat implementation

moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
This commit is contained in:
Matthew Johnson 2022-08-16 12:07:27 -07:00
parent 332d7d0168
commit d19e34fa4a
14 changed files with 131 additions and 424 deletions

View File

@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.3.17 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main).
* Breaking changes
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).

View File

@ -19,9 +19,7 @@ As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a ne
## How can I disable the change, and go back to the old behavior for now?
In case you have a problem with this change, it will **temporarily** be possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:
In case you have a problem with this change, **through version `jax==0.3.16`** it is possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:
1. set the shell environment variable `JAX_NEW_CHECKPOINT=0`;
2. execute `jax.config.update('jax_new_checkpoint', False)`;
@ -29,6 +27,10 @@ In case you have a problem with this change, it will **temporarily** be possible
If you need to revert to the old implementation, **please reach out** on a GitHub issue so that we can make the new implementation work for you.
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue
tracker](https://github.com/google/jax/issues) so we can help fix it!
## Why are we doing this?

View File

@ -18,19 +18,23 @@ from typing import Callable, Optional, List, Tuple, Sequence, Set, Union, Any
import types
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)
@ -38,9 +42,6 @@ from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
# TODO(mattjj): before this can be the standard remat implementation, we must:
# [ ] fix up callers who use the 'concrete' option (now removed)
map = safe_map
zip = safe_zip
@ -582,6 +583,104 @@ def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
pe.dce_rules[remat_p] = remat_dce
def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, is_gpu_platform: bool = False,
**_):
assert not jaxpr.constvars
if differentiated and prevent_cse:
if jax.config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while
else:
translation_rule = _remat_translation_using_cond
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)
# TODO(mattjj): add core utility for 'create dummy value for this type'?
def _dummy_like(aval: core.AbstractValue) -> Any:
if aval is core.abstract_token:
return jax.lax.create_token()
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
return jax.lax.broadcast(jax.lax.empty(aval.dtype), aval.shape) # type: ignore
else:
raise ValueError(aval)
def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
# Implements:
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
avals_out = tuple(v.aval for v in jaxpr.outvars)
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
def cond(carry):
counter, _, _ = carry
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
return counter < unif
def body(carry):
counter, _, args = carry
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)
carry_res = jax.lax.while_loop(cond, body, carry_init)
return carry_res[1]
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
# Implements:
# if(rng(0, 1) < 2)
# return eval_jaxpr(*args)
# else:
# return 0
avals_out = tuple(v.aval for v in jaxpr.outvars)
def remat_comp(*args):
return tuple(core.eval_jaxpr(jaxpr, (), *args))
def dummy_comp(*args):
return tuple(map(_dummy_like, avals_out))
unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
mlir.register_lowering(
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
mlir.register_lowering(
remat_p,
mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True),
multiple_results=True),
platform="gpu")
def _optimization_barrier_abstract_eval(*args):
return args
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))
def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)
def checkpoint_name(x, name):
return name_p.bind(x, name=name)

View File

@ -3103,27 +3103,11 @@ def checkpoint(fun: Callable, *,
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n")
if config.jax_new_checkpoint:
raise NotImplementedError(msg)
else:
warn(msg, DeprecationWarning)
if config.jax_new_checkpoint:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
@wraps(fun)
@api_boundary
def remat_f(*args, **kwargs):
f, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
concrete=concrete, prevent_cse=prevent_cse,
differentiated=False, policy=policy)
return tree_unflatten(out_tree(), out_flat)
return remat_f
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
remat = checkpoint # type: ignore

View File

@ -875,13 +875,6 @@ config.define_bool_state(
default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.'))
# TODO(mattjj): set default to True, then remove
config.define_bool_state(
name='jax_new_checkpoint',
default=True,
upgrade=True,
help='Whether to use the new jax.checkpoint implementation.')
# TODO(b/205307544): Remove flag once coordination service has rolled out.
config.define_bool_state(
name='jax_coordination_service',

View File

@ -19,8 +19,6 @@ from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
from jax._src.lax.control_flow.remat_impl import (remat_impl,
optimization_barrier_p)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)
@ -32,3 +30,5 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
_initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.ad_checkpoint import optimization_barrier_p

View File

@ -1,152 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the remat implementation."""
from functools import partial
from typing import Optional
import jax
from jax import core
from jax import lax
from jax.config import config
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_checkpoint
from jax._src import util
from jax._src.util import safe_map, wrap_name
from jax._src.lax.control_flow.conditionals import cond
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lax.control_flow.loops import while_loop
import numpy as np
_map = safe_map
def _dummy_remat_result(aval: core.AbstractValue):
"""A result that will be discarded"""
if aval is core.abstract_token:
return lax.create_token()
else:
return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
def _remat_translation_using_cond(*args,
jaxpr: core.Jaxpr):
# Implements:
# if(rng(0, 1) < 2)
# return eval_jaxpr(*args)
# else:
# return 0
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
def remat_comp(*args):
return tuple(core.eval_jaxpr(jaxpr, (), *args))
def dummy_comp(*args):
return tuple(_map(_dummy_remat_result, avals_out))
cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2))
return cond(cond_pred, remat_comp, dummy_comp, *args)
def _remat_translation_using_while(*args,
jaxpr: core.Jaxpr):
# Implements:
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
dummies_like_result = tuple(_map(_dummy_remat_result, avals_out))
carry_init = (np.int32(0), dummies_like_result, args)
def cond(carry):
counter, _, _ = carry
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
def body(carry):
counter, _, args = carry
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)
carry_res = while_loop(cond, body, carry_init)
return carry_res[1]
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)
def remat_impl(*args,
call_jaxpr: Optional[core.Jaxpr] = None,
jaxpr: Optional[core.Jaxpr] = None,
prevent_cse: bool, differentiated: bool,
policy,
is_gpu_platform: bool = False,
concrete: bool = False,
name: str = "checkpoint"):
# Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
# name is not passed for remat2, defaults to "checkpoint"
# TODO: remove call_jaxpr once we drop the remat call primitive
if jaxpr is None:
jaxpr = call_jaxpr
assert jaxpr is not None
assert not jaxpr.constvars
del concrete, policy # Unused.
if differentiated and prevent_cse:
if config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while
else:
translation_rule = _remat_translation_using_cond
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore
mlir.register_lowering(remat_primitive,
mlir.lower_fun(remat_impl, multiple_results=True))
mlir.register_lowering(remat_primitive,
mlir.lower_fun(partial(remat_impl,
is_gpu_platform=True),
multiple_results=True),
platform="gpu")
def _optimization_barrier_abstract_eval(*args):
return args
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
return util.unflatten(barrier_op.results, _map(len, barrier_types))
def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

View File

@ -1615,16 +1615,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
# cased to just pass-through the token
in_axes=eqn.params["in_axes"] + (None, None),
out_axes=eqn.params["out_axes"] + (0, 0))))
elif eqn.primitive is pe.remat_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
)))
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
fun_jaxpr = eqn.params["fun_jaxpr"]

View File

@ -1070,9 +1070,9 @@ class TensorFlowTrace(core.Trace):
def post_process_call(self, call_primitive: core.Primitive,
out_tracers: Sequence[TensorFlowTracer], params):
# We encountered a call primitive, e.g., remat_call_p, whose result
# (out_tracers) include TensorFlowTracer that were not passed through
# its arguments (captured from the environment).
# We encountered a call primitive whose result (out_tracers) include
# TensorFlowTracer that were not passed through its arguments (captured from
# the environment).
vals = tuple(t.val for t in out_tracers)
main = self.main
@ -1137,8 +1137,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
# Call primitives are inlined
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p,
partial_eval.remat_call_p, maps.xmap_p]:
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
@ -2549,7 +2548,7 @@ tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
extra_name_stack="scan")
tf_impl_with_avals[ad_checkpoint.remat_p] = \
_convert_jax_impl(partial(lax_control_flow.remat_impl,
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
# TODO: jax2tf cannot discriminate by platform
is_gpu_platform=False),
multiple_results=True,

View File

@ -42,7 +42,7 @@ import jax._src.lib.xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from google3.third_party.tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
config.parse_flags_with_absl()
@ -766,16 +766,13 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.TransformConvertAndCompare(f, arg, "grad")
self.TransformConvertAndCompare(f, arg, "grad_vmap")
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{flavor}", flavor=flavor)
for flavor in ["old", "new"]))
def test_remat(self, flavor="old"):
def test_remat(self):
def f(x1):
x2 = jnp.sin(x1)
x3 = jnp.sin(x2)
x4 = jnp.sin(x3)
return x4
remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f)
remat_f = ad_checkpoint.checkpoint(f)
# The computation of grad_f computes "sin" 5 times, 3 for the forward pass
# and then to rematerialize "x2" and "x3" in the backward pass.
@ -783,21 +780,19 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
# Check that we have a Sin under a conditional
f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False)
f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def()
if flavor == "old":
raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat")
if jax.config.jax_remat_opt_barrier:
if config.jax2tf_default_experimental_native_lowering:
self.assertRegex(
str(f_tf_graph), r"mhlo.optimization_barrier")
else:
self.assertRegex(
str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier")
str(f_tf_graph), r"XlaOptimizationBarrier")
elif config.jax_experimental_name_stack:
self.assertRegex(str(f_tf_graph),
r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin')
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
else:
self.assertRegex(str(f_tf_graph),
r'remat_checkpoint_/switch_case/indexed_case/Sin')
r'switch_case/indexed_case/Sin')
def test_remat_free_var(self):
def f(x):

View File

@ -637,39 +637,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
primitive_transposes[core.closed_call_p] = _closed_call_transpose
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
cotangent_in_avals, reduce_axes):
unknowns = map(is_undefined_primal, primals_in)
primal_jaxpr, tangent_jaxpr, _, _ = \
pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr),
unknowns=unknowns, instantiate=True) # type: ignore
args, in_tree = tree_flatten((primals_in, cotangents_in))
transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
tangent_jaxpr, reduce_axes)
flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree)
flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params)
return tree_unflatten(out_tree(), flat_cotangents_out)
primitive_transposes[pe.remat_call_p] = remat_transpose
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
primals_tangents_in, cotangents_in):
primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)]
tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)]
res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)
cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
(*res, *tangents_in), cotangents_in)
cotangents_out = iter(cotangents_out_[len(res):])
outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x)
for x in primals_tangents_in]
assert next(cotangents_out, None) is None
return outs
@lu.transformation_with_aux
def nonzero_outputs(*args, **kwargs):
results = yield args, kwargs
yield results, [type(r) is not Zero for r in results]
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)

View File

@ -1118,127 +1118,6 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals
remat_call_p: Primitive = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)
def _remat_partial_eval(trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
# the function being called, not just for the unknown parts. To do that, we
# instantiate all the input tracers as constants in the jaxpr being formed.
# Those tracers might have concrete avals, and doing abstract interpretation
# on concrete avals engenders a tradeoff: it allows data-dependent Python
# control flow to work, but it can in some cases lead to redundant FLOPs (done
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
# `concrete` parameter to switch this behavior, and if `concrete` is False
# then we raise the avals to the Shaped level.
if concrete:
instantiated_tracers = map(trace.instantiate_const, tracers)
else:
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
# This way we record all primitives applied to the inputs (all treated as
# unknown/instantiated) to produce the output. In the context of autodiff,
# that means we record primal, residual, and tangent computations (e.g. sine,
# cosine, and multiply).
in_pvals = [t.pval for t in instantiated_tracers]
in_knowns, in_avals, () = partition_pvals(in_pvals) # all are unknown
assert not any(in_knowns)
f = trace_to_subjaxpr_nounits(f, trace.main, True)
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
consts = remat_call_p.bind(f, **params) # no known inputs
_, out_avals, jaxpr, env = aux()
if jaxpr.effects:
raise NotImplementedError(
'Effects not supported in partial-eval of `checkpoint`/`remat`.')
env_tracers = map(trace.full_raise, env)
jaxpr = convert_constvars_jaxpr(jaxpr)
del in_pvals, in_knowns, in_avals, out_avals, f, aux, env
# When concrete=True, we could avoid some redundant computation by extracting
# values from any ConcreteArrays in `out_avals`, but we eschew that
# optimization.
# We're done with `f`, and in the steps to follow we work with `jaxpr`. To
# that end, we want a list of which inputs to `jaxpr` are known/unknown.
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
if params['policy']:
# unzip into jaxpr_known and jaxpr_unknown
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
partial_eval_jaxpr_custom(jaxpr, in_unknowns, [True] * len(in_unknowns),
False, False, params['policy'])
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
# compute known outputs and residuals (hoisted out of a remat_call)
_, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers)
if t.pval.is_known())
_, known_inputs = partition_list(in_used_known, [*consts, *in_consts_])
outs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
known_outputs, res = split_list(outs, [len(out_unknowns)-sum(out_unknowns)])
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, res)
const_tracers = map(trace.new_instantiated_const, consts)
in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers,
*instantiated_tracers]
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
recipe = new_eqn_recipe(in_jaxpr_tracers, unknown_outputs, remat_call_p,
new_params, jaxpr_unknown.effects, source_info_util.current())
for t in unknown_outputs: t.recipe = recipe
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
else:
# TODO(mattjj): this is an old parallel code path, to be deleted once the
# new path is fully functional
# Now that we have a `jaxpr` which represents as much as `f` as possible, we
# want to actually compute known output values. To do that, we first extract
# a `jaxpr_known`, and compute which outputs of `jaxpr` are known/unknown.
jaxpr_known_, _, out_unknowns, res_avals = partial_eval_jaxpr_nounits(
core.ClosedJaxpr(jaxpr, ()), in_unknowns, instantiate=False) # type: ignore
jaxpr_known, () = jaxpr_known_.jaxpr, jaxpr_known_.consts
num_res = len(res_avals)
# Next, we need values for known outputs. To get them, we need to evaluate
# jaxpr_known, minus the residual outputs that we don't need. In addition to
# eliminating residual outputs, we should perform DCE to eliminate the
# computation of those residuals; for example, if the primal program
# includes a sine, jaxpr_known includes both the sine and cosine, yet we
# don't want to compute the cosine here.
known_inputs = consts + [t for t in it.chain(env_tracers, tracers)
if t.pval.is_known()]
num_known_outputs = len(out_unknowns) - sum(out_unknowns)
jaxpr_known, kept_inputs = dce_jaxpr(
jaxpr_known, [True] * num_known_outputs + [False] * num_res)
known_inputs = [x for x, kept in zip(known_inputs, kept_inputs) if kept]
known_outputs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
del jaxpr_known, res_avals, num_res, num_known_outputs, kept_inputs
# We compute unknown outputs by using the full `jaxpr`, though we can prune
# out of it any known outputs and computations and only keep those
# operations we need to compute the unknown outputs.
jaxpr, kept_inputs = dce_jaxpr(jaxpr, out_unknowns)
const_tracers = map(trace.instantiate_const, map(trace.full_raise, consts))
unknown_inputs = [*const_tracers, *env_tracers, *instantiated_tracers]
unknown_inputs = [x for x, kept in zip(unknown_inputs, kept_inputs) if kept]
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
for x in jaxpr.outvars]
eqn = new_eqn_recipe(unknown_inputs, unknown_outputs, remat_call_p,
dict(params, call_jaxpr=jaxpr,
differentiated=True),
jaxpr.effects, source_info_util.current())
for t in unknown_outputs: t.recipe = eqn
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def partial_eval_jaxpr_custom(
jaxpr: Jaxpr,
in_unknowns: Sequence[bool],
@ -1400,10 +1279,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[remat_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, p1, p2:
(p1, dict(p2, differentiated=True)))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
@ -1494,7 +1369,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
dce_rules[remat_call_p] = dce_jaxpr_call_rule
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn

View File

@ -3922,55 +3922,6 @@ class RematTest(jtu.JaxTestCase):
expected = f_lin_expected(3.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(config.jax_new_checkpoint, "test is for old remat only")
def test_remat_grad_python_control_flow(self):
# See test `test_remat_grad_python_control_flow_static_argnums` for the
# new recommended way to express this computation.
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
def f(x):
x, _ = g(x)
return x
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
g = api.remat(g, concrete=True)
ans = f(2.)
expected = np.sin(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
ans = api.grad(f)(2.)
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(config.jax_new_checkpoint, "new remat raises error here")
def test_remat_concrete_deprecation_warning(self):
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
_ = api.remat(g, concrete=True)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat warns here")
def test_remat_concrete_deprecation_error(self):
def g(x):
if x > 0:
return lax.sin(x), 3.
else:
return lax.cos(x), 4.
with self.assertRaisesRegex(NotImplementedError, "static_argnums"):
_ = api.remat(g, concrete=True)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat different error")
def test_remat_concrete_error(self):
@api.remat # no static_argnums or concrete
def g(x):
@ -4055,7 +4006,6 @@ class RematTest(jtu.JaxTestCase):
expected = np.cos(2.)
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
def test_remat_retracing(self):
# This is *not* a very important behavior; remat doesn't need to provide
# caching guarantees with the same importance as jit. But even so, in the
@ -4079,7 +4029,6 @@ class RematTest(jtu.JaxTestCase):
y.block_until_ready()
self.assertEqual(count, 1)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
def test_remat_static_agnums_retracing(self):
# This is *not* a super important behavior; remat doesn't need to provide
# caching guarantees with the same importance as jit. But even so, in the
@ -4971,7 +4920,6 @@ class RematTest(jtu.JaxTestCase):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
def test_fwd_caching(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
@ -4981,7 +4929,6 @@ class RematTest(jtu.JaxTestCase):
y.block_until_ready()
self.assertEqual(count[0], 1)
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
def test_fwd_caching_static_argnums(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))

View File

@ -1979,7 +1979,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
for grad_func in ["grad", "value_and_grad"]
for use_remat in ["old", "new", "none"]))
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
if config.jax_new_checkpoint and use_remat == "old": raise SkipTest()
if use_remat == "old": raise SkipTest()
def f(x):
id_print_result = hcb.id_print(x, output_stream=testing_stream)