Deprecate hashing of tracers

This commit is contained in:
Jake VanderPlas 2024-06-13 13:14:27 -07:00
parent 0dc706d79f
commit 0a86e9a929
6 changed files with 36 additions and 2 deletions

View File

@ -19,6 +19,9 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
## jaxlib 0.4.30

View File

@ -464,6 +464,7 @@ pytype_strict_library(
deps = [
":compute_on",
":config",
":deprecations",
":dtypes",
":effects",
":pretty_printer",

View File

@ -181,6 +181,7 @@ del _ccache
from jax._src.deprecations import register as _register_deprecation
_register_deprecation("jax-experimental-maps-module")
_register_deprecation('jax-scipy-beta-args')
_register_deprecation('tracer-hash')
del _register_deprecation
_deprecations = {

View File

@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import collections # noqa: F401
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Collection, Generator, Hashable, Iterable,
Iterator, Set, Sequence, MutableSet,
@ -36,6 +35,7 @@ from weakref import ref
import numpy as np
from jax._src import deprecations
from jax._src import dtypes
from jax._src import config
from jax._src import effects
@ -668,6 +668,18 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
size = _aval_property('size')
shape = _aval_property('shape')
def __hash__(self):
# TODO(jakevdp) finalize this deprecation and set __hash__ = None
# Warning added 2024-06-13
if deprecations.is_accelerated('tracer-hash'):
raise TypeError(f"unhashable type: {type(self)}")
# Use FutureWarning rather than DeprecationWarning because hash is likely
# not called directly by the user, so we want to warn at all stacklevels.
warnings.warn(
f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an"
" error in a future JAX release.", category=FutureWarning)
return super().__hash__()
def __init__(self, trace: Trace):
self._trace = trace

View File

@ -118,7 +118,7 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
# This will cause the line search to stop, and since the Wolfe conditions
# are not satisfied the minimization should stop too.
threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10)
threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10)
state = state._replace(failed=state.failed | (dalpha <= threshold))
# Cubmin is sometimes nan, though in this case the bounds check will fail.

View File

@ -24,6 +24,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import op_shardings
from jax._src import test_util as jtu
@ -597,6 +598,22 @@ class JaxArrayTest(jtu.JaxTestCase):
x = jnp.array([1, 2, 3])
self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)
def test_array_not_hashable(self):
x = jnp.arange(4)
with self.assertRaisesRegex(TypeError, "unhashable type"):
hash(x)
@jax.jit
def check_tracer_hash(x):
self.assertIsInstance(hash(x), int)
if deprecations.is_accelerated('tracer-hash'):
with self.assertRaisesRegex(TypeError, "unhashable type"):
check_tracer_hash(x)
else:
with self.assertWarnsRegex(FutureWarning, "unhashable type"):
check_tracer_hash(x)
def test_shape_dtype_struct_sharding_jit(self):
mesh = jtu.create_global_mesh((8,), ('x'))
s = jax.sharding.NamedSharding(mesh, P('x'))