This commit is contained in:
Matthew Johnson 2022-03-18 11:09:32 -07:00
parent e9f59aed84
commit d60d5d7737
2 changed files with 12 additions and 1 deletions

View File

@ -957,7 +957,7 @@ def _maybe_perturbed(x: Any) -> bool:
# happen later, but some types always have trivial tangents.
vspace = x.aval.at_least_vspace()
return not (vspace is core.abstract_unit or vspace is core.abstract_token or
vspace is dtypes.float0)
getattr(vspace, 'dtype', None) is dtypes.float0)
elif not isinstance(x, ad.JVPTracer):
# If x is not a JVPTracer, recursively check its contents.
return any(_maybe_perturbed(attr) for name, attr in x._contents())

View File

@ -5435,6 +5435,17 @@ class CustomJVPTest(jtu.JaxTestCase):
jax.jvp(f, (1.0,), (1.0,)) # assertions inside f
def test_maybe_perturbed_int_regression(self):
# see https://github.com/google/jax/discussions/9951
from jax._src.custom_derivatives import closure_convert
@jax.jit
def f():
x = jnp.array(1)
_, aux_args = closure_convert(lambda: x)
self.assertEmpty(aux_args)
f()
class CustomVJPTest(jtu.JaxTestCase):