mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
This commit is contained in:
parent
332d7d0168
commit
d19e34fa4a
@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
## jax 0.3.17 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main).
|
||||
* Breaking changes
|
||||
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
|
||||
the `concrete` option, following the previous version's deprecation; see
|
||||
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).
|
||||
|
||||
## jax 0.3.16
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
|
||||
|
@ -19,9 +19,7 @@ As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a ne
|
||||
|
||||
## How can I disable the change, and go back to the old behavior for now?
|
||||
|
||||
In case you have a problem with this change, it will **temporarily** be possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:
|
||||
|
||||
|
||||
In case you have a problem with this change, **through version `jax==0.3.16`** it is possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:
|
||||
|
||||
1. set the shell environment variable `JAX_NEW_CHECKPOINT=0`;
|
||||
2. execute `jax.config.update('jax_new_checkpoint', False)`;
|
||||
@ -29,6 +27,10 @@ In case you have a problem with this change, it will **temporarily** be possible
|
||||
|
||||
If you need to revert to the old implementation, **please reach out** on a GitHub issue so that we can make the new implementation work for you.
|
||||
|
||||
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
|
||||
available. If you have an issue, please reach out on [the issue
|
||||
tracker](https://github.com/google/jax/issues) so we can help fix it!
|
||||
|
||||
|
||||
## Why are we doing this?
|
||||
|
||||
|
@ -18,19 +18,23 @@ from typing import Callable, Optional, List, Tuple, Sequence, Set, Union, Any
|
||||
import types
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip, merge_lists, weakref_lru_cache)
|
||||
@ -38,9 +42,6 @@ from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
# TODO(mattjj): before this can be the standard remat implementation, we must:
|
||||
# [ ] fix up callers who use the 'concrete' option (now removed)
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
||||
@ -582,6 +583,104 @@ def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
pe.dce_rules[remat_p] = remat_dce
|
||||
|
||||
|
||||
def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
differentiated: bool, is_gpu_platform: bool = False,
|
||||
**_):
|
||||
assert not jaxpr.constvars
|
||||
|
||||
if differentiated and prevent_cse:
|
||||
if jax.config.jax_remat_opt_barrier:
|
||||
translation_rule = _remat_translation_using_opt_barrier
|
||||
elif is_gpu_platform:
|
||||
translation_rule = _remat_translation_using_while
|
||||
else:
|
||||
translation_rule = _remat_translation_using_cond
|
||||
else:
|
||||
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
|
||||
|
||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
args = _optimization_barrier(args)
|
||||
return core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
# TODO(mattjj): add core utility for 'create dummy value for this type'?
|
||||
def _dummy_like(aval: core.AbstractValue) -> Any:
|
||||
if aval is core.abstract_token:
|
||||
return jax.lax.create_token()
|
||||
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
|
||||
return jax.lax.broadcast(jax.lax.empty(aval.dtype), aval.shape) # type: ignore
|
||||
else:
|
||||
raise ValueError(aval)
|
||||
|
||||
def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
|
||||
# result = eval_jaxpr(*args)
|
||||
# }
|
||||
# The loop carry is a tuple: (counter, result, args)
|
||||
avals_out = tuple(v.aval for v in jaxpr.outvars)
|
||||
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
|
||||
def cond(carry):
|
||||
counter, _, _ = carry
|
||||
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
|
||||
return counter < unif
|
||||
|
||||
def body(carry):
|
||||
counter, _, args = carry
|
||||
results = core.eval_jaxpr(jaxpr, (), *args)
|
||||
return (counter + 1, tuple(results), args)
|
||||
|
||||
carry_res = jax.lax.while_loop(cond, body, carry_init)
|
||||
return carry_res[1]
|
||||
|
||||
def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# if(rng(0, 1) < 2)
|
||||
# return eval_jaxpr(*args)
|
||||
# else:
|
||||
# return 0
|
||||
avals_out = tuple(v.aval for v in jaxpr.outvars)
|
||||
|
||||
def remat_comp(*args):
|
||||
return tuple(core.eval_jaxpr(jaxpr, (), *args))
|
||||
def dummy_comp(*args):
|
||||
return tuple(map(_dummy_like, avals_out))
|
||||
|
||||
unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
|
||||
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
|
||||
|
||||
mlir.register_lowering(
|
||||
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
|
||||
mlir.register_lowering(
|
||||
remat_p,
|
||||
mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True),
|
||||
multiple_results=True),
|
||||
platform="gpu")
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_barrier_types = util.flatten(barrier_types)
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
||||
return util.unflatten(barrier_op.results, map(len, barrier_types))
|
||||
|
||||
def _optimization_barrier(arg):
|
||||
flat_args, treedef = tree_flatten(arg)
|
||||
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
||||
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(xla.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
||||
|
||||
|
||||
def checkpoint_name(x, name):
|
||||
return name_p.bind(x, name=name)
|
||||
|
||||
|
@ -3103,27 +3103,11 @@ def checkpoint(fun: Callable, *,
|
||||
" return f(x)\n"
|
||||
" else:\n"
|
||||
" return g(x)\n"
|
||||
"\n")
|
||||
if config.jax_new_checkpoint:
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
warn(msg, DeprecationWarning)
|
||||
|
||||
if config.jax_new_checkpoint:
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def remat_f(*args, **kwargs):
|
||||
f, args = _remat_static_argnums(fun, static_argnums, args)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
|
||||
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
|
||||
concrete=concrete, prevent_cse=prevent_cse,
|
||||
differentiated=False, policy=policy)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return remat_f
|
||||
"\n"
|
||||
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
|
||||
raise NotImplementedError(msg)
|
||||
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
remat = checkpoint # type: ignore
|
||||
|
||||
|
||||
|
@ -875,13 +875,6 @@ config.define_bool_state(
|
||||
default=(lib.version >= (0, 3, 6)),
|
||||
help=('Enables using optimization-barrier op for lowering remat.'))
|
||||
|
||||
# TODO(mattjj): set default to True, then remove
|
||||
config.define_bool_state(
|
||||
name='jax_new_checkpoint',
|
||||
default=True,
|
||||
upgrade=True,
|
||||
help='Whether to use the new jax.checkpoint implementation.')
|
||||
|
||||
# TODO(b/205307544): Remove flag once coordination service has rolled out.
|
||||
config.define_bool_state(
|
||||
name='jax_coordination_service',
|
||||
|
@ -19,8 +19,6 @@ from jax._src.lax.control_flow.loops import (associative_scan, cummax, cummax_p,
|
||||
scan, scan_bind, scan_p,
|
||||
_scan_impl, while_loop, while_p)
|
||||
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
|
||||
from jax._src.lax.control_flow.remat_impl import (remat_impl,
|
||||
optimization_barrier_p)
|
||||
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
|
||||
_custom_linear_solve_impl,
|
||||
linear_solve_p)
|
||||
@ -32,3 +30,5 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
|
||||
_initial_style_jaxpr,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_check_tree_and_avals)
|
||||
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
|
||||
from jax._src.ad_checkpoint import optimization_barrier_p
|
||||
|
@ -1,152 +0,0 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for the remat implementation."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import util
|
||||
from jax._src.util import safe_map, wrap_name
|
||||
from jax._src.lax.control_flow.conditionals import cond
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lax.control_flow.loops import while_loop
|
||||
import numpy as np
|
||||
|
||||
_map = safe_map
|
||||
|
||||
def _dummy_remat_result(aval: core.AbstractValue):
|
||||
"""A result that will be discarded"""
|
||||
if aval is core.abstract_token:
|
||||
return lax.create_token()
|
||||
else:
|
||||
return lax.broadcast(np.array(0, dtype=aval.dtype), aval.shape) # type: ignore
|
||||
|
||||
def _remat_translation_using_cond(*args,
|
||||
jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# if(rng(0, 1) < 2)
|
||||
# return eval_jaxpr(*args)
|
||||
# else:
|
||||
# return 0
|
||||
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
||||
|
||||
def remat_comp(*args):
|
||||
return tuple(core.eval_jaxpr(jaxpr, (), *args))
|
||||
def dummy_comp(*args):
|
||||
return tuple(_map(_dummy_remat_result, avals_out))
|
||||
|
||||
cond_pred = (lax.rng_uniform(np.float32(0), np.float32(1), shape=()) < np.float32(2))
|
||||
return cond(cond_pred, remat_comp, dummy_comp, *args)
|
||||
|
||||
def _remat_translation_using_while(*args,
|
||||
jaxpr: core.Jaxpr):
|
||||
# Implements:
|
||||
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
|
||||
# result = eval_jaxpr(*args)
|
||||
# }
|
||||
# The loop carry is a tuple: (counter, result, args)
|
||||
avals_out = tuple(ov.aval for ov in jaxpr.outvars)
|
||||
dummies_like_result = tuple(_map(_dummy_remat_result, avals_out))
|
||||
carry_init = (np.int32(0), dummies_like_result, args)
|
||||
def cond(carry):
|
||||
counter, _, _ = carry
|
||||
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
|
||||
|
||||
def body(carry):
|
||||
counter, _, args = carry
|
||||
results = core.eval_jaxpr(jaxpr, (), *args)
|
||||
return (counter + 1, tuple(results), args)
|
||||
|
||||
carry_res = while_loop(cond, body, carry_init)
|
||||
return carry_res[1]
|
||||
|
||||
|
||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
args = _optimization_barrier(args)
|
||||
return core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
|
||||
def remat_impl(*args,
|
||||
call_jaxpr: Optional[core.Jaxpr] = None,
|
||||
jaxpr: Optional[core.Jaxpr] = None,
|
||||
prevent_cse: bool, differentiated: bool,
|
||||
policy,
|
||||
is_gpu_platform: bool = False,
|
||||
concrete: bool = False,
|
||||
name: str = "checkpoint"):
|
||||
# Support either "jaxpr" (for remat2) and "call_jaxpr" (for remat)
|
||||
# name is not passed for remat2, defaults to "checkpoint"
|
||||
# TODO: remove call_jaxpr once we drop the remat call primitive
|
||||
if jaxpr is None:
|
||||
jaxpr = call_jaxpr
|
||||
assert jaxpr is not None
|
||||
assert not jaxpr.constvars
|
||||
|
||||
del concrete, policy # Unused.
|
||||
if differentiated and prevent_cse:
|
||||
if config.jax_remat_opt_barrier:
|
||||
translation_rule = _remat_translation_using_opt_barrier
|
||||
elif is_gpu_platform:
|
||||
translation_rule = _remat_translation_using_while
|
||||
else:
|
||||
translation_rule = _remat_translation_using_cond
|
||||
else:
|
||||
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
return jax.named_call(translation_rule, name=wrap_name(name, "remat"))(*args, jaxpr=jaxpr)
|
||||
|
||||
for remat_primitive in (pe.remat_call_p, ad_checkpoint.remat_p): # type: ignore
|
||||
mlir.register_lowering(remat_primitive,
|
||||
mlir.lower_fun(remat_impl, multiple_results=True))
|
||||
mlir.register_lowering(remat_primitive,
|
||||
mlir.lower_fun(partial(remat_impl,
|
||||
is_gpu_platform=True),
|
||||
multiple_results=True),
|
||||
platform="gpu")
|
||||
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_barrier_types = util.flatten(barrier_types)
|
||||
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
||||
return util.unflatten(barrier_op.results, _map(len, barrier_types))
|
||||
|
||||
|
||||
def _optimization_barrier(arg):
|
||||
flat_args, treedef = tree_flatten(arg)
|
||||
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
||||
|
||||
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(xla.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
@ -1615,16 +1615,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
# cased to just pass-through the token
|
||||
in_axes=eqn.params["in_axes"] + (None, None),
|
||||
out_axes=eqn.params["out_axes"] + (0, 0))))
|
||||
elif eqn.primitive is pe.remat_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
)))
|
||||
elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p:
|
||||
fun_jaxpr = eqn.params["fun_jaxpr"]
|
||||
|
||||
|
@ -1070,9 +1070,9 @@ class TensorFlowTrace(core.Trace):
|
||||
|
||||
def post_process_call(self, call_primitive: core.Primitive,
|
||||
out_tracers: Sequence[TensorFlowTracer], params):
|
||||
# We encountered a call primitive, e.g., remat_call_p, whose result
|
||||
# (out_tracers) include TensorFlowTracer that were not passed through
|
||||
# its arguments (captured from the environment).
|
||||
# We encountered a call primitive whose result (out_tracers) include
|
||||
# TensorFlowTracer that were not passed through its arguments (captured from
|
||||
# the environment).
|
||||
vals = tuple(t.val for t in out_tracers)
|
||||
main = self.main
|
||||
|
||||
@ -1137,8 +1137,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
|
||||
|
||||
# Call primitives are inlined
|
||||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p,
|
||||
partial_eval.remat_call_p, maps.xmap_p]:
|
||||
for unexpected in [core.call_p, core.named_call_p, xla.xla_call_p, maps.xmap_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
@ -2549,7 +2548,7 @@ tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
|
||||
extra_name_stack="scan")
|
||||
|
||||
tf_impl_with_avals[ad_checkpoint.remat_p] = \
|
||||
_convert_jax_impl(partial(lax_control_flow.remat_impl,
|
||||
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
|
||||
# TODO: jax2tf cannot discriminate by platform
|
||||
is_gpu_platform=False),
|
||||
multiple_results=True,
|
||||
|
@ -42,7 +42,7 @@ import jax._src.lib.xla_bridge
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||||
from google3.third_party.tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||||
# pylint: enable=g-direct-tensorflow-import
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -766,16 +766,13 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.TransformConvertAndCompare(f, arg, "grad")
|
||||
self.TransformConvertAndCompare(f, arg, "grad_vmap")
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{flavor}", flavor=flavor)
|
||||
for flavor in ["old", "new"]))
|
||||
def test_remat(self, flavor="old"):
|
||||
def test_remat(self):
|
||||
def f(x1):
|
||||
x2 = jnp.sin(x1)
|
||||
x3 = jnp.sin(x2)
|
||||
x4 = jnp.sin(x3)
|
||||
return x4
|
||||
remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f)
|
||||
remat_f = ad_checkpoint.checkpoint(f)
|
||||
|
||||
# The computation of grad_f computes "sin" 5 times, 3 for the forward pass
|
||||
# and then to rematerialize "x2" and "x3" in the backward pass.
|
||||
@ -783,21 +780,19 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
# Check that we have a Sin under a conditional
|
||||
f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False)
|
||||
f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def()
|
||||
if flavor == "old":
|
||||
raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat")
|
||||
if jax.config.jax_remat_opt_barrier:
|
||||
if config.jax2tf_default_experimental_native_lowering:
|
||||
self.assertRegex(
|
||||
str(f_tf_graph), r"mhlo.optimization_barrier")
|
||||
else:
|
||||
self.assertRegex(
|
||||
str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier")
|
||||
str(f_tf_graph), r"XlaOptimizationBarrier")
|
||||
elif config.jax_experimental_name_stack:
|
||||
self.assertRegex(str(f_tf_graph),
|
||||
r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin')
|
||||
r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin')
|
||||
else:
|
||||
self.assertRegex(str(f_tf_graph),
|
||||
r'remat_checkpoint_/switch_case/indexed_case/Sin')
|
||||
r'switch_case/indexed_case/Sin')
|
||||
|
||||
def test_remat_free_var(self):
|
||||
def f(x):
|
||||
|
@ -637,39 +637,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
|
||||
primitive_transposes[core.closed_call_p] = _closed_call_transpose
|
||||
|
||||
|
||||
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
|
||||
cotangent_in_avals, reduce_axes):
|
||||
unknowns = map(is_undefined_primal, primals_in)
|
||||
primal_jaxpr, tangent_jaxpr, _, _ = \
|
||||
pe.partial_eval_jaxpr_nounits(pe.close_jaxpr(call_jaxpr),
|
||||
unknowns=unknowns, instantiate=True) # type: ignore
|
||||
args, in_tree = tree_flatten((primals_in, cotangents_in))
|
||||
transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr,
|
||||
tangent_jaxpr, reduce_axes)
|
||||
flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree)
|
||||
flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params)
|
||||
return tree_unflatten(out_tree(), flat_cotangents_out)
|
||||
primitive_transposes[pe.remat_call_p] = remat_transpose
|
||||
|
||||
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes,
|
||||
primals_tangents_in, cotangents_in):
|
||||
primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)]
|
||||
tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)]
|
||||
res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)
|
||||
cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (),
|
||||
(*res, *tangents_in), cotangents_in)
|
||||
cotangents_out = iter(cotangents_out_[len(res):])
|
||||
outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x)
|
||||
for x in primals_tangents_in]
|
||||
assert next(cotangents_out, None) is None
|
||||
return outs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def nonzero_outputs(*args, **kwargs):
|
||||
results = yield args, kwargs
|
||||
yield results, [type(r) is not Zero for r in results]
|
||||
|
||||
|
||||
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
|
||||
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
|
||||
|
@ -1118,127 +1118,6 @@ def _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate):
|
||||
return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals
|
||||
|
||||
|
||||
remat_call_p: Primitive = core.CallPrimitive('remat_call')
|
||||
remat_call = remat_call_p.bind
|
||||
remat_call_p.def_impl(core.call_impl)
|
||||
|
||||
def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
concrete = params['concrete']
|
||||
|
||||
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
|
||||
# the function being called, not just for the unknown parts. To do that, we
|
||||
# instantiate all the input tracers as constants in the jaxpr being formed.
|
||||
# Those tracers might have concrete avals, and doing abstract interpretation
|
||||
# on concrete avals engenders a tradeoff: it allows data-dependent Python
|
||||
# control flow to work, but it can in some cases lead to redundant FLOPs (done
|
||||
# both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the
|
||||
# `concrete` parameter to switch this behavior, and if `concrete` is False
|
||||
# then we raise the avals to the Shaped level.
|
||||
if concrete:
|
||||
instantiated_tracers = map(trace.instantiate_const, tracers)
|
||||
else:
|
||||
instantiated_tracers = map(trace.instantiate_const_abstracted, tracers)
|
||||
|
||||
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
||||
# This way we record all primitives applied to the inputs (all treated as
|
||||
# unknown/instantiated) to produce the output. In the context of autodiff,
|
||||
# that means we record primal, residual, and tangent computations (e.g. sine,
|
||||
# cosine, and multiply).
|
||||
in_pvals = [t.pval for t in instantiated_tracers]
|
||||
in_knowns, in_avals, () = partition_pvals(in_pvals) # all are unknown
|
||||
assert not any(in_knowns)
|
||||
f = trace_to_subjaxpr_nounits(f, trace.main, True)
|
||||
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
|
||||
consts = remat_call_p.bind(f, **params) # no known inputs
|
||||
_, out_avals, jaxpr, env = aux()
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError(
|
||||
'Effects not supported in partial-eval of `checkpoint`/`remat`.')
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
del in_pvals, in_knowns, in_avals, out_avals, f, aux, env
|
||||
# When concrete=True, we could avoid some redundant computation by extracting
|
||||
# values from any ConcreteArrays in `out_avals`, but we eschew that
|
||||
# optimization.
|
||||
|
||||
# We're done with `f`, and in the steps to follow we work with `jaxpr`. To
|
||||
# that end, we want a list of which inputs to `jaxpr` are known/unknown.
|
||||
in_unknowns = ([False] * len(consts) +
|
||||
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
||||
|
||||
if params['policy']:
|
||||
# unzip into jaxpr_known and jaxpr_unknown
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
|
||||
partial_eval_jaxpr_custom(jaxpr, in_unknowns, [True] * len(in_unknowns),
|
||||
False, False, params['policy'])
|
||||
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
|
||||
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
|
||||
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
|
||||
|
||||
# compute known outputs and residuals (hoisted out of a remat_call)
|
||||
_, in_consts_ = unzip2(t.pval for t in it.chain(env_tracers, tracers)
|
||||
if t.pval.is_known())
|
||||
_, known_inputs = partition_list(in_used_known, [*consts, *in_consts_])
|
||||
outs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
|
||||
known_outputs, res = split_list(outs, [len(out_unknowns)-sum(out_unknowns)])
|
||||
|
||||
# set up unknown outputs with a recipe to call remat
|
||||
res_tracers = map(trace.new_instantiated_const, res)
|
||||
const_tracers = map(trace.new_instantiated_const, consts)
|
||||
in_jaxpr_tracers = [*res_tracers, *const_tracers, *env_tracers,
|
||||
*instantiated_tracers]
|
||||
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
|
||||
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
|
||||
for x in jaxpr_unknown.outvars]
|
||||
new_params = dict(params, call_jaxpr=jaxpr_unknown, differentiated=True)
|
||||
recipe = new_eqn_recipe(in_jaxpr_tracers, unknown_outputs, remat_call_p,
|
||||
new_params, jaxpr_unknown.effects, source_info_util.current())
|
||||
for t in unknown_outputs: t.recipe = recipe
|
||||
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
|
||||
else:
|
||||
# TODO(mattjj): this is an old parallel code path, to be deleted once the
|
||||
# new path is fully functional
|
||||
|
||||
# Now that we have a `jaxpr` which represents as much as `f` as possible, we
|
||||
# want to actually compute known output values. To do that, we first extract
|
||||
# a `jaxpr_known`, and compute which outputs of `jaxpr` are known/unknown.
|
||||
jaxpr_known_, _, out_unknowns, res_avals = partial_eval_jaxpr_nounits(
|
||||
core.ClosedJaxpr(jaxpr, ()), in_unknowns, instantiate=False) # type: ignore
|
||||
jaxpr_known, () = jaxpr_known_.jaxpr, jaxpr_known_.consts
|
||||
num_res = len(res_avals)
|
||||
# Next, we need values for known outputs. To get them, we need to evaluate
|
||||
# jaxpr_known, minus the residual outputs that we don't need. In addition to
|
||||
# eliminating residual outputs, we should perform DCE to eliminate the
|
||||
# computation of those residuals; for example, if the primal program
|
||||
# includes a sine, jaxpr_known includes both the sine and cosine, yet we
|
||||
# don't want to compute the cosine here.
|
||||
known_inputs = consts + [t for t in it.chain(env_tracers, tracers)
|
||||
if t.pval.is_known()]
|
||||
num_known_outputs = len(out_unknowns) - sum(out_unknowns)
|
||||
jaxpr_known, kept_inputs = dce_jaxpr(
|
||||
jaxpr_known, [True] * num_known_outputs + [False] * num_res)
|
||||
known_inputs = [x for x, kept in zip(known_inputs, kept_inputs) if kept]
|
||||
known_outputs = core.eval_jaxpr(jaxpr_known, (), *known_inputs)
|
||||
del jaxpr_known, res_avals, num_res, num_known_outputs, kept_inputs
|
||||
|
||||
# We compute unknown outputs by using the full `jaxpr`, though we can prune
|
||||
# out of it any known outputs and computations and only keep those
|
||||
# operations we need to compute the unknown outputs.
|
||||
jaxpr, kept_inputs = dce_jaxpr(jaxpr, out_unknowns)
|
||||
const_tracers = map(trace.instantiate_const, map(trace.full_raise, consts))
|
||||
unknown_inputs = [*const_tracers, *env_tracers, *instantiated_tracers]
|
||||
unknown_inputs = [x for x, kept in zip(unknown_inputs, kept_inputs) if kept]
|
||||
unknown_outputs = [JaxprTracer(trace, PartialVal.unknown(x.aval), None)
|
||||
for x in jaxpr.outvars]
|
||||
eqn = new_eqn_recipe(unknown_inputs, unknown_outputs, remat_call_p,
|
||||
dict(params, call_jaxpr=jaxpr,
|
||||
differentiated=True),
|
||||
jaxpr.effects, source_info_util.current())
|
||||
for t in unknown_outputs: t.recipe = eqn
|
||||
|
||||
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
|
||||
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
|
||||
|
||||
def partial_eval_jaxpr_custom(
|
||||
jaxpr: Jaxpr,
|
||||
in_unknowns: Sequence[bool],
|
||||
@ -1400,10 +1279,6 @@ partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[remat_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, p1, p2:
|
||||
(p1, dict(p2, differentiated=True)))
|
||||
|
||||
|
||||
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
@ -1494,7 +1369,6 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[remat_call_p] = dce_jaxpr_call_rule
|
||||
|
||||
|
||||
def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
|
@ -3922,55 +3922,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
expected = f_lin_expected(3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(config.jax_new_checkpoint, "test is for old remat only")
|
||||
def test_remat_grad_python_control_flow(self):
|
||||
# See test `test_remat_grad_python_control_flow_static_argnums` for the
|
||||
# new recommended way to express this computation.
|
||||
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
def f(x):
|
||||
x, _ = g(x)
|
||||
return x
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
|
||||
g = api.remat(g, concrete=True)
|
||||
|
||||
ans = f(2.)
|
||||
expected = np.sin(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
ans = api.grad(f)(2.)
|
||||
expected = np.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(config.jax_new_checkpoint, "new remat raises error here")
|
||||
def test_remat_concrete_deprecation_warning(self):
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "static_argnums"):
|
||||
_ = api.remat(g, concrete=True)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat warns here")
|
||||
def test_remat_concrete_deprecation_error(self):
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return lax.sin(x), 3.
|
||||
else:
|
||||
return lax.cos(x), 4.
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "static_argnums"):
|
||||
_ = api.remat(g, concrete=True)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat different error")
|
||||
def test_remat_concrete_error(self):
|
||||
@api.remat # no static_argnums or concrete
|
||||
def g(x):
|
||||
@ -4055,7 +4006,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
expected = np.cos(2.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
|
||||
def test_remat_retracing(self):
|
||||
# This is *not* a very important behavior; remat doesn't need to provide
|
||||
# caching guarantees with the same importance as jit. But even so, in the
|
||||
@ -4079,7 +4029,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat retraces here")
|
||||
def test_remat_static_agnums_retracing(self):
|
||||
# This is *not* a super important behavior; remat doesn't need to provide
|
||||
# caching guarantees with the same importance as jit. But even so, in the
|
||||
@ -4971,7 +4920,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 1) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
|
||||
@ -4981,7 +4929,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(not config.jax_new_checkpoint, "old remat recompiles here")
|
||||
def test_fwd_caching_static_argnums(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))
|
||||
|
@ -1979,7 +1979,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
for grad_func in ["grad", "value_and_grad"]
|
||||
for use_remat in ["old", "new", "none"]))
|
||||
def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"):
|
||||
if config.jax_new_checkpoint and use_remat == "old": raise SkipTest()
|
||||
if use_remat == "old": raise SkipTest()
|
||||
|
||||
def f(x):
|
||||
id_print_result = hcb.id_print(x, output_stream=testing_stream)
|
||||
|
Loading…
x
Reference in New Issue
Block a user