diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c32b6604..42d705f34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,23 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.22 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.21...main). +* Breaking Changes + * Static arguments to `jax.pmap` must now be hashable. + + Unhashable static arguments have long been disallowed on `jax.jit`, but they + were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static + arguments using object identity. + + This behavior is a footgun, since comparing arguments using + object identity leads to recompilation each time the object identity + changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap` + wants to compare static arguments by object identity, they can define + `__hash__` and `__eq__` methods on their objects that do that, or wrap their + objects in an object that has those operations with object identity + semantics. Another option is to use `functools.partial` to encapsulate the + unhashable static arguments into the function object. + * `jax.util.partial` was an accidental export that has now been removed. Use + `functools.partial` from the Python standard library instead. ## jax 0.2.21 (Sept 23, 2021) * [GitHub @@ -19,9 +36,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * `jax.api` has been removed. Functions that were available as `jax.api.*` were aliases for functions in `jax.*`; please use the functions in `jax.*` instead. - * `jax.partial`, `jax.lax.partial`, and `jax.util.partial` were accidental - exports that have now been removed. Use `functools.partial` from the Python - standard library instead. + * `jax.partial`, and `jax.lax.partial` were accidental exports that have now + been removed. Use `functools.partial` from the Python standard library + instead. * Boolean scalar indices now raise a `TypeError`; previously this silently returned wrong results ({jax-issue}`#7925`). * Many more `jax.numpy` functions now require array-like inputs, and will error diff --git a/jax/_src/api.py b/jax/_src/api.py index 596bf9829..a18c5829f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -870,7 +870,8 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, f"but got only {len(args)} positional arguments.") f = lu.wrap_init(fun, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args) + f_partial, dyn_args = argnums_partial(f, argnums, args, + require_static_args_hashable=False) for leaf in tree_leaves(dyn_args): _check_input_dtype_grad(holomorphic, allow_int, leaf) if not has_aux: @@ -966,7 +967,8 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0, def jacfun(*args, **kwargs): f = lu.wrap_init(fun, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args) + f_partial, dyn_args = argnums_partial(f, argnums, args, + require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) pushfwd = partial(_jvp, f_partial, dyn_args) y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) @@ -1034,7 +1036,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0, def jacfun(*args, **kwargs): f = lu.wrap_init(fun, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args) + f_partial, dyn_args = argnums_partial(f, argnums, args, + require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) y, pullback = _vjp(f_partial, *dyn_args) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) @@ -1417,6 +1420,10 @@ def pmap( raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to (). + + Static arguments must be hashable, meaning both ``__hash__`` and + ``__eq__`` are implemented, and should be immutable. + devices: This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 69b58fac6..166f80fcf 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -24,7 +24,7 @@ from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure, treedef_children, treedef_is_leaf) from .tree_util import _replace_nones from .. import linear_util as lu -from .util import safe_map, WrapHashably, WrapKwArgs, Hashable +from .util import safe_map, WrapKwArgs, Hashable, Unhashable from ..core import unit from . import traceback_util @@ -118,13 +118,23 @@ def flatten_fun_nokwargs2(in_tree, *args_flat): yield (ans_flat, aux_flat), (ans_tree, aux_tree) -def argnums_partial(f, dyn_argnums, args): +def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) - fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg) - for i, arg in enumerate(args)) - dyn_args = tuple(args[i] for i in dyn_argnums) - return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args + fixed_args = [unit] * len(args) + for i, arg in enumerate(args): + if i in dyn_argnums: continue + if require_static_args_hashable: + if not is_hashable(arg): + raise ValueError( + "Non-hashable static arguments are not supported, as this can lead " + f"to unexpected cache-misses. Static argument (index {i}) of type " + f"{type(arg)} for function {f.__name__} is non-hashable.") + fixed_args[i] = Hashable(arg) + else: + fixed_args[i] = Unhashable(arg) + dyn_args = tuple(args[i] for i in dyn_argnums) + return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...], args: Tuple[Any], *, allow_invalid: bool): @@ -163,14 +173,6 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): yield ans -def argnames_partial(f, dyn_argnames, kwargs): - dyn_argnames = _ensure_str_tuple(dyn_argnames) - fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v)) - for k, v in kwargs.items()) - dyn_kwargs = {k: kwargs[k] for k in dyn_argnames} - return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs - - def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...], kwargs: Dict[str, Any]): if not static_argnames: @@ -247,13 +249,14 @@ def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]: j += 1 return tuple(out) -def wrap_hashably(arg): + +def is_hashable(arg): try: hash(arg) + return True except TypeError: - return WrapHashably(arg) # e.g. ndarrays, DeviceArrays - else: - return Hashable(arg) + return False + def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 9f05881d8..71e19467d 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -24,9 +24,8 @@ from jax import linear_util as lu from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_multimap, treedef_is_leaf, treedef_tuple, register_pytree_node_class) -from jax._src.util import cache, safe_zip, safe_map, split_list -from jax._src.api_util import (flatten_fun_nokwargs, argnums_partial, - wrap_hashably) +from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable +from jax._src.api_util import flatten_fun_nokwargs, argnums_partial from jax.core import raise_to_shaped from jax.errors import UnexpectedTracerError from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p @@ -206,7 +205,8 @@ class custom_jvp(Generic[ReturnValue]): args = tuple(_stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] - f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args) + f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args, + require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: @@ -220,7 +220,7 @@ class custom_jvp(Generic[ReturnValue]): return tree_unflatten(out_tree, out_flat) def _add_args(f, extra_args): - return _add_args_(f, tuple(map(wrap_hashably, extra_args))) + return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args)) @lu.transformation def _add_args_(extra_args, *args, **kwargs): @@ -505,9 +505,11 @@ class custom_vjp(Generic[ReturnValue]): for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] - f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args) + f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args, + require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] - fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args) + fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args, + require_static_args_hashable=False) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args diff --git a/jax/_src/util.py b/jax/_src/util.py index 6ac3c3933..4200275f5 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -218,17 +218,14 @@ def prod(xs): out *= x return out -class WrapHashably: +class Unhashable: __slots__ = ["val"] def __init__(self, val): self.val = val - def __hash__(self): - return id(self.val) - def __eq__(self, other): - return self.val is other.val + return self.val == other.val class Hashable: __slots__ = ["val"] diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 456a95bd6..3e12e8c66 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1759,6 +1759,16 @@ class PythonPmapTest(jtu.JaxTestCase): A().my_func_pmap(jnp.asarray([3] * jax.device_count())) + def test_pmap_error_on_non_hashable_static_argument(self): + f = lambda x, y: x + 3 + pmapped_f = self.pmap(f, static_broadcasted_argnums=(1,)) + + inputs = np.asarray([1] * jax.device_count()) + with self.assertRaisesRegex( + ValueError, "Non-hashable static arguments are not supported.*"): + pmapped_f(inputs, np.asarray(1)) + + class CppPmapTest(PythonPmapTest):