mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
in remat, handle hoisted outputs in out_avals
In remat's partial_eval rule, we form a TypedJaxpr and hence need to set its output avals (modeling its output type). Since jaxprs are currently untyped, the logic looks at the Python output values returned by call bind to get their types. However, some of the output types might not correspond to value types in the jaxpr language; for example, instances of ad_util.Zero can be returned from a call bind. (In the past we've only formed TypedJaxprs in "initial-style" control-flow primitives, which don't have this possibility because there's no call bind involved.) We handle this possibility just by inspecting the jaxpr first before looking at the Python-valued outputs: wherever a non-jaxtype value has been hoisted out, a corresponding unit exists in the jaxpr outvars. So the logic is now: if the output has an abstract value from partial evaluation, use that to get the output type; otherwise if the output is a literal or unitvar, use that to get the output type; otherwise look at the const from partial evaluation (and it's an error if that const isn't a valid jaxtype). fixes #1907
This commit is contained in:
parent
a14a05d1f2
commit
0f2e08dd9d
@ -524,8 +524,11 @@ def _remat_partial_eval(trace, f, tracers, params):
|
||||
jaxpr_converted = convert_freevars_jaxpr(jaxpr)
|
||||
in_avals = ([raise_to_shaped(t.pval[0]) for t in env]
|
||||
+ [raise_to_shaped(pv) for pv in in_pvs])
|
||||
out_avals = [raise_to_shaped(pv if pv is not None else core.get_aval(const))
|
||||
for pv, const in zip(out_pvs, out_pval_consts1)]
|
||||
out_avals = [raise_to_shaped(pv if pv is not None
|
||||
else abstract_unit if var is unitvar
|
||||
else get_aval(var.val) if type(var) is Literal
|
||||
else get_aval(const))
|
||||
for var, pv, const in zip(jaxpr.outvars, out_pvs, out_pval_consts1)]
|
||||
typed_jaxpr = core.TypedJaxpr(jaxpr_converted, consts, in_avals, out_avals)
|
||||
in_unknowns = [t.pval[0] is not None for t in it.chain(env, tracers)]
|
||||
jaxpr_1, jaxpr_2, out_unknowns = partial_eval_jaxpr(typed_jaxpr, in_unknowns, False)
|
||||
@ -565,7 +568,7 @@ def _dce_jaxpr(typed_jaxpr, outputs):
|
||||
# TODO(mattjj): better DCE
|
||||
jaxpr = typed_jaxpr.jaxpr
|
||||
outvars, out_avals = jaxpr.outvars, typed_jaxpr.out_avals
|
||||
out_pairs = [(var, aval) if output else (core.unitvar, core.abstract_unit)
|
||||
out_pairs = [(var, aval) if output else (unitvar, core.abstract_unit)
|
||||
for var, aval, output in zip(outvars, out_avals, outputs)]
|
||||
new_outvars, new_out_avals = unzip2(out_pairs)
|
||||
|
||||
|
@ -1533,6 +1533,35 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f1(x), f2(x), check_dtypes=False)
|
||||
self.assertAllClose(api.grad(f1)(x), api.grad(f2)(x), check_dtypes=False)
|
||||
|
||||
def test_remat_symbolic_zeros(self):
|
||||
# code from https://github.com/google/jax/issues/1907
|
||||
test_remat = True
|
||||
test_scan = True
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
key, split = jax.random.split(key)
|
||||
n = 5
|
||||
|
||||
def func(D0):
|
||||
def shift(R, dR, **unused_kwargs):
|
||||
return R + dR
|
||||
|
||||
def apply_fn(R):
|
||||
return D0 * R
|
||||
|
||||
Rinit = jax.random.uniform(split, (n,3), minval=0.0, maxval=5.0,
|
||||
dtype=np.float32)
|
||||
|
||||
def move(R,i):
|
||||
F = apply_fn(R)
|
||||
return shift(R, 0.001 * F), np.array([0.])
|
||||
|
||||
move = api.remat(move)
|
||||
R, temp = lax.scan(move, Rinit, np.arange(2))
|
||||
return R[0, 0]
|
||||
|
||||
api.grad(func)(5.0) # doesn't crash
|
||||
|
||||
def test_trivial_computations(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = api.jit(lambda x: x)(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user