mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix custom_vjp issue?
This commit is contained in:
parent
a56a1679ab
commit
9f919c69e4
@ -681,6 +681,9 @@ xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[custom_vjp_call_jaxpr_p] = \
|
||||
custom_jvp_jaxpr_custom_partial_eval_rule # type: ignore
|
||||
|
||||
|
||||
def custom_gradient(fun):
|
||||
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
||||
@ -706,7 +709,9 @@ def custom_gradient(fun):
|
||||
over intermediate values computed when evaluating the function to be
|
||||
differentiated. That is, use lexical closure to share work between the forward
|
||||
pass and the backward pass of reverse-mode automatic differentiation. However,
|
||||
it cannot support Python control flow.
|
||||
it cannot perform Python control flow which depends on the values of the
|
||||
closed-over intermediate values or its cotangent arguments; if the function
|
||||
includes such control flow, an error is raised.
|
||||
|
||||
Args:
|
||||
fun: a Python callable specifying both the mathematical function to be
|
||||
|
@ -23,7 +23,6 @@ from jax._src import util
|
||||
_exclude_paths = [__file__, util.__file__]
|
||||
|
||||
def register_exclusion(path):
|
||||
return
|
||||
_exclude_paths.append(path)
|
||||
|
||||
_jax_message_append = (
|
||||
|
@ -878,7 +878,7 @@ def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
|
||||
def _partial_eval_jaxpr_custom(
|
||||
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
|
||||
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
|
||||
env: Dict[Var, Tuple[bool, bool]] = {}
|
||||
residuals: OrderedSet[Var] = OrderedSet()
|
||||
|
||||
@ -985,15 +985,17 @@ def call_partial_eval_custom_rule(
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, lambda _, __, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
|
||||
partial(call_partial_eval_custom_rule, lambda _, __, x, y: (x, y))
|
||||
partial_eval_jaxpr_custom_rules[remat_call_p] = \
|
||||
partial(call_partial_eval_custom_rule,
|
||||
lambda _, __, p1, p2: (p1, dict(p2, differentiated=True)))
|
||||
|
||||
|
||||
# TODO unify with dce code below
|
||||
# TODO(mattjj): unify with dce code below
|
||||
def dce_jaxpr(jaxpr: Jaxpr, used_outputs: List[bool]
|
||||
) -> Tuple[Jaxpr, List[bool]]:
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO
|
||||
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
|
||||
env: Dict[Var, bool] = {}
|
||||
|
||||
def read(v: Var) -> bool:
|
||||
@ -1041,6 +1043,8 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
|
||||
eqn.primitive, new_params, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
dce_rules[core.call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[core.named_call_p] = dce_jaxpr_call_rule
|
||||
dce_rules[remat_call_p] = dce_jaxpr_call_rule
|
||||
|
||||
|
||||
def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr:
|
||||
@ -1051,7 +1055,7 @@ def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=
|
||||
def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr:
|
||||
# This dead-code elimination is pretty rudimentary, and in particular doesn't
|
||||
# nontrivially DCE through scan, call, or other higher-order primitives.
|
||||
# TODO(mattjj): better DCE
|
||||
# TODO(mattjj): better DCE (i.e. use above dce_jaxpr)
|
||||
if drop_outputs:
|
||||
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
|
||||
else:
|
||||
|
@ -169,6 +169,7 @@ def join_tolerance(tol1, tol2):
|
||||
return out
|
||||
|
||||
def _assert_numpy_close(a, b, atol=None, rtol=None, err_msg=''):
|
||||
a, b = np.asarray(a), np.asarray(b)
|
||||
assert a.shape == b.shape
|
||||
atol = max(tolerance(a.dtype, atol), tolerance(b.dtype, atol))
|
||||
rtol = max(tolerance(a.dtype, rtol), tolerance(b.dtype, rtol))
|
||||
|
@ -3503,11 +3503,11 @@ class RematTest(jtu.JaxTestCase):
|
||||
def test_remat_checkpoint_dots(self):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x)
|
||||
return x
|
||||
|
||||
@ -3521,11 +3521,11 @@ class RematTest(jtu.JaxTestCase):
|
||||
@api.jit
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x * 1e-3)
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x * 1e-3)
|
||||
x = jnp.dot(x, x)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = jnp.sin(x * 1e-3)
|
||||
return x
|
||||
|
||||
@ -3541,9 +3541,9 @@ class RematTest(jtu.JaxTestCase):
|
||||
def f(W):
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
x = jnp.sin(jnp.dot(x, W))
|
||||
x = jnp.sin(jnp.dot(x, W))
|
||||
x = jnp.sin(jnp.dot(x, W))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
x = jnp.sin(jnp.dot(x, W, precision=lax.Precision.HIGHEST))
|
||||
return x
|
||||
|
||||
def body(x, _): return f(x), None
|
||||
@ -3565,8 +3565,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_custom_jvp_policy(self):
|
||||
save_sin = lambda prim, *_, **__: str(prim) == 'sin'
|
||||
|
||||
@api.custom_jvp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
@ -3576,9 +3574,15 @@ class RematTest(jtu.JaxTestCase):
|
||||
return sin(x), jnp.cos(x) * g
|
||||
sin.defjvp(sin_jvp)
|
||||
|
||||
@partial(api.remat, policy=save_sin)
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
return sin(sin(x))
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = sin(x * 1e-3)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = sin(x * 1e-3)
|
||||
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
|
||||
x = sin(x * 1e-3)
|
||||
return x
|
||||
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
@ -3586,6 +3590,38 @@ class RematTest(jtu.JaxTestCase):
|
||||
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
|
||||
jtu.check_grads(g, (3.,), order=2, modes=['fwd', 'rev'])
|
||||
|
||||
def test_remat_custom_vjp_policy(self):
|
||||
@api.custom_vjp
|
||||
def sin(x):
|
||||
return jnp.sin(x)
|
||||
def sin_fwd(x):
|
||||
return sin(x), x
|
||||
def sin_bwd(x, y_bar):
|
||||
return (jnp.cos(x) * y_bar,)
|
||||
sin.defvjp(sin_fwd, sin_bwd)
|
||||
|
||||
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
|
||||
def f(x):
|
||||
@partial(api.named_call, name="dot")
|
||||
def dot2(y, z):
|
||||
return jnp.dot(x, jnp.dot(y, z, precision=lax.Precision.HIGHEST),
|
||||
precision=lax.Precision.HIGHEST)
|
||||
|
||||
x = dot2(x, x)
|
||||
x = sin(x * 1e-3)
|
||||
x = dot2(x, x)
|
||||
x = sin(x * 1e-3)
|
||||
x = dot2(x, x)
|
||||
x = sin(x * 1e-3)
|
||||
return x
|
||||
|
||||
jtu.check_grads(f, (3.,), order=2, modes=['rev'])
|
||||
|
||||
def g(x):
|
||||
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
|
||||
jtu.check_grads(g, (3.,), order=2, modes=['rev'])
|
||||
|
||||
|
||||
class JaxprTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scalar_literals(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user