mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
add scan dce rule tests, fix bugs
This commit is contained in:
parent
d57e36416f
commit
d0863a1258
@ -2037,30 +2037,36 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
|
||||
|
||||
def _scan_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
) -> Tuple[List[bool], core.JaxprEqn]:
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
|
||||
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
|
||||
used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry])
|
||||
for i in range(1 + num_carry):
|
||||
used_outputs = used_carry_out + used_extensive_out
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'].jaxpr, used_outputs)
|
||||
jaxpr_dce, used_inputs = pe.dce_jaxpr(
|
||||
jaxpr.jaxpr, used_outputs,
|
||||
instantiate=[False] * num_consts + used_carry_out + [False] * num_xs)
|
||||
used_consts, used_carry_in, used_extensive_in = \
|
||||
split_list(used_inputs, [num_consts, num_carry])
|
||||
if used_carry_in == used_carry_out:
|
||||
if list(used_carry_in) == list(used_carry_out):
|
||||
break
|
||||
else:
|
||||
used_carry_out = _map(operator.or_, used_carry_out, used_carry_in)
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
core.check_jaxpr(jaxpr.jaxpr)
|
||||
|
||||
new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u]
|
||||
new_params = dict(eqn.params, num_consts=sum(used_consts),
|
||||
num_carry=sum(used_carry_in), linear=tuple(new_linear),
|
||||
jaxpr=core.ClosedJaxpr(jaxpr, eqn.params['jaxpr'].consts))
|
||||
jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts))
|
||||
new_eqn = pe.new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs)
|
||||
if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs)
|
||||
if used],
|
||||
eqn.primitive, new_params, eqn.effects,
|
||||
eqn.source_info)
|
||||
assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals )
|
||||
assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
@ -2133,8 +2139,7 @@ core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
# TODO(mattjj): re-enable
|
||||
# pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -1161,12 +1161,16 @@ def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:
|
||||
for v in jaxpr.outvars]
|
||||
|
||||
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool]
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
return _dce_jaxpr(jaxpr, tuple(used_outputs))
|
||||
if type(instantiate) is bool:
|
||||
instantiate = (instantiate,) * len(jaxpr.invars)
|
||||
return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate))
|
||||
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
||||
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...],
|
||||
instantiate: Tuple[bool, ...]
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
env: Dict[Var, bool] = {}
|
||||
|
||||
@ -1177,26 +1181,23 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
||||
if type(x) is Var:
|
||||
env[x] = read(x) or b
|
||||
|
||||
def has_effects(e: JaxprEqn) -> bool:
|
||||
return bool(e.effects) or core.primitive_uses_outfeed(e.primitive, e.params)
|
||||
|
||||
new_eqns = []
|
||||
map(write, jaxpr.outvars, used_outputs)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
used_outs = map(read, eqn.outvars)
|
||||
# If any outputs are used, then we need to keep a version of the eqn and
|
||||
# potentially mark some inputs as used. Otherwise mark all inputs as unused.
|
||||
if any(used_outs) or core.primitive_uses_outfeed(eqn.primitive, eqn.params):
|
||||
# If there's a rule for modifying the eqn and computing used inputs, apply
|
||||
# it. Otherwise, keep the eqn unmodified and mark all inputs as used.
|
||||
rule = dce_rules.get(eqn.primitive)
|
||||
if rule:
|
||||
used_ins, new_eqn = rule(used_outs, eqn)
|
||||
else:
|
||||
used_ins = [True] * len(eqn.invars)
|
||||
new_eqn = eqn
|
||||
new_eqns.append(new_eqn)
|
||||
else:
|
||||
if not any(used_outs) and not has_effects(eqn):
|
||||
used_ins = [False] * len(eqn.invars)
|
||||
else:
|
||||
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
|
||||
used_ins, new_eqn = rule(used_outs, eqn)
|
||||
if new_eqn is not None:
|
||||
new_eqns.append(new_eqn)
|
||||
map(write, eqn.invars, used_ins)
|
||||
used_inputs = map(read, jaxpr.invars)
|
||||
used_inputs = map(op.or_, instantiate, used_inputs)
|
||||
|
||||
new_jaxpr = Jaxpr(jaxpr.constvars,
|
||||
[v for v, b in zip(jaxpr.invars, used_inputs) if b],
|
||||
@ -1206,7 +1207,13 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: Tuple[bool, ...]
|
||||
|
||||
return new_jaxpr, used_inputs
|
||||
|
||||
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], JaxprEqn]]
|
||||
DCERule = Callable[[List[bool], JaxprEqn], Tuple[List[bool], Optional[JaxprEqn]]]
|
||||
|
||||
def _default_dce_rule(
|
||||
used_outs: List[bool], eqn: JaxprEqn
|
||||
) -> Tuple[List[bool], JaxprEqn]:
|
||||
return [True] * len(eqn.invars), eqn
|
||||
|
||||
dce_rules: Dict[Primitive, DCERule] = {}
|
||||
|
||||
|
||||
@ -1217,9 +1224,10 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
update_params = call_param_updaters.get(eqn.primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, used_inputs, 0)
|
||||
new_eqn = new_jaxpr_eqn([v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
new_eqn = new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
|
@ -24,7 +24,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from typing import Callable
|
||||
from typing import Callable, List, Optional
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
@ -4518,8 +4518,185 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
jaxpr = api.make_jaxpr(lambda: cet(3.))()
|
||||
self.assertLen(jaxpr.eqns, 0)
|
||||
|
||||
def test_dce_jaxpr_scan(self):
|
||||
raise unittest.SkipTest() # TODO(mattjj)
|
||||
|
||||
class DCETest(jtu.JaxTestCase):
|
||||
|
||||
def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: List[bool],
|
||||
expected_used_inputs: List[bool],
|
||||
expected_num_eqns: Optional[int] = None,
|
||||
check_diff: bool = True):
|
||||
jaxpr_dce, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs)
|
||||
core.check_jaxpr(jaxpr_dce)
|
||||
self.assertEqual(used_inputs, expected_used_inputs)
|
||||
if expected_num_eqns is not None:
|
||||
all_jaxprs = it.chain([jaxpr_dce], core.subjaxprs(jaxpr_dce))
|
||||
num_eqns = sum(len(subjaxpr.eqns) for subjaxpr in all_jaxprs)
|
||||
self.assertEqual(num_eqns, expected_num_eqns, msg=str(jaxpr_dce))
|
||||
|
||||
rand_ = jtu.rand_small(np.random.RandomState(0))
|
||||
rand = lambda v: rand_(v.aval.shape, v.aval.dtype)
|
||||
consts = [rand(v) for v in jaxpr.constvars]
|
||||
inputs = [rand(v) for v in jaxpr.invars ]
|
||||
inputs_dce = [x for x, used in zip(inputs, used_inputs) if used]
|
||||
full_outs = core.eval_jaxpr(jaxpr , consts, *inputs)
|
||||
expected_outs_dce = [y for y, used in zip(full_outs, used_outputs) if used]
|
||||
outs = core.eval_jaxpr(jaxpr_dce, consts, *inputs_dce)
|
||||
self.assertAllClose(outs, expected_outs_dce)
|
||||
|
||||
if check_diff and expected_num_eqns != 0:
|
||||
f = lambda *args: core.eval_jaxpr(jaxpr_dce, consts, *args)
|
||||
jtu.check_grads(f, inputs_dce, order=2, modes=['rev'])
|
||||
|
||||
def test_dce_jaxpr_scan_nontrivial_fixedpoint_carry(self):
|
||||
# The idea is that each element of the output carry tuple depends on the
|
||||
# corresponding carried input as well as the one to the left. The extensive
|
||||
# inputs and outputs aren't used here; just the carry depending on itself.
|
||||
def f(lst):
|
||||
def body(c, _):
|
||||
return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None
|
||||
out, _ = jax.lax.scan(body, lst, None, length=len(lst))
|
||||
return out
|
||||
jaxpr = api.make_jaxpr(f)([1., 2., 3., 4.]).jaxpr
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
|
||||
|
||||
# If we use all but the last element, all but the first input is used, and
|
||||
# only one eqn is pruned.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True, True, True, False],
|
||||
expected_used_inputs=[True, True, True, False],
|
||||
expected_num_eqns=1 + 2) # one outer scan eqn, two adds in the body
|
||||
|
||||
# Same as above if we just pull on the third element.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[False, False, True, False],
|
||||
expected_used_inputs=[True, True, True, False],
|
||||
expected_num_eqns=1 + 2) # one outer scan eqn, two adds in the body
|
||||
|
||||
# If we use all but the last two elements, the last two inputs are not used,
|
||||
# and two eqns can be pruned.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True, True, False, False],
|
||||
expected_used_inputs=[True, True, False, False],
|
||||
expected_num_eqns=1 + 1) # one outer scan eqn, one add in body
|
||||
|
||||
# If we only use the last element, no eqns can be pruned.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[False, False, False, True],
|
||||
expected_used_inputs=[True, True, True, True],
|
||||
expected_num_eqns=1 + 3) # one outer scan eqn, three adds in body
|
||||
|
||||
def test_dce_jaxpr_scan_nontrivial_fixedpoint_carry_2(self):
|
||||
# This is much like the above test, except with a more interesting
|
||||
# dependence structure among the carry elements. Also add a const and
|
||||
# extensive input.
|
||||
hidden_sequence = [1, 2, 3, 5, 8]
|
||||
def f(lst):
|
||||
def body(c, _):
|
||||
_ = jnp.sin(np.array([3., 1., 4.]))
|
||||
sub_c = [c[i] for i in hidden_sequence]
|
||||
sub_c = [sub_c[0]] + [c1 * c2 for c1, c2 in zip(sub_c[:-1], sub_c[1:])]
|
||||
new_c = list(c)
|
||||
for i, elt in zip(hidden_sequence, sub_c):
|
||||
new_c[i] = elt
|
||||
return new_c, None
|
||||
out, _ = jax.lax.scan(body, lst, np.arange(len(lst), dtype='float32'))
|
||||
return out
|
||||
jaxpr = api.make_jaxpr(f)([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]).jaxpr
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 5)
|
||||
|
||||
# If we use the value at index 8 only, all the hidden sequence must be kept
|
||||
# and no eqns can be pruned.
|
||||
used_outputs = [False] * 10
|
||||
used_outputs[8] = True
|
||||
expected_used_inputs = [False] * 10
|
||||
for i in hidden_sequence:
|
||||
expected_used_inputs[i] = True
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=used_outputs,
|
||||
expected_used_inputs=expected_used_inputs,
|
||||
expected_num_eqns=1 + 4)
|
||||
|
||||
# If we use the value at any indices not in the hidden sequence, none of the
|
||||
# hidden sequence must be kept and we can prune all body eqns.
|
||||
used_outputs = [False] * 10
|
||||
expected_used_inputs = [False] * 10
|
||||
used_outputs[9] = expected_used_inputs[9] = True
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=used_outputs,
|
||||
expected_used_inputs=expected_used_inputs,
|
||||
expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule
|
||||
used_outputs[7] = expected_used_inputs[7] = True
|
||||
used_outputs[6] = expected_used_inputs[6] = True
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=used_outputs,
|
||||
expected_used_inputs=expected_used_inputs,
|
||||
expected_num_eqns=1)
|
||||
|
||||
# If we use the value at index 3 only, some of the hidden sequence must be
|
||||
# kept but the rest pruned.
|
||||
used_outputs = [False] * 10
|
||||
used_outputs[3] = True
|
||||
expected_used_inputs = [False] * 10
|
||||
expected_used_inputs[1] = expected_used_inputs[2] = \
|
||||
expected_used_inputs[3] = True
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=used_outputs,
|
||||
expected_used_inputs=expected_used_inputs,
|
||||
expected_num_eqns=1 + 2)
|
||||
|
||||
def test_dce_jaxpr_scan_nontrivial_fixedpoint_extensive_output(self):
|
||||
# Here we test how using the extensive output affects the carry.
|
||||
def f(lst):
|
||||
def body(c, _):
|
||||
return [c[-1], *c[:-1]], c[-1]
|
||||
_, ys = jax.lax.scan(body, lst, None, length=len(lst))
|
||||
return ys
|
||||
jaxpr = api.make_jaxpr(f)([1., 2., 3., 4.]).jaxpr
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
|
||||
# If we only use the extensive output, all carry elements are needed, and we
|
||||
# need to keep the scan itself.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True],
|
||||
expected_used_inputs=[True, True, True, True],
|
||||
expected_num_eqns=1)
|
||||
|
||||
# If we don't use the extensive output, no carry elements are needed, and we
|
||||
# don't need to keep the scan.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[False],
|
||||
expected_used_inputs=[False, False, False, False],
|
||||
expected_num_eqns=0)
|
||||
|
||||
def test_dce_jaxpr_scan_extensive_input(self):
|
||||
# Here we test an extensive input affecting the carry.
|
||||
def cumprod(xs):
|
||||
def body(c, x):
|
||||
return c * x, c
|
||||
c, ys = jax.lax.scan(body, jnp.float32(1.), xs)
|
||||
return c, ys
|
||||
jaxpr = api.make_jaxpr(cumprod)(jnp.arange(1., 5., dtype='float32')).jaxpr
|
||||
|
||||
# If we only use the carry output or extensive output, we need the input.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True, False],
|
||||
expected_used_inputs=[True],
|
||||
expected_num_eqns=2)
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[False, True],
|
||||
expected_used_inputs=[True],
|
||||
expected_num_eqns=2)
|
||||
|
||||
# If we don't use either output, the scan is eliminated.
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[False, False],
|
||||
expected_used_inputs=[False],
|
||||
expected_num_eqns=0)
|
||||
|
||||
def test_dce_jaxpr_scan_overpruning(self):
|
||||
# This is a regression test for a specific issue.
|
||||
@api.remat
|
||||
def scanned_f(c, x):
|
||||
out = jnp.tanh(c * x)
|
||||
@ -4528,46 +4705,15 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
def f(xs):
|
||||
return lax.scan(scanned_f, 1., xs)
|
||||
|
||||
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(jnp.arange(10.)).jaxpr
|
||||
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
|
||||
xs = jnp.arange(10.)
|
||||
jaxpr = api.make_jaxpr(lambda xs: api.linearize(f, xs)[1])(xs).jaxpr
|
||||
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[-1].params['jaxpr'].jaxpr.eqns, 2)
|
||||
|
||||
def test_dce_jaxpr_scan_nontrivial_fixedpoint(self):
|
||||
raise unittest.SkipTest() # TODO(mattjj)
|
||||
def f(lst):
|
||||
def body(c, _):
|
||||
return [c[0]] + [c1 + c2 for c1, c2 in zip(c[:-1], c[1:])], None
|
||||
out, _ = jax.lax.scan(body, lst, None, length=len(lst))
|
||||
return out
|
||||
jaxpr = api.make_jaxpr(f)([1, 2, 3, 4]).jaxpr
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
|
||||
|
||||
# If we use all but the last element, only one eqn is pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, True, False])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 2)
|
||||
# And all but the first input is used.
|
||||
self.assertEqual(used_inputs, [True, True, True, False])
|
||||
|
||||
# If we use all but the last two elements, two eqns can be pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True, True, False, False])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 1)
|
||||
# And the last two inputs are not used.
|
||||
self.assertEqual(used_inputs, [True, True, False, False])
|
||||
|
||||
# If we only use the last element, no eqns can be pruned.
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [False, False, False, True])
|
||||
self.assertLen(jaxpr_pruned.eqns, 1)
|
||||
self.assertLen(jaxpr_pruned.eqns[0].params['jaxpr'].jaxpr.eqns, 3)
|
||||
# And all inputs are used.
|
||||
self.assertEqual(used_inputs, [True, True, True, True])
|
||||
|
||||
def test_dce_jaxpr_scan_const_in_jvp(self):
|
||||
raise unittest.SkipTest() # TODO(mattjj)
|
||||
# The main point of this test is to check for a crash.
|
||||
@api.custom_jvp
|
||||
def f(x):
|
||||
return x * np.arange(3.)
|
||||
@ -4582,14 +4728,38 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
y, _ = jax.lax.scan(body, x, None, length=1)
|
||||
return y
|
||||
|
||||
jvp_jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,)))(
|
||||
np.arange(3.), np.arange(3.)).jaxpr
|
||||
jaxpr = api.make_jaxpr(lambda x, xdot: api.jvp(g, (x,), (xdot,))
|
||||
)(np.arange(3.), np.arange(3.)).jaxpr
|
||||
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, True])
|
||||
self.assertTrue(all(used_inputs))
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True, True],
|
||||
expected_used_inputs=[True, True])
|
||||
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jvp_jaxpr, [True, False])
|
||||
self.assertEqual(used_inputs, [True, False])
|
||||
self.assert_dce_result(
|
||||
jaxpr, used_outputs=[True, False],
|
||||
expected_used_inputs=[True, False])
|
||||
|
||||
def test_dce_jaxpr_scan_results(self):
|
||||
# This doesn't test whether DCE is doing nontrivial work; instead it tests
|
||||
# whether the result after applying DCE computes different values. If
|
||||
# dce_jaxpr were an identity function, it'd pass this test!
|
||||
def f(cs, xs):
|
||||
def body(c, x):
|
||||
return (c[0], c[0] + c[1], jnp.arange(3.)), x
|
||||
cs, xs = jax.lax.scan(body, cs, xs)
|
||||
return cs[::2], xs[::2]
|
||||
|
||||
cs = 1., 2., jnp.arange(3.)
|
||||
xs = jnp.arange(3.), jnp.arange(3.) + 5
|
||||
jaxpr_ = jax.make_jaxpr(f)(cs, xs)
|
||||
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
|
||||
jaxpr_pruned, used_inputs = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars))
|
||||
|
||||
args = (*cs, *xs)
|
||||
result1 = core.eval_jaxpr(jaxpr , consts, *cs, *xs)
|
||||
pruned_args = [x for x, used in zip(args, used_inputs) if used]
|
||||
result2 = core.eval_jaxpr(jaxpr_pruned, consts, *pruned_args)
|
||||
self.assertAllClose(result1, result2)
|
||||
|
||||
|
||||
class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user