This commit is contained in:
Matthew Johnson 2022-03-16 15:47:00 -07:00
parent 4cdc25f1f7
commit b1847bc41e
2 changed files with 35 additions and 7 deletions

View File

@ -25,6 +25,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
register_pytree_node_class)
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax.core import raise_to_shaped
@ -943,13 +944,25 @@ def closure_convert(fun, *example_args):
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)
def _is_perturbed(x: Any) -> bool:
if isinstance(x, ad.JVPTracer):
return True
elif isinstance(x, core.Tracer):
return any(_is_perturbed(attr) for name, attr in x._contents())
else:
def _maybe_perturbed(x: Any) -> bool:
# False if x can't represent an AD-perturbed value (i.e. a value
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
# See https://github.com/google/jax/issues/6415 for motivation.
x = core.full_lower(x)
if not isinstance(x, core.Tracer):
# If x is not a Tracer, it can't be perturbed.
return False
elif isinstance(x, pe.DynamicJaxprTracer):
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
# happen later, but some types always have trivial tangents.
vspace = x.aval.at_least_vspace()
return not (vspace is core.abstract_unit or vspace is core.abstract_token or
vspace is dtypes.float0)
elif not isinstance(x, ad.JVPTracer):
# If x is not a JVPTracer, recursively check its contents.
return any(_maybe_perturbed(attr) for name, attr in x._contents())
else:
return True # We can't be sure!
@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
@ -957,7 +970,7 @@ def _closure_convert_for_avals(fun, in_tree, in_avals):
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()
(closure_consts, hoisted_consts), merge = partition_list(_is_perturbed, consts)
(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
num_consts = len(hoisted_consts)
def converted_fun(*args_hconsts):

View File

@ -5420,6 +5420,21 @@ class CustomJVPTest(jtu.JaxTestCase):
shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape
self.assertEqual(shape, ())
def test_maybe_perturbed_internal_helper_function(self):
# This is a unit test for an internal API. We include it so as not to
# regress https://github.com/google/jax/issues/9567. For an explanation of
# this helper function, see https://github.com/google/jax/issues/6415.
from jax._src.custom_derivatives import _maybe_perturbed
def f(x):
def g(y, _):
z = y * x
self.assertTrue(_maybe_perturbed(z))
return y, None
g(1, None)
return lax.scan(g, 1, xs=None, length=1)[0]
jax.jvp(f, (1.0,), (1.0,)) # assertions inside f
class CustomVJPTest(jtu.JaxTestCase):