Merge pull request #8554 from mattjj:scan-dce-rule

PiperOrigin-RevId: 445059933
This commit is contained in:
jax authors 2022-04-27 22:27:06 -07:00
commit 611759d0ce
3 changed files with 100 additions and 5 deletions

View File

@ -24,7 +24,7 @@ import inspect
import itertools
import operator
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List
import numpy as np
@ -1980,6 +1980,32 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
) -> Tuple[List[bool], core.JaxprEqn]:
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
for i in range(1 + num_carry):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr,
used_carry_out + used_extensive_out)
used_consts, used_carry_in, used_extensive_in = \
split_list(used_inputs, [num_consts, num_carry])
if used_carry_in == used_carry_out:
break
else:
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
else:
assert False, "Fixpoint not reached"
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
new_params = dict(eqn.params, num_consts=sum(used_consts),
num_carry=sum(used_carry_in), linear=tuple(new_linear),
jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts))
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, eqn.effects,
eqn.source_info)
return used_inputs, new_eqn
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
tc = partial(_typecheck_param, 'scan')
@ -2049,7 +2075,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
@api_boundary

View File

@ -1185,7 +1185,6 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
for v in jaxpr.outvars]
# TODO(mattjj): unify with dce code below
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
) -> Tuple[Jaxpr, List[bool]]:
return _dce_jaxpr(jaxpr, tuple(used_outputs))
@ -1193,7 +1192,6 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
) -> Tuple[Jaxpr, List[bool]]:
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
env: Dict[Var, bool] = {}
def read(v: Var) -> bool:
@ -1224,7 +1222,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
new_jaxpr = Jaxpr((),
new_jaxpr = Jaxpr(jaxpr.constvars,
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
new_eqns[::-1], jaxpr.effects)

View File

@ -50,6 +50,7 @@ from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax._src import device_array
import jax._src.lib
@ -4492,6 +4493,76 @@ class JaxprTest(jtu.JaxTestCase):
self.assertIn('in (*,)', str(jaxpr))
self.assertNotIn('in (a,)', str(jaxpr))
def test_dce_jaxpr_scan(self):
@api.remat
def scanned_f(c, x):
out = jnp.tanh(c * x)
return out, out
def f(xs):
return lax.scan(scanned_f, 1., xs)
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(jnp.arange(10.)).jaxpr
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
self.assertLen(jaxpr.eqns, 1)
self.assertLen(jaxpr.eqns[-1].params['jaxpr'].jaxpr.eqns, 2)
def test_dce_jaxpr_scan_nontrivial_fixedpoint(self):
def f(lst):
def body(c, _):
return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None
out, _ = jax.lax.scan(body, lst, None, length=len(lst))
return out
jaxpr = api.make_jaxpr(f)([1, 2, 3, 4]).jaxpr
self.assertLen(jaxpr.eqns, 1)
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
# If we use all but the last element, only one eqn is pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, True, False])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 2)
# And all but the first input is used.
self.assertEqual(used_inputs, [True, True, True, False])
# If we use all but the last two elements, two eqns can be pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, False, False])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 1)
# And the last two inputs are not used.
self.assertEqual(used_inputs, [True, True, False, False])
# If we only use the last element, no eqns can be pruned.
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [False, False, False, True])
self.assertLen(jaxpr_pruned.eqns, 1)
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
# And all inputs are used.
self.assertEqual(used_inputs, [True, True, True, True])
def test_dce_jaxpr_scan_const_in_jvp(self):
@api.custom_jvp
def f(x):
return x * np.arange(3.)
@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return f(x), xdot * np.arange(3.)
def g(x):
def body(c, _):
return f(c), None
y, _ = jax.lax.scan(body, x, None, length=1)
return y
jvp_jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,)))(
np.arange(3.), np.arange(3.)).jaxpr
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, True])
self.assertTrue(all(used_inputs))
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, False])
self.assertEqual(used_inputs, [True, False])
class CustomJVPTest(jtu.JaxTestCase):