fix custom_vjp issue?

This commit is contained in:
Matthew Johnson 2021-08-27 17:42:42 -07:00
parent a56a1679ab
commit 9f919c69e4
5 changed files with 64 additions and 19 deletions

View File

@ -681,6 +681,9 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
pe.partial_eval_jaxpr_custom_rules[custom_vjp_call_jaxpr_p] = \
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore
def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
@ -706,7 +709,9 @@ def custom_gradient(fun):
over intermediate values computed when evaluating the function to be
differentiated. That is, use lexical closure to share work between the forward
pass and the backward pass of reverse-mode automatic differentiation. However,
it cannot support Python control flow.
it cannot perform Python control flow which depends on the values of the
closed-over intermediate values or its cotangent arguments; if the function
includes such control flow, an error is raised.
Args:
fun: a Python callable specifying both the mathematical function to be

View File

@ -23,7 +23,6 @@ from jax._src import util
_exclude_paths = [__file__, util.__file__]
def register_exclusion(path):
return
_exclude_paths.append(path)
_jax_message_append = (

View File

@ -878,7 +878,7 @@ def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
def _partial_eval_jaxpr_custom(
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
if jaxpr.constvars: raise NotImplementedError # TODO
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
env: Dict[Var, Tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
@ -985,15 +985,17 @@ def call_partial_eval_custom_rule(
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, lambda _, __, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, lambda _, __, x, y: (x, y))
partial_eval_jaxpr_custom_rules[remat_call_p] = \
partial(call_partial_eval_custom_rule,
lambda _, __, p1, p2: (p1, dict(p2, differentiated=True)))
# TODO unify with dce code below
# TODO(mattjj): unify with dce code below
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
) -> Tuple[Jaxpr, List[bool]]:
if jaxpr.constvars: raise NotImplementedError # TODO
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
env: Dict[Var, bool] = {}
def read(v: Var) -> bool:
@ -1041,6 +1043,8 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
eqn.primitive, new_params, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.call_p] = dce_jaxpr_call_rule
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
dce_rules[remat_call_p] = dce_jaxpr_call_rule
def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
@ -1051,7 +1055,7 @@ def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
# This dead-code elimination is pretty rudimentary, and in particular doesn't
# nontrivially DCE through scan, call, or other higher-order primitives.
# TODO(mattjj): better DCE
# TODO(mattjj): better DCE (i.e. use above dce_jaxpr)
if drop_outputs:
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
else:

View File

@ -169,6 +169,7 @@ def join_tolerance(tol1, tol2):
return out
def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
a, b = np.asarray(a), np.asarray(b)
assert a.shape == b.shape
atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol))
rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol))

View File

@ -3503,11 +3503,11 @@ class RematTest(jtu.JaxTestCase):
def test_remat_checkpoint_dots(self):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
return x
@ -3521,11 +3521,11 @@ class RematTest(jtu.JaxTestCase):
@api.jit
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x * 1e-3)
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x * 1e-3)
x = jnp.dot(x, x)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x * 1e-3)
return x
@ -3541,9 +3541,9 @@ class RematTest(jtu.JaxTestCase):
def f(W):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.sin(jnp.dot(x, W))
x = jnp.sin(jnp.dot(x, W))
x = jnp.sin(jnp.dot(x, W))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
return x
def body(x, _): return f(x), None
@ -3565,8 +3565,6 @@ class RematTest(jtu.JaxTestCase):
modes=['fwd', 'rev'])
def test_remat_custom_jvp_policy(self):
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
@api.custom_jvp
def sin(x):
return jnp.sin(x)
@ -3576,9 +3574,15 @@ class RematTest(jtu.JaxTestCase):
return sin(x), jnp.cos(x) * g
sin.defjvp(sin_jvp)
@partial(api.remat, policy=save_sin)
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
return sin(sin(x))
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = sin(x * 1e-3)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = sin(x * 1e-3)
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = sin(x * 1e-3)
return x
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
@ -3586,6 +3590,38 @@ class RematTest(jtu.JaxTestCase):
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
jtu.check_grads(g, (3.,), order=2, modes=['fwd', 'rev'])
def test_remat_custom_vjp_policy(self):
@api.custom_vjp
def sin(x):
return jnp.sin(x)
def sin_fwd(x):
return sin(x), x
def sin_bwd(x, y_bar):
return (jnp.cos(x) * y_bar,)
sin.defvjp(sin_fwd, sin_bwd)
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
@partial(api.named_call, name="dot")
def dot2(y, z):
return jnp.dot(x, jnp.dot(y, z, precision=lax.Precision.HIGHEST),
precision=lax.Precision.HIGHEST)
x = dot2(x, x)
x = sin(x * 1e-3)
x = dot2(x, x)
x = sin(x * 1e-3)
x = dot2(x, x)
x = sin(x * 1e-3)
return x
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
def g(x):
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
jtu.check_grads(g, (3.,), order=2, modes=['rev'])
class JaxprTest(jtu.JaxTestCase):
def test_scalar_literals(self):