mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap. * Delete wrap_hashably(). * In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended.. * Delete argnames_partial, which appears unused.
This commit is contained in:
parent
372839863d
commit
a11d957e61
23
CHANGELOG.md
23
CHANGELOG.md
@ -11,6 +11,23 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.22 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.21...main).
|
||||
* Breaking Changes
|
||||
* Static arguments to `jax.pmap` must now be hashable.
|
||||
|
||||
Unhashable static arguments have long been disallowed on `jax.jit`, but they
|
||||
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
|
||||
arguments using object identity.
|
||||
|
||||
This behavior is a footgun, since comparing arguments using
|
||||
object identity leads to recompilation each time the object identity
|
||||
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
|
||||
wants to compare static arguments by object identity, they can define
|
||||
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
|
||||
objects in an object that has those operations with object identity
|
||||
semantics. Another option is to use `functools.partial` to encapsulate the
|
||||
unhashable static arguments into the function object.
|
||||
* `jax.util.partial` was an accidental export that has now been removed. Use
|
||||
`functools.partial` from the Python standard library instead.
|
||||
|
||||
## jax 0.2.21 (Sept 23, 2021)
|
||||
* [GitHub
|
||||
@ -19,9 +36,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* `jax.api` has been removed. Functions that were available as `jax.api.*`
|
||||
were aliases for functions in `jax.*`; please use the functions in
|
||||
`jax.*` instead.
|
||||
* `jax.partial`, `jax.lax.partial`, and `jax.util.partial` were accidental
|
||||
exports that have now been removed. Use `functools.partial` from the Python
|
||||
standard library instead.
|
||||
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
|
||||
been removed. Use `functools.partial` from the Python standard library
|
||||
instead.
|
||||
* Boolean scalar indices now raise a `TypeError`; previously this silently
|
||||
returned wrong results ({jax-issue}`#7925`).
|
||||
* Many more `jax.numpy` functions now require array-like inputs, and will error
|
||||
|
@ -870,7 +870,8 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
f"but got only {len(args)} positional arguments.")
|
||||
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
for leaf in tree_leaves(dyn_args):
|
||||
_check_input_dtype_grad(holomorphic, allow_int, leaf)
|
||||
if not has_aux:
|
||||
@ -966,7 +967,8 @@ def jacfwd(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
|
||||
pushfwd = partial(_jvp, f_partial, dyn_args)
|
||||
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
||||
@ -1034,7 +1036,8 @@ def jacrev(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
|
||||
def jacfun(*args, **kwargs):
|
||||
f = lu.wrap_init(fun, kwargs)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args)
|
||||
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
|
||||
y, pullback = _vjp(f_partial, *dyn_args)
|
||||
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
|
||||
@ -1417,6 +1420,10 @@ def pmap(
|
||||
raised. Each of the static arguments will be broadcasted to all devices.
|
||||
Arguments that are not arrays or containers thereof must be marked as
|
||||
static. Defaults to ().
|
||||
|
||||
Static arguments must be hashable, meaning both ``__hash__`` and
|
||||
``__eq__`` are implemented, and should be immutable.
|
||||
|
||||
devices: This is an experimental feature and the API is likely to change.
|
||||
Optional, a sequence of Devices to map over. (Available devices can be
|
||||
retrieved via jax.devices()). Must be given identically for each process
|
||||
|
@ -24,7 +24,7 @@ from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
|
||||
tree_structure, treedef_children, treedef_is_leaf)
|
||||
from .tree_util import _replace_nones
|
||||
from .. import linear_util as lu
|
||||
from .util import safe_map, WrapHashably, WrapKwArgs, Hashable
|
||||
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
|
||||
from ..core import unit
|
||||
|
||||
from . import traceback_util
|
||||
@ -118,13 +118,23 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
||||
|
||||
|
||||
def argnums_partial(f, dyn_argnums, args):
|
||||
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
|
||||
dyn_argnums = _ensure_index_tuple(dyn_argnums)
|
||||
fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg)
|
||||
for i, arg in enumerate(args))
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
|
||||
fixed_args = [unit] * len(args)
|
||||
for i, arg in enumerate(args):
|
||||
if i in dyn_argnums: continue
|
||||
if require_static_args_hashable:
|
||||
if not is_hashable(arg):
|
||||
raise ValueError(
|
||||
"Non-hashable static arguments are not supported, as this can lead "
|
||||
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
||||
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
||||
fixed_args[i] = Hashable(arg)
|
||||
else:
|
||||
fixed_args[i] = Unhashable(arg)
|
||||
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
||||
args: Tuple[Any], *, allow_invalid: bool):
|
||||
@ -163,14 +173,6 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
yield ans
|
||||
|
||||
|
||||
def argnames_partial(f, dyn_argnames, kwargs):
|
||||
dyn_argnames = _ensure_str_tuple(dyn_argnames)
|
||||
fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v))
|
||||
for k, v in kwargs.items())
|
||||
dyn_kwargs = {k: kwargs[k] for k in dyn_argnames}
|
||||
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
||||
|
||||
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
|
||||
kwargs: Dict[str, Any]):
|
||||
if not static_argnames:
|
||||
@ -247,13 +249,14 @@ def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
|
||||
j += 1
|
||||
return tuple(out)
|
||||
|
||||
def wrap_hashably(arg):
|
||||
|
||||
def is_hashable(arg):
|
||||
try:
|
||||
hash(arg)
|
||||
return True
|
||||
except TypeError:
|
||||
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
|
||||
else:
|
||||
return Hashable(arg)
|
||||
return False
|
||||
|
||||
|
||||
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
||||
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
||||
|
@ -24,9 +24,8 @@ from jax import linear_util as lu
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class)
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list
|
||||
from jax._src.api_util import (flatten_fun_nokwargs, argnums_partial,
|
||||
wrap_hashably)
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
|
||||
from jax.core import raise_to_shaped
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
@ -206,7 +205,8 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
|
||||
for i, x in enumerate(args))
|
||||
diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
jvp = _add_args(lu.wrap_init(self.jvp), static_args)
|
||||
else:
|
||||
@ -220,7 +220,7 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
def _add_args(f, extra_args):
|
||||
return _add_args_(f, tuple(map(wrap_hashably, extra_args)))
|
||||
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
|
||||
|
||||
@lu.transformation
|
||||
def _add_args_(extra_args, *args, **kwargs):
|
||||
@ -505,9 +505,11 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
for i in self.nondiff_argnums: _check_for_tracers(args[i])
|
||||
nondiff_argnums = set(self.nondiff_argnums)
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args)
|
||||
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
static_args = [args[i] for i in self.nondiff_argnums]
|
||||
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
|
||||
fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args,
|
||||
require_static_args_hashable=False)
|
||||
bwd = _add_args(lu.wrap_init(self.bwd), static_args)
|
||||
else:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
|
@ -218,17 +218,14 @@ def prod(xs):
|
||||
out *= x
|
||||
return out
|
||||
|
||||
class WrapHashably:
|
||||
class Unhashable:
|
||||
__slots__ = ["val"]
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __hash__(self):
|
||||
return id(self.val)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val is other.val
|
||||
return self.val == other.val
|
||||
|
||||
class Hashable:
|
||||
__slots__ = ["val"]
|
||||
|
@ -1759,6 +1759,16 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
A().my_func_pmap(jnp.asarray([3] * jax.device_count()))
|
||||
|
||||
def test_pmap_error_on_non_hashable_static_argument(self):
|
||||
f = lambda x, y: x + 3
|
||||
pmapped_f = self.pmap(f, static_broadcasted_argnums=(1,))
|
||||
|
||||
inputs = np.asarray([1] * jax.device_count())
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Non-hashable static arguments are not supported.*"):
|
||||
pmapped_f(inputs, np.asarray(1))
|
||||
|
||||
|
||||
|
||||
class CppPmapTest(PythonPmapTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user