in remat, handle hoisted outputs in out_avals

In remat's partial_eval rule, we form a TypedJaxpr and hence need to set
its output avals (modeling its output type). Since jaxprs are currently
untyped, the logic looks at the Python output values returned by
call bind to get their types. However, some of the output types might
not correspond to value types in the jaxpr language; for example,
instances of ad_util.Zero can be returned from a call bind. (In the past
we've only formed TypedJaxprs in "initial-style" control-flow
primitives, which don't have this possibility because there's no call
bind involved.)

We handle this possibility just by inspecting the jaxpr first before
looking at the Python-valued outputs: wherever a non-jaxtype value has
been hoisted out, a corresponding unit exists in the jaxpr outvars. So
the logic is now: if the output has an abstract value from partial
evaluation, use that to get the output type; otherwise if the output is
a literal or unitvar, use that to get the output type; otherwise look at
the const from partial evaluation (and it's an error if that const isn't
a valid jaxtype).

fixes #1907
This commit is contained in:
Matthew Johnson 2019-12-23 11:49:01 -08:00 committed by Matthew Johnson
parent a14a05d1f2
commit 0f2e08dd9d
2 changed files with 35 additions and 3 deletions

View File

@ -524,8 +524,11 @@ def _remat_partial_eval(trace, f, tracers, params):
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
+ [raise_to_shaped(pv) for pv in in_pvs])
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
for pv, const in zip(out_pvs, out_pval_consts1)]
out_avals = [raise_to_shaped(pv if pv is not None
else abstract_unit if var is unitvar
else get_aval(var.val) if type(var) is Literal
else get_aval(const))
for var, pv, const in zip(jaxpr.outvars, out_pvs, out_pval_consts1)]
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
@ -565,7 +568,7 @@ def _dce_jaxpr(typed_jaxpr, outputs):
# TODO(mattjj): better DCE
jaxpr = typed_jaxpr.jaxpr
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
out_pairs = [(var, aval) if output else (unitvar, core.abstract_unit)
for var, aval, output in zip(outvars, out_avals, outputs)]
new_outvars, new_out_avals = unzip2(out_pairs)

View File

@ -1533,6 +1533,35 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
def test_remat_symbolic_zeros(self):
# code from https://github.com/google/jax/issues/1907
test_remat = True
test_scan = True
key = jax.random.PRNGKey(0)
key, split = jax.random.split(key)
n = 5
def func(D0):
def shift(R, dR, **unused_kwargs):
return R + dR
def apply_fn(R):
return D0 * R
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
dtype=np.float32)
def move(R,i):
F = apply_fn(R)
return shift(R, 0.001 * F), np.array([0.])
move = api.remat(move)
R, temp = lax.scan(move, Rinit, np.arange(2))
return R[0, 0]
api.grad(func)(5.0) # doesn't crash
def test_trivial_computations(self):
x = np.array([1, 2, 3])
y = api.jit(lambda x: x)(x)