Disallow non-hashable static arguments in pmap().

* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
This commit is contained in:
Peter Hawkins 2021-07-19 13:11:38 -04:00
parent 372839863d
commit a11d957e61
6 changed files with 72 additions and 36 deletions

View File

@ -11,6 +11,23 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.22 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.21...main).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.
Unhashable static arguments have long been disallowed on `jax.jit`, but they
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
arguments using object identity.
This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
wants to compare static arguments by object identity, they can define
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use `functools.partial` to encapsulate the
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
## jax 0.2.21 (Sept 23, 2021)
* [GitHub
@ -19,9 +36,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, `jax.lax.partial`, and `jax.util.partial` were accidental
exports that have now been removed. Use `functools.partial` from the Python
standard library instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`#7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error

View File

@ -870,7 +870,8 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
f"but got only {len(args)} positional arguments.")
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
for leaf in tree_leaves(dyn_args):
_check_input_dtype_grad(holomorphic, allow_int, leaf)
if not has_aux:
@ -966,7 +967,8 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
pushfwd = partial(_jvp, f_partial, dyn_args)
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
@ -1034,7 +1036,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
def jacfun(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = argnums_partial(f, argnums, args)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
y, pullback = _vjp(f_partial, *dyn_args)
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
@ -1417,6 +1420,10 @@ def pmap(
raised. Each of the static arguments will be broadcasted to all devices.
Arguments that are not arrays or containers thereof must be marked as
static. Defaults to ().
Static arguments must be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and should be immutable.
devices: This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). Must be given identically for each process

View File

@ -24,7 +24,7 @@ from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
tree_structure, treedef_children, treedef_is_leaf)
from .tree_util import _replace_nones
from .. import linear_util as lu
from .util import safe_map, WrapHashably, WrapKwArgs, Hashable
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
from ..core import unit
from . import traceback_util
@ -118,13 +118,23 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args):
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
dyn_argnums = _ensure_index_tuple(dyn_argnums)
fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg)
for i, arg in enumerate(args))
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
fixed_args = [unit] * len(args)
for i, arg in enumerate(args):
if i in dyn_argnums: continue
if require_static_args_hashable:
if not is_hashable(arg):
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
fixed_args[i] = Hashable(arg)
else:
fixed_args[i] = Unhashable(arg)
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
args: Tuple[Any], *, allow_invalid: bool):
@ -163,14 +173,6 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
yield ans
def argnames_partial(f, dyn_argnames, kwargs):
dyn_argnames = _ensure_str_tuple(dyn_argnames)
fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v))
for k, v in kwargs.items())
dyn_kwargs = {k: kwargs[k] for k in dyn_argnames}
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
kwargs: Dict[str, Any]):
if not static_argnames:
@ -247,13 +249,14 @@ def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
j += 1
return tuple(out)
def wrap_hashably(arg):
def is_hashable(arg):
try:
hash(arg)
return True
except TypeError:
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
else:
return Hashable(arg)
return False
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the

View File

@ -24,9 +24,8 @@ from jax import linear_util as lu
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
register_pytree_node_class)
from jax._src.util import cache, safe_zip, safe_map, split_list
from jax._src.api_util import (flatten_fun_nokwargs, argnums_partial,
wrap_hashably)
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
@ -206,7 +205,8 @@ class custom_jvp(Generic[ReturnValue]):
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
for i, x in enumerate(args))
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
else:
@ -220,7 +220,7 @@ class custom_jvp(Generic[ReturnValue]):
return tree_unflatten(out_tree, out_flat)
def _add_args(f, extra_args):
return _add_args_(f, tuple(map(wrap_hashably, extra_args)))
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
@ -505,9 +505,11 @@ class custom_vjp(Generic[ReturnValue]):
for i in self.nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums = set(self.nondiff_argnums)
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args,
require_static_args_hashable=False)
static_args = [args[i] for i in self.nondiff_argnums]
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
require_static_args_hashable=False)
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
else:
f_, dyn_args = lu.wrap_init(self.fun), args

View File

@ -218,17 +218,14 @@ def prod(xs):
out *= x
return out
class WrapHashably:
class Unhashable:
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return self.val is other.val
return self.val == other.val
class Hashable:
__slots__ = ["val"]

View File

@ -1759,6 +1759,16 @@ class PythonPmapTest(jtu.JaxTestCase):
A().my_func_pmap(jnp.asarray([3] * jax.device_count()))
def test_pmap_error_on_non_hashable_static_argument(self):
f = lambda x, y: x + 3
pmapped_f = self.pmap(f, static_broadcasted_argnums=(1,))
inputs = np.asarray([1] * jax.device_count())
with self.assertRaisesRegex(
ValueError, "Non-hashable static arguments are not supported.*"):
pmapped_f(inputs, np.asarray(1))
class CppPmapTest(PythonPmapTest):