From d60d5d773740b7d6814b4815bcbce6876516e2c1 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 18 Mar 2022 11:09:32 -0700 Subject: [PATCH] fix typo in #9923 --- jax/_src/custom_derivatives.py | 2 +- tests/api_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c13ee7576..83a2828bb 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -957,7 +957,7 @@ def _maybe_perturbed(x: Any) -> bool: # happen later, but some types always have trivial tangents. vspace = x.aval.at_least_vspace() return not (vspace is core.abstract_unit or vspace is core.abstract_token or - vspace is dtypes.float0) + getattr(vspace, 'dtype', None) is dtypes.float0) elif not isinstance(x, ad.JVPTracer): # If x is not a JVPTracer, recursively check its contents. return any(_maybe_perturbed(attr) for name, attr in x._contents()) diff --git a/tests/api_test.py b/tests/api_test.py index 4448ae550..a54a96161 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5435,6 +5435,17 @@ class CustomJVPTest(jtu.JaxTestCase): jax.jvp(f, (1.0,), (1.0,)) # assertions inside f + def test_maybe_perturbed_int_regression(self): + # see https://github.com/google/jax/discussions/9951 + from jax._src.custom_derivatives import closure_convert + + @jax.jit + def f(): + x = jnp.array(1) + _, aux_args = closure_convert(lambda: x) + self.assertEmpty(aux_args) + f() + class CustomVJPTest(jtu.JaxTestCase):