mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[remove-units] remove units from while partial eval
This commit is contained in:
parent
cdb4a8428e
commit
9359cc3e53
@ -470,30 +470,29 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
|
||||
def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int,
|
||||
cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int,
|
||||
body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]:
|
||||
"""An implementation of partial evaluation for while.
|
||||
As long as some carry (and hence output) are known and the output
|
||||
of `cond_jaxpr` is known, we use a portion of the loop body to compute the known
|
||||
outputs of the `while_loop`. For the unknown outputs we generate Jaxpr to run
|
||||
the whole while, including recomputing the known parts.
|
||||
# As long as some carry (and hence output) are known and the output of
|
||||
# `cond_jaxpr` is known, we use a portion of the loop body to compute the
|
||||
# known outputs of the `while_loop`. For the unknown outputs we generate a
|
||||
# jaxpr to run the whole while, including recomputing the known parts,
|
||||
# basically like building in checkpointing/rematieralization. This means that
|
||||
# we don't actually save any computation by partial evaluation if there are
|
||||
# unknown outputs.
|
||||
#
|
||||
# What this achieves is twofold: jax.linearize works, and we can give a proper
|
||||
# error for reverse differentiation of `while`.
|
||||
|
||||
This means that we don't actually save any computation by partial
|
||||
evaluation if there are unknown outputs.
|
||||
|
||||
What this achieves is that we can give a proper error for reverse
|
||||
differentiation of `while`, because in that use of partial evaluation the
|
||||
primal inputs are considered "known", and only the tangent computation is
|
||||
unknown (see issue #2129).
|
||||
"""
|
||||
unknowns = [not t.pval.is_known() for t in tracers]
|
||||
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
|
||||
|
||||
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
|
||||
# Fixpoint computation of unknown carry. Each iteration promotes
|
||||
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
|
||||
cond_consts_uk, body_consts_uk, carry_init_uk = \
|
||||
split_list(unknowns, [cond_nconsts, body_nconsts])
|
||||
|
||||
# Fixpoint computation of unknown carry. Each iteration promotes at least one
|
||||
# carry to unknown. We need one last iteration to prepare the jaxpr.
|
||||
carry_uk = carry_init_uk
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr( # type: ignore
|
||||
body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore
|
||||
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
|
||||
if carry_out_uk == carry_uk:
|
||||
break
|
||||
@ -502,7 +501,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
|
||||
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr( # type: ignore
|
||||
cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore
|
||||
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
|
||||
|
||||
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
||||
@ -510,33 +509,24 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
# just do the default processing.
|
||||
return trace.default_process_primitive(while_p, tracers, params)
|
||||
|
||||
# Run the known part of the while. Prepare the inputs, as constants (if known), or
|
||||
# as core.unit.
|
||||
in_consts = [ core.unit if uk else t.pval.get_known()
|
||||
for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk,
|
||||
tracers)]
|
||||
# There should be no residuals for the cond_jaxpr_known
|
||||
assert 1 == len(cond_jaxpr_known.out_avals)
|
||||
# We ignore the residuals from the body_jaxpr_known, so the type of inputs matches
|
||||
# the type of outputs; residuals are at the end
|
||||
if len(body_jaxpr_known.out_avals) > len(body_jaxpr.out_avals):
|
||||
# TODO(necula): this is not quite enough; we should drop the residual computations also
|
||||
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:len(body_jaxpr.out_avals)]
|
||||
# Run the known part of the while.
|
||||
in_consts = [t.pval.get_known() for uk, t in
|
||||
zip(cond_consts_uk + body_consts_uk + carry_uk, tracers)
|
||||
if not uk]
|
||||
cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk)
|
||||
body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk)
|
||||
num_known_outs = len(carry_uk) - sum(carry_uk)
|
||||
# TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs
|
||||
body_jaxpr_known.jaxpr.outvars = body_jaxpr_known.jaxpr.outvars[:num_known_outs]
|
||||
out_known = while_p.bind(
|
||||
*in_consts,
|
||||
cond_nconsts=cond_nconsts,
|
||||
cond_jaxpr=cond_jaxpr_known,
|
||||
body_nconsts=body_nconsts,
|
||||
body_jaxpr=body_jaxpr_known)
|
||||
*in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known,
|
||||
body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known)
|
||||
del body_jaxpr_known
|
||||
|
||||
# Run the whole while_loop to get all the outputs, then merge with known ones
|
||||
out_all: Sequence[pe.Tracer] = trace.default_process_primitive(while_p, tracers, params)
|
||||
out_tracers: Sequence[pe.Tracer] = [
|
||||
out_unknown if uk
|
||||
else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
|
||||
for uk, out_unknown, known in zip(carry_uk, out_all, out_known)]
|
||||
|
||||
return out_tracers
|
||||
out_tracers_ = trace.default_process_primitive(while_p, tracers, params)
|
||||
out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk]
|
||||
return util.merge_lists(carry_uk, out_known, out_tracers)
|
||||
|
||||
def _while_transpose_error(*_, **kwargs):
|
||||
raise ValueError("Reverse-mode differentiation does not work for "
|
||||
|
@ -1930,6 +1930,35 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
jtu.check_grads(loop, (x,), order=2, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}".format(
|
||||
jit_loop, jit_body, jit_cond),
|
||||
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
|
||||
for jit_loop in [False, True]
|
||||
for jit_body in [False, True]
|
||||
for jit_cond in [False, True])
|
||||
def testWhileLinearize(self, jit_loop=True, jit_body=False, jit_cond=True):
|
||||
cond = lambda x: x[0, 2] <= 8
|
||||
body = lambda x: x * x
|
||||
|
||||
if jit_cond:
|
||||
cond = jax.jit(cond)
|
||||
if jit_body:
|
||||
body = jax.jit(body)
|
||||
|
||||
loop = partial(lax.while_loop, cond, body)
|
||||
if jit_loop:
|
||||
loop = jax.jit(loop)
|
||||
|
||||
loop_ref = partial(while_loop_reference, cond, body)
|
||||
|
||||
x = jnp.arange(9.).reshape((3, 3))
|
||||
y, f_lin = jax.linearize(loop, x)
|
||||
ydot = f_lin(x)
|
||||
y_expected, ydot_expected = jax.jvp(loop_ref, (x,), (x,))
|
||||
self.assertAllClose(y, y_expected, check_dtypes=False)
|
||||
self.assertAllClose(ydot, ydot_expected, check_dtypes=False)
|
||||
|
||||
def testWhileJVPViaForiLoop(self):
|
||||
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * 2, x)
|
||||
self.assertAllClose(f(2.), 16., check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user