mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Deprecate hashing of tracers
This commit is contained in:
parent
0dc706d79f
commit
0a86e9a929
@ -19,6 +19,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
|
||||
in a future release.
|
||||
* Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
|
||||
release. This previously was the case, but there was an inadvertent regression in
|
||||
the last several JAX releases.
|
||||
|
||||
## jaxlib 0.4.30
|
||||
|
||||
|
@ -464,6 +464,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":compute_on",
|
||||
":config",
|
||||
":deprecations",
|
||||
":dtypes",
|
||||
":effects",
|
||||
":pretty_printer",
|
||||
|
@ -181,6 +181,7 @@ del _ccache
|
||||
from jax._src.deprecations import register as _register_deprecation
|
||||
_register_deprecation("jax-experimental-maps-module")
|
||||
_register_deprecation('jax-scipy-beta-args')
|
||||
_register_deprecation('tracer-hash')
|
||||
del _register_deprecation
|
||||
|
||||
_deprecations = {
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections # noqa: F401
|
||||
from collections import Counter, defaultdict, deque, namedtuple
|
||||
from collections.abc import (Collection, Generator, Hashable, Iterable,
|
||||
Iterator, Set, Sequence, MutableSet,
|
||||
@ -36,6 +35,7 @@ from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import config
|
||||
from jax._src import effects
|
||||
@ -668,6 +668,18 @@ class Tracer(typing.Array, metaclass=StrictABCMeta):
|
||||
size = _aval_property('size')
|
||||
shape = _aval_property('shape')
|
||||
|
||||
def __hash__(self):
|
||||
# TODO(jakevdp) finalize this deprecation and set __hash__ = None
|
||||
# Warning added 2024-06-13
|
||||
if deprecations.is_accelerated('tracer-hash'):
|
||||
raise TypeError(f"unhashable type: {type(self)}")
|
||||
# Use FutureWarning rather than DeprecationWarning because hash is likely
|
||||
# not called directly by the user, so we want to warn at all stacklevels.
|
||||
warnings.warn(
|
||||
f"unhashable type: {type(self)}. Attempting to hash a tracer will lead to an"
|
||||
" error in a future JAX release.", category=FutureWarning)
|
||||
return super().__hash__()
|
||||
|
||||
def __init__(self, trace: Trace):
|
||||
self._trace = trace
|
||||
|
||||
|
@ -118,7 +118,7 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
|
||||
# This will cause the line search to stop, and since the Wolfe conditions
|
||||
# are not satisfied the minimization should stop too.
|
||||
threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10)
|
||||
threshold = jnp.where((jnp.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10)
|
||||
state = state._replace(failed=state.failed | (dalpha <= threshold))
|
||||
|
||||
# Cubmin is sometimes nan, though in this case the bounds check will fail.
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
@ -597,6 +598,22 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
x = jnp.array([1, 2, 3])
|
||||
self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)
|
||||
|
||||
def test_array_not_hashable(self):
|
||||
x = jnp.arange(4)
|
||||
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
||||
hash(x)
|
||||
|
||||
@jax.jit
|
||||
def check_tracer_hash(x):
|
||||
self.assertIsInstance(hash(x), int)
|
||||
|
||||
if deprecations.is_accelerated('tracer-hash'):
|
||||
with self.assertRaisesRegex(TypeError, "unhashable type"):
|
||||
check_tracer_hash(x)
|
||||
else:
|
||||
with self.assertWarnsRegex(FutureWarning, "unhashable type"):
|
||||
check_tracer_hash(x)
|
||||
|
||||
def test_shape_dtype_struct_sharding_jit(self):
|
||||
mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
s = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user