mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #8554 from mattjj:scan-dce-rule
PiperOrigin-RevId: 445059933
This commit is contained in:
commit
611759d0ce
@ -24,7 +24,7 @@ import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1980,6 +1980,32 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
|
||||
return scan_p.bind(*args, jaxpr=padded_jaxpr, **params)
|
||||
|
||||
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
|
||||
for i in range(1 + num_carry):
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr,
|
||||
used_carry_out + used_extensive_out)
|
||||
used_consts, used_carry_in, used_extensive_in = \
|
||||
split_list(used_inputs, [num_consts, num_carry])
|
||||
if used_carry_in == used_carry_out:
|
||||
break
|
||||
else:
|
||||
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
|
||||
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
|
||||
new_params = dict(eqn.params, num_consts=sum(used_consts),
|
||||
num_carry=sum(used_carry_in), linear=tuple(new_linear),
|
||||
jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts))
|
||||
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, eqn.effects,
|
||||
eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
tc = partial(_typecheck_param, 'scan')
|
||||
@ -2049,7 +2075,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -1185,7 +1185,6 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
for v in jaxpr.outvars]
|
||||
|
||||
|
||||
# TODO(mattjj): unify with dce code below
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
return _dce_jaxpr(jaxpr, tuple(used_outputs))
|
||||
@ -1193,7 +1192,6 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
|
||||
env: Dict[Var, bool] = {}
|
||||
|
||||
def read(v: Var) -> bool:
|
||||
@ -1224,7 +1222,7 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
||||
map(write, eqn.invars, used_ins)
|
||||
used_inputs = map(read, jaxpr.invars)
|
||||
|
||||
new_jaxpr = Jaxpr((),
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars,
|
||||
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
|
||||
[v for v, b in zip(jaxpr.outvars, used_outputs) if b],
|
||||
new_eqns[::-1], jaxpr.effects)
|
||||
|
@ -50,6 +50,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax._src import device_array
|
||||
import jax._src.lib
|
||||
@ -4492,6 +4493,76 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertIn('in (*,)', str(jaxpr))
|
||||
self.assertNotIn('in (a,)', str(jaxpr))
|
||||
|
||||
def test_dce_jaxpr_scan(self):
|
||||
@api.remat
|
||||
def scanned_f(c, x):
|
||||
out = jnp.tanh(c * x)
|
||||
return out, out
|
||||
|
||||
def f(xs):
|
||||
return lax.scan(scanned_f, 1., xs)
|
||||
|
||||
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(jnp.arange(10.)).jaxpr
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
|
||||
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[-1].params['jaxpr'].jaxpr.eqns, 2)
|
||||
|
||||
def test_dce_jaxpr_scan_nontrivial_fixedpoint(self):
|
||||
def f(lst):
|
||||
def body(c, _):
|
||||
return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None
|
||||
out, _ = jax.lax.scan(body, lst, None, length=len(lst))
|
||||
return out
|
||||
jaxpr = api.make_jaxpr(f)([1, 2, 3, 4]).jaxpr
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
|
||||
|
||||
# If we use all but the last element, only one eqn is pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, True, False])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 2)
|
||||
# And all but the first input is used.
|
||||
self.assertEqual(used_inputs, [True, True, True, False])
|
||||
|
||||
# If we use all but the last two elements, two eqns can be pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, False, False])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 1)
|
||||
# And the last two inputs are not used.
|
||||
self.assertEqual(used_inputs, [True, True, False, False])
|
||||
|
||||
# If we only use the last element, no eqns can be pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [False, False, False, True])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
|
||||
# And all inputs are used.
|
||||
self.assertEqual(used_inputs, [True, True, True, True])
|
||||
|
||||
def test_dce_jaxpr_scan_const_in_jvp(self):
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return x * np.arange(3.)
|
||||
@f.defjvp
|
||||
def f_jvp(primals, tangents):
|
||||
(x,), (xdot,) = primals, tangents
|
||||
return f(x), xdot * np.arange(3.)
|
||||
|
||||
def g(x):
|
||||
def body(c, _):
|
||||
return f(c), None
|
||||
y, _ = jax.lax.scan(body, x, None, length=1)
|
||||
return y
|
||||
|
||||
jvp_jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,)))(
|
||||
np.arange(3.), np.arange(3.)).jaxpr
|
||||
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, True])
|
||||
self.assertTrue(all(used_inputs))
|
||||
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, False])
|
||||
self.assertEqual(used_inputs, [True, False])
|
||||
|
||||
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user