mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fix typo in #9923
This commit is contained in:
parent
e9f59aed84
commit
d60d5d7737
@ -957,7 +957,7 @@ def _maybe_perturbed(x: Any) -> bool:
|
||||
# happen later, but some types always have trivial tangents.
|
||||
vspace = x.aval.at_least_vspace()
|
||||
return not (vspace is core.abstract_unit or vspace is core.abstract_token or
|
||||
vspace is dtypes.float0)
|
||||
getattr(vspace, 'dtype', None) is dtypes.float0)
|
||||
elif not isinstance(x, ad.JVPTracer):
|
||||
# If x is not a JVPTracer, recursively check its contents.
|
||||
return any(_maybe_perturbed(attr) for name, attr in x._contents())
|
||||
|
@ -5435,6 +5435,17 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
jax.jvp(f, (1.0,), (1.0,)) # assertions inside f
|
||||
|
||||
def test_maybe_perturbed_int_regression(self):
|
||||
# see https://github.com/google/jax/discussions/9951
|
||||
from jax._src.custom_derivatives import closure_convert
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
x = jnp.array(1)
|
||||
_, aux_args = closure_convert(lambda: x)
|
||||
self.assertEmpty(aux_args)
|
||||
f()
|
||||
|
||||
|
||||
class CustomVJPTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user