simplify remat partial eval parameterization

The main win here is reducing the number of arguments for the function
that parameterizes _remat_partial_eval (so it can be used both with
remat and invertible ad features).

I also included a fix to _remat_partial_eval that is needed in #3370,
though I don't think it's needed on master. It was easier to include the
fix now.

Both these changes made rebasing #3370 easier!
This commit is contained in:
Matthew Johnson 2020-06-15 18:42:53 -07:00
parent 6bcf0568aa
commit 05e0716f40
3 changed files with 225 additions and 215 deletions

View File

@ -22,7 +22,8 @@ from jax import linear_util as lu
from . import ad
from . import partial_eval as pe
from .partial_eval import (PartialVal, partial_eval_jaxpr,
JaxprTracer, ConstVar, convert_constvars_jaxpr, new_eqn_recipe)
JaxprTracer, ConstVar, convert_constvars_jaxpr,
new_eqn_recipe, _partition_knowns)
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from ..custom_derivatives import _initial_style_jaxpr, _resolve_kwargs
from ..api_util import flatten_fun_nokwargs
@ -43,24 +44,25 @@ invertible_call_p.def_custom_bind(invertible_call)
invertible_call_p.def_impl(core.call_impl)
invertible_call_p.multiple_results = True
def _invertible_call_make_output_tracers(trace, typed_jaxpr, in_tracers, out_known_pvals, out_unknown_pvals, _, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
# Add dummy arguments representing the outputs to the jaxpr. Those should remain unused in case
# the expression actually ends up being evaluated, but they make it well-formed.
out_known_avals = tuple(raise_to_shaped(get_aval(pval.get_known())) for pval in out_known_pvals)
lifted_jaxpr = _append_invars(lifted_jaxpr, out_known_avals)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
# We also append some dummy outputs that correspond to the known outputs we left in the call_jaxpr
dummy_outputs = [JaxprTracer(trace, pval, core.unit) for pval in out_known_pvals]
def _invertible_call_make_output_tracers(trace, in_tracers, out_tracers, params):
uks = [not t.pval.is_known() for t in out_tracers]
out_tracers_known, out_tracers_unknown = _partition_knowns(out_tracers, uks)
output_constants = [JaxprTracer(trace, pval, ConstVar(pval.get_known())) for pval in out_known_pvals]
eqn = new_eqn_recipe(tuple(it.chain(in_tracers, output_constants)),
dummy_outputs + unknown_output_tracers,
invertible_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers
# Add dummy arguments representing the outputs to the jaxpr. Those should
# remain unused if the expression is evaluated, but they make it well-formed.
out_known_avals = [raise_to_shaped(t.pval.get_aval()) for t in out_tracers_known]
out_consts = [trace.instantiate_const(t) for t in out_tracers_known]
new_jaxpr = _append_invars(params['call_jaxpr'], tuple(out_known_avals))
new_in_tracers = (*in_tracers, *out_consts)
# Append dummy outputs that correspond to known outputs left in the call_jaxpr
dummy_outputs = [JaxprTracer(trace, t.pval, core.unit) for t in out_tracers_known]
new_out_tracers = (*dummy_outputs, *out_tracers_unknown)
eqn = new_eqn_recipe(new_in_tracers, new_out_tracers, invertible_call_p,
dict(params, call_jaxpr=new_jaxpr))
for t in out_tracers_unknown: t.recipe = eqn
return new_out_tracers
pe.call_partial_eval_rules[invertible_call_p] = partial(
pe._remat_partial_eval, _invertible_call_make_output_tracers)
@ -69,7 +71,8 @@ pe.call_partial_eval_rules[invertible_call_p] = partial(
@cache()
def _append_invars(jaxpr, avals):
newvar = core.gensym([jaxpr])
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals),
jaxpr.outvars, jaxpr.eqns)
def _invertible_call_transpose(params, call_jaxpr, args, ct, _):

View File

@ -680,7 +680,12 @@ remat_call_p.def_custom_bind(remat_call)
remat_call_p.def_impl(core.call_impl)
remat_call_p.multiple_results = True
def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# We reuse the _remat_partial_eval function both for remat_call and for
# invertible_call, both of which in a sense stage out operations to
# rematerialize values. The two usages differ only in details of what jaxpr eqn
# and output tracers are formed. As a result we parameterize _remat_partial_eval
# by a `process_out` function.
def _remat_partial_eval(process_out, trace, _, f, tracers, params):
concrete = params['concrete']
# Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of
@ -699,24 +704,27 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [t.pval for t in instantiated_tracers]
with core.initial_style_staging():
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params))
# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
const_tracers = map(trace.new_instantiated_const, consts)
# Since we traced with everything marked as unknown, but we need to know which
# outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns.
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in env_tracers]
+ [raise_to_shaped(pval.get_aval()) for pval in in_pvals])
in_avals = ([raise_to_shaped(t.pval.get_aval()) for t in const_tracers] +
[raise_to_shaped(t.pval.get_aval()) for t in env_tracers] +
[raise_to_shaped(pval.get_aval()) for pval in in_pvals])
out_avals = [raise_to_shaped(abstract_unit if var is unitvar
else get_aval(var.val) if type(var) is Literal
else pval.get_aval())
for var, pval in zip(jaxpr.outvars, eval_out_pvals)]
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
in_unknowns = [not t.is_known() for t in it.chain(env_tracers, tracers)]
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns,
instantiate=False,
trace_type=trace.master.trace_type)
typed_jaxpr = core.TypedJaxpr(jaxpr, (), in_avals, out_avals)
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
typed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.master.trace_type)
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
@ -728,56 +736,52 @@ def _remat_partial_eval(wrap_unknown_pvals, trace, _, f, tracers, params):
# values. For the use case of inverse-mode ad in op-by-op ("eager mode")
# evaluation, all the primal outputs should be concrete (thus not recomputed).
to_compute = [type(pval[0]) is not ConcreteArray
for uk, pval in zip(out_unknowns, eval_out_pvals)
if not uk]
for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk]
num_outputs = len(jaxpr_unknown.out_avals)
num_res = len(jaxpr_known.out_avals) - num_outputs
jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True)
jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute)
_, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers))
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*in_consts)
reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts)
out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts)
# Now that we have out_pvals, the rest is similar to JaxprTrace.process_call.
# Known outputs should keep propagating as constants
assert all(pv.is_known() for pv in out_known_pvals)
known_output_tracers = [trace.new_const(pval.get_known()) for pval in out_known_pvals]
known_output_tracers = [trace.new_const(pval.get_known())
for pval in out_known_pvals]
# Unknown outputs get wrapped in tracers with the appropriate recipe
unknown_output_tracers = [JaxprTracer(trace, out_pval, None)
for out_pval in out_unknown_pvals]
out_tracers = _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
# Unknown outputs get wrapped in tracers with the appropriate recipe, as in JaxprTrace.process_call
const_tracers = map(trace.new_instantiated_const, consts)
unknown_output_tracers = wrap_unknown_pvals(
trace,
typed_jaxpr,
tuple(it.chain(const_tracers, env_tracers, instantiated_tracers)),
out_known_pvals,
out_unknown_pvals,
out_unknowns,
params)
in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers)
new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(typed_jaxpr.jaxpr))
return process_out(trace, in_tracers, out_tracers, new_params)
return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns)
def _remat_make_output_tracers(_, in_tracers, out_tracers, params):
# dce jaxpr outputs
jaxpr = params['call_jaxpr']
out_unknowns = [not t.pval.is_known() for t in out_tracers]
typed_jaxpr = core.TypedJaxpr(jaxpr, (), [v.aval for v in jaxpr.invars],
[v.aval for v in jaxpr.outvars])
new_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True).jaxpr
new_params = dict(params, call_jaxpr=new_jaxpr)
def _remat_make_output_tracers(trace, typed_jaxpr, input_tracers, _, out_unknown_pvals, out_unknowns, params):
unknown_output_tracers = [JaxprTracer(trace, out_pval, None) for out_pval in out_unknown_pvals]
typed_jaxpr = _dce_jaxpr(typed_jaxpr, out_unknowns, drop_outputs=True)
lifted_jaxpr = convert_constvars_jaxpr(typed_jaxpr.jaxpr)
new_params = dict(params, call_jaxpr=lifted_jaxpr)
eqn = new_eqn_recipe(input_tracers,
unknown_output_tracers,
remat_call_p,
new_params)
for t in unknown_output_tracers: t.recipe = eqn
return unknown_output_tracers
# set up eqn for unknown outputs
unknown_out_tracers = [t for t in out_tracers if not t.pval.is_known()]
eqn = new_eqn_recipe(in_tracers, unknown_out_tracers, remat_call_p, new_params)
for t in unknown_out_tracers: t.recipe = eqn
return out_tracers
call_partial_eval_rules[remat_call_p] = partial(
_remat_partial_eval, _remat_make_output_tracers)
call_partial_eval_rules[remat_call_p] = partial(_remat_partial_eval, _remat_make_output_tracers)
def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])
def _partition_knowns(l, unknowns):
return ([e for e, unknown in zip(l, unknowns) if not unknown],
[e for e, unknown in zip(l, unknowns) if unknown])
def _zip_knowns(kl, ul, unknowns):
ul_it = iter(ul)
kl_it = iter(kl)
return [next(ul_it) if unknown else next(kl_it) for unknown in unknowns]
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
def _dce_jaxpr(typed_jaxpr: TypedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> TypedJaxpr:

View File

@ -1266,6 +1266,156 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(ans1, np.cos(2.), check_dtypes=False)
self.assertAllClose(ans2, np.cos(3.), check_dtypes=False)
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
self.assertIs(x, y)
z1, z2 = api.jit(lambda x: (x, x))(x)
self.assertIs(z1, z2)
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertIs(z1, x2)
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_nested_jit_hoisting(self):
@api.jit
def f(x, y):
z = 2 * x
return y + z, 3
@api.jit
def g(x):
return f(2, x)
jaxpr_subcomp = xla.jaxpr_subcomp
jaxprs = []
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
jaxprs.append(jaxpr)
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
try:
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
ans = g(3)
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
self.assertEqual(ans, (7, 3))
self.assertLen(jaxprs, 2)
outer_jaxpr, inner_jaxpr = jaxprs
self.assertLen(outer_jaxpr.eqns, 1)
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
self.assertLen(inner_jaxpr.eqns, 2)
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
def test_primitive_compilation_cache(self):
with jtu.count_primitive_compiles() as count:
lax.add(1, 2)
lax.add(2, 3)
self.assertEqual(count[0], 1)
def test_arange_jit(self):
# see https://github.com/google/jax/issues/553
def fun(x):
r = jnp.arange(x.shape[0])[x]
return r
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
def helper_save_tracer(self, x):
self._saved_tracer = x
return x
def test_escaped_tracers_diffent_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Different traces at same level",
re.DOTALL)):
api.jit(lambda x: self._saved_tracer)(0.)
def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
api.grad(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_incompatible_sublevel(self):
def func1(x):
api.jit(self.helper_save_tracer)(0.)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracers_cant_lift(self):
def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Can't lift",
re.DOTALL)):
api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self):
def func1(x):
api.grad(self.helper_save_tracer)(x)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
api.jit(func1)(2.)
def test_pmap_static_kwarg_error_message(self):
# https://github.com/google/jax/issues/3007
def f(a, b):
return a + b
g = jax.pmap(f, static_broadcasted_argnums=(1,))
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
r"called with only 1 positional argument. All static broadcasted "
r"arguments must be passed positionally.")
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)
def test_vmap_unmapped_last(self):
@partial(jax.vmap, out_axes=jax.interpreters.batching.last)
def f(x):
return np.zeros((2,))
f(np.zeros((5,)))
class RematTest(jtu.JaxTestCase):
def test_remat_basic(self):
@api.remat
def g(x):
@ -1622,154 +1772,6 @@ class APITest(jtu.JaxTestCase):
vjp(v)
def test_trivial_computations(self):
x = jnp.array([1, 2, 3])
y = api.jit(lambda x: x)(x)
self.assertIs(x, y)
z1, z2 = api.jit(lambda x: (x, x))(x)
self.assertIs(z1, z2)
x1, x2 = jnp.array([1, 2]), jnp.array([2, 3])
z1, z2, z3 = api.jit(lambda x, y: (y, 1, x))(x1, x2)
self.assertIs(z1, x2)
self.assertIs(z3, x1)
self.assertEqual(z2, 1)
def test_nested_jit_hoisting(self):
@api.jit
def f(x, y):
z = 2 * x
return y + z, 3
@api.jit
def g(x):
return f(2, x)
jaxpr_subcomp = xla.jaxpr_subcomp
jaxprs = []
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
jaxprs.append(jaxpr)
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
try:
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
ans = g(3)
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
self.assertEqual(ans, (7, 3))
self.assertLen(jaxprs, 2)
outer_jaxpr, inner_jaxpr = jaxprs
self.assertLen(outer_jaxpr.eqns, 1)
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
self.assertLen(inner_jaxpr.eqns, 2)
self.assertEqual(inner_jaxpr.eqns[0].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[1].primitive.name, 'add')
def test_primitive_compilation_cache(self):
with jtu.count_primitive_compiles() as count:
lax.add(1, 2)
lax.add(2, 3)
self.assertEqual(count[0], 1)
def test_arange_jit(self):
# see https://github.com/google/jax/issues/553
def fun(x):
r = jnp.arange(x.shape[0])[x]
return r
jit(fun)(jnp.array([0, 1, 2], dtype=jnp.int32)) # doesn't crash
def helper_save_tracer(self, x):
self._saved_tracer = x
return x
def test_escaped_tracers_diffent_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Different traces at same level",
re.DOTALL)):
api.jit(lambda x: self._saved_tracer)(0.)
def test_escaped_tracers_cant_lift_sublevels(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_tracer_from_higher_level(self):
api.grad(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer from a higher level",
re.DOTALL)):
api.grad(lambda x: x)(self._saved_tracer)
def test_escaped_tracers_incompatible_sublevel(self):
def func1(x):
api.jit(self.helper_save_tracer)(0.)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Incompatible sublevel",
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracers_cant_lift(self):
def func1(x):
api.grad(self.helper_save_tracer)(0.)
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile("Encountered an unexpected tracer.*Can't lift",
re.DOTALL)):
api.grad(func1)(2.)
def test_escaped_tracers_not_among_input_tracers(self):
def func1(x):
api.grad(self.helper_save_tracer)(x)
# Use the tracer
return x + self._saved_tracer
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Tracer not among input tracers",
re.DOTALL)):
api.jit(func1)(2.)
def test_pmap_static_kwarg_error_message(self):
# https://github.com/google/jax/issues/3007
def f(a, b):
return a + b
g = jax.pmap(f, static_broadcasted_argnums=(1,))
msg = (r"pmapped function has static_broadcasted_argnums=\(1,\) but was "
r"called with only 1 positional argument. All static broadcasted "
r"arguments must be passed positionally.")
with self.assertRaisesRegex(ValueError, msg):
g(jnp.ones((1, 1)), b=1)
def test_vmap_unmapped_last(self):
@partial(jax.vmap, out_axes=jax.interpreters.batching.last)
def f(x):
return np.zeros((2,))
f(np.zeros((5,)))
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):
@ -2833,6 +2835,7 @@ class CustomVJPTest(jtu.JaxTestCase):
jax.grad(clip_gradient)(1.) # doesn't crash
class InvertibleADTest(jtu.JaxTestCase):
def test_invertible_basic(self):