mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add partial_eval_custom rule for for_loop
This commit is contained in:
parent
0869183107
commit
b2a5d2c3bb
@ -33,7 +33,7 @@ from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list)
|
||||
split_list, split_dict)
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -319,6 +319,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
jaxpr: core.Jaxpr, nsteps: int, reverse: bool,
|
||||
which_linear: Tuple[bool, ...]) -> List[pe.JaxprTracer]:
|
||||
num_inputs = len(tracers)
|
||||
assert num_inputs == len(jaxpr.invars) - 1
|
||||
in_unknowns = [not t.pval.is_known() for t in tracers]
|
||||
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||||
# after running the for loop. We want to use the jaxpr to determine which
|
||||
@ -446,6 +447,135 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
|
||||
return merge_lists(in_unknowns, known_outputs, unknown_outputs)
|
||||
pe.custom_partial_eval_rules[for_p] = _for_partial_eval
|
||||
|
||||
def _for_partial_eval_custom(saveable, in_unknowns, in_inst, eqn):
|
||||
jaxpr, nsteps, reverse, which_linear = split_dict(
|
||||
eqn.params, ["jaxpr", "nsteps", "reverse", "which_linear"])
|
||||
num_inputs = len(eqn.invars)
|
||||
# We first need to run a fixpoint to determine which of the `Ref`s are unknown
|
||||
# after running the for loop. However, the jaxpr has no outputs. Instead, we
|
||||
# discharge the body and run the fixpoint with the discharged jaxpr. We can do
|
||||
# this because the outputs of the discharged jaxpr are one-to-one with the
|
||||
# inputs.
|
||||
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, ())
|
||||
discharged_jaxpr = discharged_jaxpr.replace(
|
||||
invars=discharged_jaxpr.constvars + discharged_jaxpr.invars,
|
||||
constvars=[])
|
||||
in_unknowns, in_inst = list(in_unknowns), list(in_inst)
|
||||
for _ in range(num_inputs):
|
||||
jaxpr_in_unknowns = [False] * len(discharged_consts) + [False, *in_unknowns]
|
||||
_, _, out_unknowns, inst_out, _, = pe.partial_eval_jaxpr_custom(
|
||||
discharged_jaxpr, jaxpr_in_unknowns, True,
|
||||
ensure_out_unknowns=in_unknowns, ensure_out_inst=True,
|
||||
saveable=saveable)
|
||||
out_unknowns = list(out_unknowns)
|
||||
if out_unknowns == in_unknowns:
|
||||
break
|
||||
in_unknowns = map(operator.or_, in_unknowns, out_unknowns)
|
||||
else:
|
||||
raise Exception("Invalid fixpoint")
|
||||
del out_unknowns # Redundant since it's the same as `in_unknowns`
|
||||
new_inst = [x for x, inst in zip(eqn.invars, in_inst)
|
||||
if type(x) is core.Var and not inst]
|
||||
in_inst = [True] * len(eqn.invars)
|
||||
|
||||
# We use `partial_eval_jaxpr_custom` here because it won't remove effectful
|
||||
# primitives like `get`/`set`.
|
||||
jaxpr_known_resout, jaxpr_staged_resin_, _, _, num_res = \
|
||||
pe.partial_eval_jaxpr_custom(jaxpr, [False, *in_unknowns],
|
||||
[True, *in_inst], [], [], saveable)
|
||||
|
||||
# `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and
|
||||
# non-Ref input/outputs. However, we'd like to bind these jaxprs to a
|
||||
# `for`, which expects only `Ref` inputs and no output. We need to convert
|
||||
# both of these jaxprs into ones that are compatible with `for`.
|
||||
# TODO(sharadmv,mattjj): implement "passthrough" optimization.
|
||||
# TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of
|
||||
# passing the loop index as a residual
|
||||
|
||||
# `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs`
|
||||
# to output residual values (none of them should be `Ref`s). We'll need to
|
||||
# convert the output residual values into `Ref`s that are initially empty
|
||||
# `Ref`s that are written to at the end of the jaxpr.
|
||||
|
||||
# # Loop-invariant residual optimization
|
||||
# Here we are interested in finding out which of the residuals are *not*
|
||||
# dependent on the loop index. If a residual is not dependent on the loop
|
||||
# index, we don't need add an extra loop dimension we're reading from when we
|
||||
# convert it from an output into a write.
|
||||
|
||||
# In order to detect which residuals are loop-invariant, we need to run a
|
||||
# fixpoint. This is because the residual could be dependent on a `Ref` that
|
||||
# changes each iteration of the loop so we need to first detect which `Ref`s
|
||||
# are loop-varying. We can do this by discharging the state from the jaxpr and
|
||||
# running partial_eval with initially only the loop-index being loop-varying.
|
||||
# The fixpoint will eventually propagate the loop-varying-ness over the
|
||||
# inputs/outputs and we will converge.
|
||||
loop_var_res = [False] * len(jaxpr_known_resout.outvars)
|
||||
loop_var_refs = [False] * (len(jaxpr_known_resout.invars) - 1)
|
||||
discharged_jaxpr_known_resout = core.ClosedJaxpr(
|
||||
*discharge_state(jaxpr_known_resout, ()))
|
||||
for _ in range(len(discharged_jaxpr_known_resout.jaxpr.invars)):
|
||||
(_, _, loop_var_outputs, _) = pe.partial_eval_jaxpr_nounits(
|
||||
discharged_jaxpr_known_resout, [True] + loop_var_refs, False)
|
||||
loop_var_res, loop_var_refs_ = split_list(
|
||||
loop_var_outputs, [len(loop_var_res)])
|
||||
if loop_var_refs == loop_var_refs_:
|
||||
break
|
||||
loop_var_refs = map(operator.or_, loop_var_refs, loop_var_refs_)
|
||||
# Now that the fixpoint is complete, we know which residuals are
|
||||
# loop-invariant.
|
||||
loop_invar_res = map(operator.not_, loop_var_res)
|
||||
|
||||
jaxpr_known, res_avals = _convert_outputs_to_writes(nsteps,
|
||||
jaxpr_known_resout,
|
||||
loop_invar_res)
|
||||
|
||||
known_invars, _ = partition_list(in_unknowns, eqn.invars)
|
||||
known_outvars, _ = partition_list(in_unknowns, eqn.outvars)
|
||||
newvar = core.gensym()
|
||||
resvars = map(newvar, res_avals)
|
||||
|
||||
@lu.wrap_init
|
||||
def known(*known_vals):
|
||||
empty_res = map(ad_util.zeros_like_aval, res_avals)
|
||||
jaxpr_known_args = [*known_vals, *empty_res]
|
||||
jaxpr_known_which_linear = (False,) * len(jaxpr_known_args)
|
||||
return for_p.bind(*jaxpr_known_args, jaxpr=jaxpr_known, nsteps=nsteps,
|
||||
reverse=reverse, which_linear=jaxpr_known_which_linear)
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
known, [v.aval for v in known_invars])
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
eqn_known = pe.new_jaxpr_eqn(known_invars, [*known_outvars, *resvars],
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||||
call_jaxpr.effects, eqn.source_info)
|
||||
|
||||
jaxpr_staged = _convert_inputs_to_reads(nsteps, len(res_avals),
|
||||
jaxpr_staged_resin_,
|
||||
loop_invar_res)
|
||||
which_linear_unknown = (False,) * num_res + tuple(which_linear)
|
||||
params_staged = dict(eqn.params, jaxpr=jaxpr_staged, reverse=reverse,
|
||||
nsteps=nsteps,
|
||||
which_linear=which_linear_unknown)
|
||||
|
||||
@lu.wrap_init
|
||||
def staged(*res_and_refs):
|
||||
out_flat = for_p.bind(*res_and_refs, **params_staged)
|
||||
_, ans = split_list(out_flat, [num_res])
|
||||
_, ans = partition_list(inst_out, ans)
|
||||
return ans
|
||||
call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic(
|
||||
staged, [v.aval for v in [*resvars, *eqn.invars]])
|
||||
assert len(jaxpr_staged.invars) - 1 == len(call_jaxpr_.invars)
|
||||
call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts)
|
||||
_, outvars = partition_list(inst_out, eqn.outvars)
|
||||
eqn_staged = pe.new_jaxpr_eqn([*resvars, *eqn.invars], outvars,
|
||||
core.closed_call_p, dict(call_jaxpr=call_jaxpr),
|
||||
call_jaxpr.effects, eqn.source_info)
|
||||
new_vars = [*new_inst, *resvars]
|
||||
return eqn_known, eqn_staged, in_unknowns, inst_out, new_vars
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[for_p] = _for_partial_eval_custom
|
||||
|
||||
def _convert_outputs_to_writes(
|
||||
nsteps: int, jaxpr: core.Jaxpr, loop_invar_res: Sequence[bool]
|
||||
) -> Tuple[core.Jaxpr, List[core.ShapedArray]]:
|
||||
|
@ -328,10 +328,10 @@ def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
|
||||
if any(unks_in):
|
||||
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
||||
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
|
||||
elif saveable(get_p, *[var.aval for var in eqn.invars], **eqn.params):
|
||||
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
|
||||
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
|
||||
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
||||
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), []
|
||||
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
|
||||
get_p)
|
||||
|
@ -1240,9 +1240,45 @@ def call_partial_eval_custom_rule(
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
|
||||
def closed_call_partial_eval_custom_rule(
|
||||
jaxpr_param_name: str,
|
||||
saveable: Callable[..., bool], unks_in: List[bool], inst_in: List[bool],
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr = convert_constvars_jaxpr(closed_jaxpr.jaxpr)
|
||||
unks_in = [False] * len(closed_jaxpr.consts) + list(unks_in)
|
||||
inst_in = [False] * len(closed_jaxpr.consts) + list(inst_in)
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
||||
partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
newvar = core.gensym([jaxpr_known, jaxpr_staged])
|
||||
params_known = {**eqn.params, jaxpr_param_name: core.ClosedJaxpr(jaxpr_known,
|
||||
())}
|
||||
params_staged = {**eqn.params, jaxpr_param_name:
|
||||
core.ClosedJaxpr(jaxpr_staged, ())}
|
||||
residuals = [newvar(res_aval(params_known, var.aval))
|
||||
for var in jaxpr_staged.invars[:num_res]]
|
||||
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
|
||||
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is Var and not inst]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.closed_call_p] = \
|
||||
partial(closed_call_partial_eval_custom_rule, 'call_jaxpr')
|
||||
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
|
@ -89,6 +89,9 @@ def scan_with_new_checkpoint2(f, *args, **kwargs):
|
||||
def scan_with_for(f, *args, **kwargs):
|
||||
return for_loop.scan(f, *args, **kwargs)
|
||||
|
||||
def scan_with_remat_for(f, *args, **kwargs):
|
||||
return jax.remat(lambda *args: for_loop.scan(f, *args, **kwargs))(*args)
|
||||
|
||||
SCAN_IMPLS = [
|
||||
(lax.scan, 'unroll1'),
|
||||
(partial(lax.scan, unroll=2), 'unroll2'),
|
||||
@ -102,8 +105,26 @@ SCAN_IMPLS_WITH_FOR = [
|
||||
(scan_with_new_checkpoint , 'new_checkpoint'),
|
||||
(scan_with_new_checkpoint2, 'new_checkpoint2'),
|
||||
(scan_with_for, 'for_loop'),
|
||||
(scan_with_remat_for, 'for_loop_remat'),
|
||||
]
|
||||
|
||||
def remat_of_for_loop(nsteps, body, state, **kwargs):
|
||||
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
|
||||
**kwargs))(state)
|
||||
|
||||
FOR_LOOP_IMPLS = [
|
||||
(for_loop.for_loop, 'for_loop'),
|
||||
(jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'),
|
||||
(remat_of_for_loop, 'remat_for_loop'),
|
||||
]
|
||||
|
||||
|
||||
def _for_loop_impls(f):
|
||||
return parameterized.named_parameters(
|
||||
dict(testcase_name=impl_name, for_impl=for_impl)
|
||||
for for_impl, impl_name in FOR_LOOP_IMPLS
|
||||
)(f)
|
||||
|
||||
|
||||
def while_loop_new_checkpoint(cond_fun, body_fun, init_val):
|
||||
return new_checkpoint(partial(lax.while_loop, cond_fun, body_fun))(init_val)
|
||||
@ -2571,82 +2592,89 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(f)(1.) # doesn't crash
|
||||
|
||||
|
||||
class ForLoopTest(jtu.JaxTestCase):
|
||||
|
||||
def test_for_loop_impl_trivial(self):
|
||||
out = for_loop.for_loop(5, lambda i, _: None, None)
|
||||
@_for_loop_impls
|
||||
def test_for_loop_impl_trivial(self, for_impl):
|
||||
out = for_impl(5, lambda i, _: None, None)
|
||||
self.assertEqual(out, None)
|
||||
|
||||
def test_for_loop_can_write_to_ref(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_write_to_ref(self, for_impl):
|
||||
def body(_, x_ref):
|
||||
x_ref[()] = jnp.float32(1.)
|
||||
out = for_loop.for_loop(1, body, jnp.float32(0.))
|
||||
out = for_impl(1, body, jnp.float32(0.))
|
||||
self.assertEqual(out, 1.)
|
||||
|
||||
def body2(i, x_ref):
|
||||
x_ref[()] = jnp.float32(i)
|
||||
out = for_loop.for_loop(2, body2, jnp.float32(0.))
|
||||
out = for_impl(2, body2, jnp.float32(0.))
|
||||
self.assertEqual(out, 1.)
|
||||
|
||||
def body3(i, x_ref):
|
||||
x_ref[()] = jnp.float32(i) * 2.
|
||||
out = for_loop.for_loop(2, body3, jnp.float32(0.))
|
||||
out = for_impl(2, body3, jnp.float32(0.))
|
||||
self.assertEqual(out, 2.)
|
||||
|
||||
def test_for_loop_can_write_to_multiple_refs(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_write_to_multiple_refs(self, for_impl):
|
||||
def body(_, refs):
|
||||
x_ref, y_ref = refs
|
||||
x_ref[()] = jnp.float32(1.)
|
||||
y_ref[()] = jnp.float32(2.)
|
||||
x, y = for_loop.for_loop(1, body, (jnp.float32(0.), jnp.float32(0.)))
|
||||
x, y = for_impl(1, body, (jnp.float32(0.), jnp.float32(0.)))
|
||||
self.assertEqual(x, 1.)
|
||||
self.assertEqual(y, 2.)
|
||||
|
||||
def test_for_loop_can_read_from_ref(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_read_from_ref(self, for_impl):
|
||||
def body(_, x_ref):
|
||||
x_ref[()]
|
||||
x = for_loop.for_loop(1, body, jnp.float32(0.))
|
||||
x = for_impl(1, body, jnp.float32(0.))
|
||||
self.assertEqual(x, 0.)
|
||||
|
||||
def test_for_loop_can_read_from_and_write_to_ref(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_read_from_and_write_to_ref(self, for_impl):
|
||||
def body(_, x_ref):
|
||||
x = x_ref[()]
|
||||
x_ref[()] = x + jnp.float32(1.)
|
||||
x = for_loop.for_loop(5, body, jnp.float32(0.))
|
||||
x = for_impl(5, body, jnp.float32(0.))
|
||||
self.assertEqual(x, 5.)
|
||||
|
||||
def test_for_loop_can_read_from_and_write_to_refs(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_read_from_and_write_to_refs(self, for_impl):
|
||||
def body2(_, refs):
|
||||
x_ref, y_ref = refs
|
||||
x = x_ref[()]
|
||||
y_ref[()] = x + 1.
|
||||
x_ref[()] = x + 1.
|
||||
x, y = for_loop.for_loop(5, body2, (0., 0.))
|
||||
x, y = for_impl(5, body2, (0., 0.))
|
||||
self.assertEqual(x, 5.)
|
||||
self.assertEqual(y, 5.)
|
||||
|
||||
def test_for_loop_can_read_from_and_write_to_ref_slice(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_read_from_and_write_to_ref_slice(self, for_impl):
|
||||
def body(i, x_ref):
|
||||
x = x_ref[i]
|
||||
x_ref[i] = x + jnp.float32(1.)
|
||||
x = for_loop.for_loop(4, body, jnp.ones(4, jnp.float32))
|
||||
x = for_impl(4, body, jnp.ones(4, jnp.float32))
|
||||
np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32))
|
||||
|
||||
def body2(i, x_ref):
|
||||
x = x_ref[i, 0]
|
||||
x_ref[i, 1] = x + x_ref[i, 1]
|
||||
x = for_loop.for_loop(4, body2, jnp.arange(8.).reshape((4, 2)))
|
||||
x = for_impl(4, body2, jnp.arange(8.).reshape((4, 2)))
|
||||
np.testing.assert_allclose(
|
||||
x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]]))
|
||||
|
||||
def test_for_loop_can_implement_cumsum(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_can_implement_cumsum(self, for_impl):
|
||||
def cumsum(x):
|
||||
def body(i, refs):
|
||||
x_ref, accum_ref = refs
|
||||
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
|
||||
accum = jnp.zeros(x.shape[0] + 1, x.dtype)
|
||||
_, accum_out = for_loop.for_loop(x.shape[0], body, (x, accum))
|
||||
_, accum_out = for_impl(x.shape[0], body, (x, accum))
|
||||
return accum_out[1:]
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
@ -2708,18 +2736,20 @@ def for_body_reverse(i, refs):
|
||||
|
||||
reverse_ref = lambda x, y: (x, x[::-1])
|
||||
|
||||
identity = lambda x, y: (x, y)
|
||||
def for_body_noop(i, refs):
|
||||
pass
|
||||
noop_ref = lambda x, y: (x, y)
|
||||
for_reference = for_loop.discharged_for_loop
|
||||
|
||||
|
||||
class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
|
||||
jit_for, for_body_name, nsteps),
|
||||
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps}
|
||||
for jit_for in [False, True]
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps, "for_impl": for_impl}
|
||||
for for_impl, impl_name in FOR_LOOP_IMPLS
|
||||
for for_body_name, for_body, ref, body_shapes, nsteps in [
|
||||
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
|
||||
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
|
||||
@ -2729,14 +2759,12 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
def test_for_jvp(self, jit_for, f, ref, body_shapes, n):
|
||||
for_ = for_loop.for_loop
|
||||
def test_for_jvp(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
|
||||
args = [rng.randn(*s) for s in body_shapes]
|
||||
|
||||
if jit_for:
|
||||
for_ = jax.jit(for_, static_argnums=(0, 1))
|
||||
tol = {np.float64: 1e-12, np.float32: 1e-4}
|
||||
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
|
||||
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
|
||||
@ -2746,11 +2774,11 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
|
||||
jit_for, for_body_name, nsteps),
|
||||
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps}
|
||||
for jit_for in [False, True]
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps, "for_impl": for_impl}
|
||||
for for_impl, impl_name in FOR_LOOP_IMPLS
|
||||
for for_body_name, for_body, ref, body_shapes, nsteps in [
|
||||
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
|
||||
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
|
||||
@ -2760,14 +2788,12 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
def test_for_linearize(self, jit_for, f, ref, body_shapes, n):
|
||||
for_ = for_loop.for_loop
|
||||
def test_for_linearize(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
|
||||
args = [rng.randn(*s) for s in body_shapes]
|
||||
|
||||
if jit_for:
|
||||
for_ = jax.jit(for_, static_argnums=(0, 1))
|
||||
tol = {np.float64: 1e-12, np.float32: 1e-4}
|
||||
ans = jax.linearize(lambda *args: for_( n, f, args), *args)[1](*args)
|
||||
ans_discharged = jax.linearize(lambda *args: for_reference(n, f, args),
|
||||
@ -2804,7 +2830,9 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text())
|
||||
assert s.count("dynamic-update-slice(") < 2
|
||||
|
||||
def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals(self):
|
||||
@_for_loop_impls
|
||||
def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals(
|
||||
self, for_impl):
|
||||
|
||||
def body(i, refs):
|
||||
a_ref, b_ref, c_ref = refs
|
||||
@ -2815,7 +2843,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
c_ref[i] = x * b
|
||||
def f(a, b):
|
||||
c = jnp.zeros_like(a)
|
||||
_, b, c = for_loop.for_loop(5, body, (a, b, c))
|
||||
_, b, c = for_impl(5, body, (a, b, c))
|
||||
return b, c
|
||||
a = jnp.arange(5.) + 1.
|
||||
b = 1.
|
||||
@ -2826,12 +2854,13 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
|
||||
jit_for, for_body_name, nsteps),
|
||||
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps}
|
||||
for jit_for in [False, True]
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
"ref": ref, "n": nsteps, "for_impl": for_impl}
|
||||
for for_impl, impl_name in FOR_LOOP_IMPLS
|
||||
for for_body_name, for_body, ref, body_shapes, nsteps in [
|
||||
("noop", for_body_noop, noop_ref, [(4,), (4,)], 4),
|
||||
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
|
||||
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
|
||||
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
|
||||
@ -2840,14 +2869,12 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
def test_for_grad(self, jit_for, f, ref, body_shapes, n):
|
||||
for_ = for_loop.for_loop
|
||||
def test_for_grad(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
|
||||
args = [rng.randn(*s) for s in body_shapes]
|
||||
|
||||
if jit_for:
|
||||
for_ = jax.jit(for_, static_argnums=(0, 1))
|
||||
tol = {np.float64: 1e-12, np.float32: 1e-4}
|
||||
ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args)
|
||||
ans_discharged = jax.grad(
|
||||
@ -2857,7 +2884,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
atol=tol)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
|
||||
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3,
|
||||
rtol=5e-3)
|
||||
rtol=7e-3, atol=1e-2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user