mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix #9567
This commit is contained in:
parent
4cdc25f1f7
commit
b1847bc41e
@ -25,6 +25,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class)
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
|
||||
from jax.core import raise_to_shaped
|
||||
@ -943,13 +944,25 @@ def closure_convert(fun, *example_args):
|
||||
else:
|
||||
return _closure_convert_for_avals(fun, in_tree, in_avals)
|
||||
|
||||
def _is_perturbed(x: Any) -> bool:
|
||||
if isinstance(x, ad.JVPTracer):
|
||||
return True
|
||||
elif isinstance(x, core.Tracer):
|
||||
return any(_is_perturbed(attr) for name, attr in x._contents())
|
||||
else:
|
||||
def _maybe_perturbed(x: Any) -> bool:
|
||||
# False if x can't represent an AD-perturbed value (i.e. a value
|
||||
# with a nontrivial tangent attached), up to heuristics, and True otherwise.
|
||||
# See https://github.com/google/jax/issues/6415 for motivation.
|
||||
x = core.full_lower(x)
|
||||
if not isinstance(x, core.Tracer):
|
||||
# If x is not a Tracer, it can't be perturbed.
|
||||
return False
|
||||
elif isinstance(x, pe.DynamicJaxprTracer):
|
||||
# If x is a DynamicJaxprTracer then we're staging out; differentiation could
|
||||
# happen later, but some types always have trivial tangents.
|
||||
vspace = x.aval.at_least_vspace()
|
||||
return not (vspace is core.abstract_unit or vspace is core.abstract_token or
|
||||
vspace is dtypes.float0)
|
||||
elif not isinstance(x, ad.JVPTracer):
|
||||
# If x is not a JVPTracer, recursively check its contents.
|
||||
return any(_maybe_perturbed(attr) for name, attr in x._contents())
|
||||
else:
|
||||
return True # We can't be sure!
|
||||
|
||||
@cache()
|
||||
def _closure_convert_for_avals(fun, in_tree, in_avals):
|
||||
@ -957,7 +970,7 @@ def _closure_convert_for_avals(fun, in_tree, in_avals):
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
out_tree = out_tree()
|
||||
|
||||
(closure_consts, hoisted_consts), merge = partition_list(_is_perturbed, consts)
|
||||
(closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
|
||||
num_consts = len(hoisted_consts)
|
||||
|
||||
def converted_fun(*args_hconsts):
|
||||
|
@ -5420,6 +5420,21 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape
|
||||
self.assertEqual(shape, ())
|
||||
|
||||
def test_maybe_perturbed_internal_helper_function(self):
|
||||
# This is a unit test for an internal API. We include it so as not to
|
||||
# regress https://github.com/google/jax/issues/9567. For an explanation of
|
||||
# this helper function, see https://github.com/google/jax/issues/6415.
|
||||
from jax._src.custom_derivatives import _maybe_perturbed
|
||||
def f(x):
|
||||
def g(y, _):
|
||||
z = y * x
|
||||
self.assertTrue(_maybe_perturbed(z))
|
||||
return y, None
|
||||
g(1, None)
|
||||
return lax.scan(g, 1, xs=None, length=1)[0]
|
||||
|
||||
jax.jvp(f, (1.0,), (1.0,)) # assertions inside f
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user