From 0a86e9a92910ba55cd1b120082fe7edf678d03da Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 13 Jun 2024 13:14:27 -0700 Subject: [PATCH] Deprecate hashing of tracers --- CHANGELOG.md | 3 +++ jax/BUILD | 1 + jax/__init__.py | 1 + jax/_src/core.py | 14 +++++++++++++- jax/_src/scipy/optimize/line_search.py | 2 +- tests/array_test.py | 17 +++++++++++++++++ 6 files changed, 36 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cc510647..9a885b0af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed in a future release. + * Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX + release. This previously was the case, but there was an inadvertent regression in + the last several JAX releases. ## jaxlib 0.4.30 diff --git a/jax/BUILD b/jax/BUILD index fad9845f8..3105256eb 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -464,6 +464,7 @@ pytype_strict_library( deps = [ ":compute_on", ":config", + ":deprecations", ":dtypes", ":effects", ":pretty_printer", diff --git a/jax/__init__.py b/jax/__init__.py index c016c3afb..9dd7ebc85 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -181,6 +181,7 @@ del _ccache from jax._src.deprecations import register as _register_deprecation _register_deprecation("jax-experimental-maps-module") _register_deprecation('jax-scipy-beta-args') +_register_deprecation('tracer-hash') del _register_deprecation _deprecations = { diff --git a/jax/_src/core.py b/jax/_src/core.py index 53c2e71f8..78146cc2b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import collections # noqa: F401 from collections import Counter, defaultdict, deque, namedtuple from collections.abc import (Collection, Generator, Hashable, Iterable, Iterator, Set, Sequence, MutableSet, @@ -36,6 +35,7 @@ from weakref import ref import numpy as np +from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects @@ -668,6 +668,18 @@ class Tracer(typing.Array, metaclass=StrictABCMeta): size = _aval_property('size') shape = _aval_property('shape') + def __hash__(self): + # TODO(jakevdp) finalize this deprecation and set __hash__ = None + # Warning added 2024-06-13 + if deprecations.is_accelerated('tracer-hash'): + raise TypeError(f"unhashable type: {type(self)}") + # Use FutureWarning rather than DeprecationWarning because hash is likely + # not called directly by the user, so we want to warn at all stacklevels. + warnings.warn( + f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an" + " error in a future JAX release.", category=FutureWarning) + return super().__hash__() + def __init__(self, trace: Trace): self._trace = trace diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 078d23d97..189009693 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -118,7 +118,7 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo, # This will cause the line search to stop, and since the Wolfe conditions # are not satisfied the minimization should stop too. - threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10) + threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10) state = state._replace(failed=state.failed | (dalpha <= threshold)) # Cubmin is sometimes nan, though in this case the bounds check will fail. diff --git a/tests/array_test.py b/tests/array_test.py index 3efe67306..88ade2305 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -24,6 +24,7 @@ import numpy as np import jax import jax.numpy as jnp from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import op_shardings from jax._src import test_util as jtu @@ -597,6 +598,22 @@ class JaxArrayTest(jtu.JaxTestCase): x = jnp.array([1, 2, 3]) self.assertIsInstance(x.addressable_data(0), array.ArrayImpl) + def test_array_not_hashable(self): + x = jnp.arange(4) + with self.assertRaisesRegex(TypeError, "unhashable type"): + hash(x) + + @jax.jit + def check_tracer_hash(x): + self.assertIsInstance(hash(x), int) + + if deprecations.is_accelerated('tracer-hash'): + with self.assertRaisesRegex(TypeError, "unhashable type"): + check_tracer_hash(x) + else: + with self.assertWarnsRegex(FutureWarning, "unhashable type"): + check_tracer_hash(x) + def test_shape_dtype_struct_sharding_jit(self): mesh = jtu.create_global_mesh((8,), ('x')) s = jax.sharding.NamedSharding(mesh, P('x'))