From 4608d36340af6f7bdfd21a445d2fea8b3659bd47 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 15 Nov 2021 22:36:39 -0800 Subject: [PATCH] add scan dce rule --- jax/_src/lax/control_flow.py | 30 +++++++++++++- jax/interpreters/partial_eval.py | 4 +- tests/api_test.py | 71 ++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 09ee16026..0997a68e0 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -24,7 +24,7 @@ import inspect import itertools import operator import os -from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar +from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List import numpy as np @@ -1980,6 +1980,32 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) return scan_p.bind(*args, jaxpr=padded_jaxpr, **params) +def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn + ) -> Tuple[List[bool], core.JaxprEqn]: + num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] + used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry]) + for i in range(1 + num_carry): + jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr, + used_carry_out + used_extensive_out) + used_consts, used_carry_in, used_extensive_in = \ + split_list(used_inputs, [num_consts, num_carry]) + if used_carry_in == used_carry_out: + break + else: + used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) + else: + assert False, "Fixpoint not reached" + + new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] + new_params = dict(eqn.params, num_consts=sum(used_consts), + num_carry=sum(used_carry_in), linear=tuple(new_linear), + jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts)) + new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, eqn.effects, + eqn.source_info) + return used_inputs, new_eqn + def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): tc = partial(_typecheck_param, 'scan') @@ -2049,7 +2075,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan') pe.padding_rules[scan_p] = _scan_padding_rule - +pe.dce_rules[scan_p] = _scan_dce_rule @api_boundary diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 40ca91f87..8978bdb4f 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1185,7 +1185,6 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]: for v in jaxpr.outvars] -# TODO(mattjj): unify with dce code below def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool] ) -> Tuple[Jaxpr, List[bool]]: return _dce_jaxpr(jaxpr, tuple(used_outputs)) @@ -1193,7 +1192,6 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool] @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...] ) -> Tuple[Jaxpr, List[bool]]: - if jaxpr.constvars: raise NotImplementedError # TODO(mattjj) env: Dict[Var, bool] = {} def read(v: Var) -> bool: @@ -1224,7 +1222,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...] map(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) - new_jaxpr = Jaxpr((), + new_jaxpr = Jaxpr(jaxpr.constvars, [v for v, b in zip(jaxpr.invars, used_inputs) if b], [v for v, b in zip(jaxpr.outvars, used_outputs) if b], new_eqns[::-1], jaxpr.effects) diff --git a/tests/api_test.py b/tests/api_test.py index 805a55001..747e78c5a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,6 +50,7 @@ from jax.interpreters import ad from jax.interpreters import mlir from jax.interpreters import xla from jax.interpreters import pxla +from jax.interpreters import partial_eval as pe from jax.interpreters.pxla import PartitionSpec as P from jax._src import device_array import jax._src.lib @@ -4492,6 +4493,76 @@ class JaxprTest(jtu.JaxTestCase): self.assertIn('in (*,)', str(jaxpr)) self.assertNotIn('in (a,)', str(jaxpr)) + def test_dce_jaxpr_scan(self): + @api.remat + def scanned_f(c, x): + out = jnp.tanh(c * x) + return out, out + + def f(xs): + return lax.scan(scanned_f, 1., xs) + + jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(jnp.arange(10.)).jaxpr + jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars)) + + self.assertLen(jaxpr.eqns, 1) + self.assertLen(jaxpr.eqns[-1].params['jaxpr'].jaxpr.eqns, 2) + + def test_dce_jaxpr_scan_nontrivial_fixedpoint(self): + def f(lst): + def body(c, _): + return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None + out, _ = jax.lax.scan(body, lst, None, length=len(lst)) + return out + jaxpr = api.make_jaxpr(f)([1, 2, 3, 4]).jaxpr + self.assertLen(jaxpr.eqns, 1) + self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3) + + # If we use all but the last element, only one eqn is pruned. + jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, True, False]) + self.assertLen(jaxpr_pruned.eqns, 1) + self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 2) + # And all but the first input is used. + self.assertEqual(used_inputs, [True, True, True, False]) + + # If we use all but the last two elements, two eqns can be pruned. + jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, False, False]) + self.assertLen(jaxpr_pruned.eqns, 1) + self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 1) + # And the last two inputs are not used. + self.assertEqual(used_inputs, [True, True, False, False]) + + # If we only use the last element, no eqns can be pruned. + jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [False, False, False, True]) + self.assertLen(jaxpr_pruned.eqns, 1) + self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 3) + # And all inputs are used. + self.assertEqual(used_inputs, [True, True, True, True]) + + def test_dce_jaxpr_scan_const_in_jvp(self): + @api.custom_jvp + def f(x): + return x * np.arange(3.) + @f.defjvp + def f_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return f(x), xdot * np.arange(3.) + + def g(x): + def body(c, _): + return f(c), None + y, _ = jax.lax.scan(body, x, None, length=1) + return y + + jvp_jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,)))( + np.arange(3.), np.arange(3.)).jaxpr + + jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, True]) + self.assertTrue(all(used_inputs)) + + jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, False]) + self.assertEqual(used_inputs, [True, False]) + class CustomJVPTest(jtu.JaxTestCase):