From d19e34fa4a1d89c02a72dc49896013ba6fc99127 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 16 Aug 2022 12:07:27 -0700 Subject: [PATCH] delete old remat implementation moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py --- CHANGELOG.md | 4 + docs/jep/11830-new-remat-checkpoint.md | 8 +- jax/_src/ad_checkpoint.py | 107 ++++++++++++- jax/_src/api.py | 26 +--- jax/_src/config.py | 7 - jax/_src/lax/control_flow/__init__.py | 4 +- jax/_src/lax/control_flow/remat_impl.py | 152 ------------------- jax/experimental/host_callback.py | 10 -- jax/experimental/jax2tf/jax2tf.py | 11 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 17 +-- jax/interpreters/ad.py | 28 ---- jax/interpreters/partial_eval.py | 126 --------------- tests/api_test.py | 53 ------- tests/host_callback_test.py | 2 +- 14 files changed, 131 insertions(+), 424 deletions(-) delete mode 100644 jax/_src/lax/control_flow/remat_impl.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6731ca334..62a81951b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.17 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main). +* Breaking changes + * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports + the `concrete` option, following the previous version's deprecation; see + [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.16 * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md index 51c8219ab..da0adaf18 100644 --- a/docs/jep/11830-new-remat-checkpoint.md +++ b/docs/jep/11830-new-remat-checkpoint.md @@ -19,9 +19,7 @@ As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a ne ## How can I disable the change, and go back to the old behavior for now? -In case you have a problem with this change, it will **temporarily** be possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways: - - +In case you have a problem with this change, **through version `jax==0.3.16`** it is possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways: 1. set the shell environment variable `JAX_NEW_CHECKPOINT=0`; 2. execute `jax.config.update('jax_new_checkpoint', False)`; @@ -29,6 +27,10 @@ In case you have a problem with this change, it will **temporarily** be possible If you need to revert to the old implementation, **please reach out** on a GitHub issue so that we can make the new implementation work for you. +As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer +available. If you have an issue, please reach out on [the issue +tracker](https://github.com/google/jax/issues) so we can help fix it! + ## Why are we doing this? diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 9af0731ad..0e784fd26 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -18,19 +18,23 @@ from typing import Callable, Optional, List, Tuple, Sequence, Set, Union, Any import types from absl import logging +import numpy as np import jax from jax import core from jax import linear_util as lu from jax.interpreters import ad from jax.interpreters import batching -from jax.interpreters import partial_eval as pe from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import xla from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util +from jax._src import util from jax._src import source_info_util from jax._src import traceback_util from jax._src.api_util import flatten_fun, shaped_abstractify +from jax._src.lib.mlir.dialects import mhlo from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -38,9 +42,6 @@ from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) -# TODO(mattjj): before this can be the standard remat implementation, we must: -# [ ] fix up callers who use the 'concrete' option (now removed) - map = safe_map zip = safe_zip @@ -582,6 +583,104 @@ def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn pe.dce_rules[remat_p] = remat_dce +def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool, + differentiated: bool, is_gpu_platform: bool = False, + **_): + assert not jaxpr.constvars + + if differentiated and prevent_cse: + if jax.config.jax_remat_opt_barrier: + translation_rule = _remat_translation_using_opt_barrier + elif is_gpu_platform: + translation_rule = _remat_translation_using_while + else: + translation_rule = _remat_translation_using_cond + else: + translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) + + return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr) + +def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): + args = _optimization_barrier(args) + return core.eval_jaxpr(jaxpr, (), *args) + +# TODO(mattjj): add core utility for 'create dummy value for this type'? +def _dummy_like(aval: core.AbstractValue) -> Any: + if aval is core.abstract_token: + return jax.lax.create_token() + elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): + return jax.lax.broadcast(jax.lax.empty(aval.dtype), aval.shape) # type: ignore + else: + raise ValueError(aval) + +def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): + # Implements: + # for(counter=0, result=0; counter < rng(1, 2); counter ++) { + # result = eval_jaxpr(*args) + # } + # The loop carry is a tuple: (counter, result, args) + avals_out = tuple(v.aval for v in jaxpr.outvars) + carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) + def cond(carry): + counter, _, _ = carry + unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=()) + return counter < unif + + def body(carry): + counter, _, args = carry + results = core.eval_jaxpr(jaxpr, (), *args) + return (counter + 1, tuple(results), args) + + carry_res = jax.lax.while_loop(cond, body, carry_init) + return carry_res[1] + +def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): + # Implements: + # if(rng(0, 1) < 2) + # return eval_jaxpr(*args) + # else: + # return 0 + avals_out = tuple(v.aval for v in jaxpr.outvars) + + def remat_comp(*args): + return tuple(core.eval_jaxpr(jaxpr, (), *args)) + def dummy_comp(*args): + return tuple(map(_dummy_like, avals_out)) + + unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=()) + return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) + +mlir.register_lowering( + remat_p, mlir.lower_fun(remat_lowering, multiple_results=True)) +mlir.register_lowering( + remat_p, + mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True), + multiple_results=True), + platform="gpu") + +def _optimization_barrier_abstract_eval(*args): + return args + +def _optimization_barrier_lowering_rule(ctx, *args): + barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in) + flat_barrier_types = util.flatten(barrier_types) + flat_args = mlir.flatten_lowering_ir_args(args) + barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) + return util.unflatten(barrier_op.results, map(len, barrier_types)) + +def _optimization_barrier(arg): + flat_args, treedef = tree_flatten(arg) + return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) + +optimization_barrier_p = core.Primitive('optimization_barrier') +optimization_barrier_p.multiple_results = True +optimization_barrier_p.def_impl( + partial(xla.apply_primitive, optimization_barrier_p)) +optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) +mlir.register_lowering(optimization_barrier_p, + _optimization_barrier_lowering_rule) + + def checkpoint_name(x, name): return name_p.bind(x, name=name) diff --git a/jax/_src/api.py b/jax/_src/api.py index 8a3aeb577..9e34aaf8a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3103,27 +3103,11 @@ def checkpoint(fun: Callable, *, " return f(x)\n" " else:\n" " return g(x)\n" - "\n") - if config.jax_new_checkpoint: - raise NotImplementedError(msg) - else: - warn(msg, DeprecationWarning) - - if config.jax_new_checkpoint: - return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy, - static_argnums=static_argnums) - - @wraps(fun) - @api_boundary - def remat_f(*args, **kwargs): - f, args = _remat_static_argnums(fun, static_argnums, args) - args_flat, in_tree = tree_flatten((args, kwargs)) - flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree) - out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__, - concrete=concrete, prevent_cse=prevent_cse, - differentiated=False, policy=policy) - return tree_unflatten(out_tree(), out_flat) - return remat_f + "\n" + "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") + raise NotImplementedError(msg) + return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy, + static_argnums=static_argnums) remat = checkpoint # type: ignore diff --git a/jax/_src/config.py b/jax/_src/config.py index 88bc47dec..3f3afcdf6 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -875,13 +875,6 @@ config.define_bool_state( default=(lib.version >= (0, 3, 6)), help=('Enables using optimization-barrier op for lowering remat.')) -# TODO(mattjj): set default to True, then remove -config.define_bool_state( - name='jax_new_checkpoint', - default=True, - upgrade=True, - help='Whether to use the new jax.checkpoint implementation.') - # TODO(b/205307544): Remove flag once coordination service has rolled out. config.define_bool_state( name='jax_coordination_service', diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index e0d21e322..0a57e560e 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -19,8 +19,6 @@ from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p, scan, scan_bind, scan_p, _scan_impl, while_loop, while_p) from jax._src.lax.control_flow.conditionals import cond, cond_p, switch -from jax._src.lax.control_flow.remat_impl import (remat_impl, - optimization_barrier_p) from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root, _custom_linear_solve_impl, linear_solve_p) @@ -32,3 +30,5 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr, _initial_style_jaxpr, _initial_style_jaxprs_with_common_consts, _check_tree_and_avals) +# TODO(mattjj): fix dependent library which expects optimization_barrier_p here +from jax._src.ad_checkpoint import optimization_barrier_p diff --git a/jax/_src/lax/control_flow/remat_impl.py b/jax/_src/lax/control_flow/remat_impl.py deleted file mode 100644 index ad3465edd..000000000 --- a/jax/_src/lax/control_flow/remat_impl.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Module for the remat implementation.""" -from functools import partial - -from typing import Optional - -import jax -from jax import core -from jax import lax -from jax.config import config -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.interpreters import xla -from jax.tree_util import tree_flatten, tree_unflatten -from jax._src import ad_checkpoint -from jax._src import util -from jax._src.util import safe_map, wrap_name -from jax._src.lax.control_flow.conditionals import cond -from jax._src.lib.mlir.dialects import mhlo -from jax._src.lax.control_flow.loops import while_loop -import numpy as np - -_map = safe_map - -def _dummy_remat_result(aval: core.AbstractValue): - """A result that will be discarded""" - if aval is core.abstract_token: - return lax.create_token() - else: - return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore - -def _remat_translation_using_cond(*args, - jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - avals_out = tuple(ov.aval for ov in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(_map(_dummy_remat_result, avals_out)) - - cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2)) - return cond(cond_pred, remat_comp, dummy_comp, *args) - -def _remat_translation_using_while(*args, - jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - avals_out = tuple(ov.aval for ov in jaxpr.outvars) - dummies_like_result = tuple(_map(_dummy_remat_result, avals_out)) - carry_init = (np.int32(0), dummies_like_result, args) - def cond(carry): - counter, _, _ = carry - return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=()) - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = while_loop(cond, body, carry_init) - return carry_res[1] - - -def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): - args = _optimization_barrier(args) - return core.eval_jaxpr(jaxpr, (), *args) - - -def remat_impl(*args, - call_jaxpr: Optional[core.Jaxpr] = None, - jaxpr: Optional[core.Jaxpr] = None, - prevent_cse: bool, differentiated: bool, - policy, - is_gpu_platform: bool = False, - concrete: bool = False, - name: str = "checkpoint"): - # Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat) - # name is not passed for remat2, defaults to "checkpoint" - # TODO: remove call_jaxpr once we drop the remat call primitive - if jaxpr is None: - jaxpr = call_jaxpr - assert jaxpr is not None - assert not jaxpr.constvars - - del concrete, policy # Unused. - if differentiated and prevent_cse: - if config.jax_remat_opt_barrier: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond - else: - translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) - - return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr) - -for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore - mlir.register_lowering(remat_primitive, - mlir.lower_fun(remat_impl, multiple_results=True)) - mlir.register_lowering(remat_primitive, - mlir.lower_fun(partial(remat_impl, - is_gpu_platform=True), - multiple_results=True), - platform="gpu") - - -def _optimization_barrier_abstract_eval(*args): - return args - - -def _optimization_barrier_lowering_rule(ctx, *args): - barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) - flat_barrier_types = util.flatten(barrier_types) - - flat_args = mlir.flatten_lowering_ir_args(args) - barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) - return util.unflatten(barrier_op.results, _map(len, barrier_types)) - - -def _optimization_barrier(arg): - flat_args, treedef = tree_flatten(arg) - return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) - - -optimization_barrier_p = core.Primitive('optimization_barrier') -optimization_barrier_p.multiple_results = True -optimization_barrier_p.def_impl( - partial(xla.apply_primitive, optimization_barrier_p)) -optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) -mlir.register_lowering(optimization_barrier_p, - _optimization_barrier_lowering_rule) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index dc8eddef5..70bc1dab3 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1615,16 +1615,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], # cased to just pass-through the token in_axes=eqn.params["in_axes"] + (None, None), out_axes=eqn.params["out_axes"] + (0, 0)))) - elif eqn.primitive is pe.remat_call_p: - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - ))) elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p: fun_jaxpr = eqn.params["fun_jaxpr"] diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 148522d64..e07b6dc69 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1070,9 +1070,9 @@ class TensorFlowTrace(core.Trace): def post_process_call(self, call_primitive: core.Primitive, out_tracers: Sequence[TensorFlowTracer], params): - # We encountered a call primitive, e.g., remat_call_p, whose result - # (out_tracers) include TensorFlowTracer that were not passed through - # its arguments (captured from the environment). + # We encountered a call primitive whose result (out_tracers) include + # TensorFlowTracer that were not passed through its arguments (captured from + # the environment). vals = tuple(t.val for t in out_tracers) main = self.main @@ -1137,8 +1137,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): # Call primitives are inlined -for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, - partial_eval.remat_call_p, maps.xmap_p]: +for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. @@ -2549,7 +2548,7 @@ tf_impl_with_avals[lax.scan_p] = _convert_jax_impl( extra_name_stack="scan") tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(lax_control_flow.remat_impl, + _convert_jax_impl(partial(ad_checkpoint.remat_lowering, # TODO: jax2tf cannot discriminate by platform is_gpu_platform=False), multiple_results=True, diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 804d98cde..6c4e76cec 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -42,7 +42,7 @@ import jax._src.lib.xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] # pylint: disable=g-direct-tensorflow-import -from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] +from google3.third_party.tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] # pylint: enable=g-direct-tensorflow-import config.parse_flags_with_absl() @@ -766,16 +766,13 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): self.TransformConvertAndCompare(f, arg, "grad") self.TransformConvertAndCompare(f, arg, "grad_vmap") - @parameterized.named_parameters(jtu.cases_from_list( - dict(testcase_name=f"_{flavor}", flavor=flavor) - for flavor in ["old", "new"])) - def test_remat(self, flavor="old"): + def test_remat(self): def f(x1): x2 = jnp.sin(x1) x3 = jnp.sin(x2) x4 = jnp.sin(x3) return x4 - remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f) + remat_f = ad_checkpoint.checkpoint(f) # The computation of grad_f computes "sin" 5 times, 3 for the forward pass # and then to rematerialize "x2" and "x3" in the backward pass. @@ -783,21 +780,19 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): # Check that we have a Sin under a conditional f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False) f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def() - if flavor == "old": - raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat") if jax.config.jax_remat_opt_barrier: if config.jax2tf_default_experimental_native_lowering: self.assertRegex( str(f_tf_graph), r"mhlo.optimization_barrier") else: self.assertRegex( - str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier") + str(f_tf_graph), r"XlaOptimizationBarrier") elif config.jax_experimental_name_stack: self.assertRegex(str(f_tf_graph), - r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin') + r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') else: self.assertRegex(str(f_tf_graph), - r'remat_checkpoint_/switch_case/indexed_case/Sin') + r'switch_case/indexed_case/Sin') def test_remat_free_var(self): def f(x): diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index ca7919f8c..47acf7cef 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -637,39 +637,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes): primitive_transposes[core.closed_call_p] = _closed_call_transpose -def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, - cotangent_in_avals, reduce_axes): - unknowns = map(is_undefined_primal, primals_in) - primal_jaxpr, tangent_jaxpr, _, _ = \ - pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr), - unknowns=unknowns, instantiate=True) # type: ignore - args, in_tree = tree_flatten((primals_in, cotangents_in)) - transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr, - tangent_jaxpr, reduce_axes) - flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree) - flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params) - return tree_unflatten(out_tree(), flat_cotangents_out) -primitive_transposes[pe.remat_call_p] = remat_transpose - -def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, - primals_tangents_in, cotangents_in): - primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)] - tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)] - res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in) - cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (), - (*res, *tangents_in), cotangents_in) - cotangents_out = iter(cotangents_out_[len(res):]) - outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x) - for x in primals_tangents_in] - assert next(cotangents_out, None) is None - return outs - @lu.transformation_with_aux def nonzero_outputs(*args, **kwargs): results = yield args, kwargs yield results, [type(r) is not Zero for r in results] - def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 695dc4f4e..51ef8f258 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1118,127 +1118,6 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate): return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals -remat_call_p: Primitive = core.CallPrimitive('remat_call') -remat_call = remat_call_p.bind -remat_call_p.def_impl(core.call_impl) - -def _remat_partial_eval(trace, _, f, tracers, params): - concrete = params['concrete'] - - # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of - # the function being called, not just for the unknown parts. To do that, we - # instantiate all the input tracers as constants in the jaxpr being formed. - # Those tracers might have concrete avals, and doing abstract interpretation - # on concrete avals engenders a tradeoff: it allows data-dependent Python - # control flow to work, but it can in some cases lead to redundant FLOPs (done - # both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the - # `concrete` parameter to switch this behavior, and if `concrete` is False - # then we raise the avals to the Shaped level. - if concrete: - instantiated_tracers = map(trace.instantiate_const, tracers) - else: - instantiated_tracers = map(trace.instantiate_const_abstracted, tracers) - - # Using the instantiated tracers, run call_bind like JaxprTrace.process_call. - # This way we record all primitives applied to the inputs (all treated as - # unknown/instantiated) to produce the output. In the context of autodiff, - # that means we record primal, residual, and tangent computations (e.g. sine, - # cosine, and multiply). - in_pvals = [t.pval for t in instantiated_tracers] - in_knowns, in_avals, () = partition_pvals(in_pvals) # all are unknown - assert not any(in_knowns) - f = trace_to_subjaxpr_nounits(f, trace.main, True) - f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) - consts = remat_call_p.bind(f, **params) # no known inputs - _, out_avals, jaxpr, env = aux() - if jaxpr.effects: - raise NotImplementedError( - 'Effects not supported in partial-eval of `checkpoint`/`remat`.') - env_tracers = map(trace.full_raise, env) - jaxpr = convert_constvars_jaxpr(jaxpr) - del in_pvals, in_knowns, in_avals, out_avals, f, aux, env - # When concrete=True, we could avoid some redundant computation by extracting - # values from any ConcreteArrays in `out_avals`, but we eschew that - # optimization. - - # We're done with `f`, and in the steps to follow we work with `jaxpr`. To - # that end, we want a list of which inputs to `jaxpr` are known/unknown. - in_unknowns = ([False] * len(consts) + - [not t.is_known() for t in it.chain(env_tracers, tracers)]) - - if params['policy']: - # unzip into jaxpr_known and jaxpr_unknown - jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \ - partial_eval_jaxpr_custom(jaxpr, in_unknowns, [True] * len(in_unknowns), - False, False, params['policy']) - jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars)) - _, used_outs_unknown = partition_list(out_inst, out_unknowns) - jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown) - - # compute known outputs and residuals (hoisted out of a remat_call) - _, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers) - if t.pval.is_known()) - _, known_inputs = partition_list(in_used_known, [*consts, *in_consts_]) - outs = core.eval_jaxpr(jaxpr_known, (), *known_inputs) - known_outputs, res = split_list(outs, [len(out_unknowns)-sum(out_unknowns)]) - - # set up unknown outputs with a recipe to call remat - res_tracers = map(trace.new_instantiated_const, res) - const_tracers = map(trace.new_instantiated_const, consts) - in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers, - *instantiated_tracers] - _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers) - unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None) - for x in jaxpr_unknown.outvars] - new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True) - recipe = new_eqn_recipe(in_jaxpr_tracers, unknown_outputs, remat_call_p, - new_params, jaxpr_unknown.effects, source_info_util.current()) - for t in unknown_outputs: t.recipe = recipe - return merge_lists(out_unknowns, known_outputs, unknown_outputs) - else: - # TODO(mattjj): this is an old parallel code path, to be deleted once the - # new path is fully functional - - # Now that we have a `jaxpr` which represents as much as `f` as possible, we - # want to actually compute known output values. To do that, we first extract - # a `jaxpr_known`, and compute which outputs of `jaxpr` are known/unknown. - jaxpr_known_, _, out_unknowns, res_avals = partial_eval_jaxpr_nounits( - core.ClosedJaxpr(jaxpr, ()), in_unknowns, instantiate=False) # type: ignore - jaxpr_known, () = jaxpr_known_.jaxpr, jaxpr_known_.consts - num_res = len(res_avals) - # Next, we need values for known outputs. To get them, we need to evaluate - # jaxpr_known, minus the residual outputs that we don't need. In addition to - # eliminating residual outputs, we should perform DCE to eliminate the - # computation of those residuals; for example, if the primal program - # includes a sine, jaxpr_known includes both the sine and cosine, yet we - # don't want to compute the cosine here. - known_inputs = consts + [t for t in it.chain(env_tracers, tracers) - if t.pval.is_known()] - num_known_outputs = len(out_unknowns) - sum(out_unknowns) - jaxpr_known, kept_inputs = dce_jaxpr( - jaxpr_known, [True] * num_known_outputs + [False] * num_res) - known_inputs = [x for x, kept in zip(known_inputs, kept_inputs) if kept] - known_outputs = core.eval_jaxpr(jaxpr_known, (), *known_inputs) - del jaxpr_known, res_avals, num_res, num_known_outputs, kept_inputs - - # We compute unknown outputs by using the full `jaxpr`, though we can prune - # out of it any known outputs and computations and only keep those - # operations we need to compute the unknown outputs. - jaxpr, kept_inputs = dce_jaxpr(jaxpr, out_unknowns) - const_tracers = map(trace.instantiate_const, map(trace.full_raise, consts)) - unknown_inputs = [*const_tracers, *env_tracers, *instantiated_tracers] - unknown_inputs = [x for x, kept in zip(unknown_inputs, kept_inputs) if kept] - unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None) - for x in jaxpr.outvars] - eqn = new_eqn_recipe(unknown_inputs, unknown_outputs, remat_call_p, - dict(params, call_jaxpr=jaxpr, - differentiated=True), - jaxpr.effects, source_info_util.current()) - for t in unknown_outputs: t.recipe = eqn - - return merge_lists(out_unknowns, known_outputs, unknown_outputs) -call_partial_eval_rules[remat_call_p] = _remat_partial_eval - def partial_eval_jaxpr_custom( jaxpr: Jaxpr, in_unknowns: Sequence[bool], @@ -1400,10 +1279,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \ partial_eval_jaxpr_custom_rules[core.named_call_p] = \ partial(call_partial_eval_custom_rule, 'call_jaxpr', lambda _, __, ___, ____, _____, x, y: (x, y)) -partial_eval_jaxpr_custom_rules[remat_call_p] = \ - partial(call_partial_eval_custom_rule, 'call_jaxpr', - lambda _, __, ___, ____, _____, p1, p2: - (p1, dict(p2, differentiated=True))) def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: @@ -1494,7 +1369,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn return used_inputs, new_eqn dce_rules[core.call_p] = dce_jaxpr_call_rule dce_rules[core.named_call_p] = dce_jaxpr_call_rule -dce_rules[remat_call_p] = dce_jaxpr_call_rule def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn diff --git a/tests/api_test.py b/tests/api_test.py index c3901b7b6..f38a36dae 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3922,55 +3922,6 @@ class RematTest(jtu.JaxTestCase): expected = f_lin_expected(3.) self.assertAllClose(ans, expected, check_dtypes=False) - @unittest.skipIf(config.jax_new_checkpoint, "test is for old remat only") - def test_remat_grad_python_control_flow(self): - # See test `test_remat_grad_python_control_flow_static_argnums` for the - # new recommended way to express this computation. - - def g(x): - if x > 0: - return lax.sin(x), 3. - else: - return lax.cos(x), 4. - - def f(x): - x, _ = g(x) - return x - - with self.assertWarnsRegex(DeprecationWarning, "static_argnums"): - g = api.remat(g, concrete=True) - - ans = f(2.) - expected = np.sin(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(f)(2.) - expected = np.cos(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - @unittest.skipIf(config.jax_new_checkpoint, "new remat raises error here") - def test_remat_concrete_deprecation_warning(self): - def g(x): - if x > 0: - return lax.sin(x), 3. - else: - return lax.cos(x), 4. - - with self.assertWarnsRegex(DeprecationWarning, "static_argnums"): - _ = api.remat(g, concrete=True) - - @unittest.skipIf(not config.jax_new_checkpoint, "old remat warns here") - def test_remat_concrete_deprecation_error(self): - def g(x): - if x > 0: - return lax.sin(x), 3. - else: - return lax.cos(x), 4. - - with self.assertRaisesRegex(NotImplementedError, "static_argnums"): - _ = api.remat(g, concrete=True) - - @unittest.skipIf(not config.jax_new_checkpoint, "old remat different error") def test_remat_concrete_error(self): @api.remat # no static_argnums or concrete def g(x): @@ -4055,7 +4006,6 @@ class RematTest(jtu.JaxTestCase): expected = np.cos(2.) self.assertAllClose(ans, expected, check_dtypes=False) - @unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here") def test_remat_retracing(self): # This is *not* a very important behavior; remat doesn't need to provide # caching guarantees with the same importance as jit. But even so, in the @@ -4079,7 +4029,6 @@ class RematTest(jtu.JaxTestCase): y.block_until_ready() self.assertEqual(count, 1) - @unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here") def test_remat_static_agnums_retracing(self): # This is *not* a super important behavior; remat doesn't need to provide # caching guarantees with the same importance as jit. But even so, in the @@ -4971,7 +4920,6 @@ class RematTest(jtu.JaxTestCase): f_vjp(1.)[0].block_until_ready() self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd - @unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here") def test_fwd_caching(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) @@ -4981,7 +4929,6 @@ class RematTest(jtu.JaxTestCase): y.block_until_ready() self.assertEqual(count[0], 1) - @unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here") def test_fwd_caching_static_argnums(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,)) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 5c3cdd434..de71d4bf5 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1979,7 +1979,7 @@ class HostCallbackTapTest(jtu.JaxTestCase): for grad_func in ["grad", "value_and_grad"] for use_remat in ["old", "new", "none"])) def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): - if config.jax_new_checkpoint and use_remat == "old": raise SkipTest() + if use_remat == "old": raise SkipTest() def f(x): id_print_result = hcb.id_print(x, output_stream=testing_stream)