Emit a warning on non-hashble static arguments in jax.jit.

The message looks like, e.g.:

Static argument (index 1) of type <class 'numpy.ndarray'> for function f is non-hashable. As this can lead to unexpected cache-misses, it will raise an error in a near future.
This commit is contained in:
Jean-Baptiste Lespiau 2020-10-20 22:22:33 +02:00 committed by Jean-Baptiste
parent 9ba28d2634
commit 98ee69ba1c
2 changed files with 44 additions and 17 deletions

View File

@ -41,8 +41,9 @@ from . import ad_util
from . import dtypes
from .core import eval_jaxpr
from .api_util import (wraps, flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial, flatten_axes,
donation_vector, rebase_donate_argnums)
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums)
from .traceback_util import api_boundary
from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose, tree_leaves, tree_multimap,
@ -130,11 +131,12 @@ def jit(fun: Callable[..., T],
arguments to treat as static (compile-time constant). Operations that only
depend on static arguments will be constant-folded in Python (during
tracing), and so the corresponding argument values can be any Python
object. Calling the jitted function with different values for these
constants will trigger recompilation. If the jitted function is called
with fewer positional arguments than indicated by ``static_argnums`` then
an error is raised. Arguments that are not arrays or containers thereof
must be marked as static. Defaults to ().
object. Static arguments should be hashable, meaning both ``__hash__``
and ``__eq__` are implemented. Calling the jitted function with different
values for these constants will trigger recompilation. If the jitted
function is called with fewer positional arguments than indicated by
``static_argnums`` then an error is raised. Arguments that are not arrays
or containers thereof must be marked as static. Defaults to ().
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited from
@ -198,8 +200,7 @@ def _python_jit(
raise ValueError(msg.format(static_argnums, donate_argnums, len(args)))
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
f, dyn_args = argnums_partial_except(f, static_argnums, args)
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
@ -258,8 +259,7 @@ def _cpp_jit(
# work/code that is redundant between C++ and Python. We can try that later.
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
f, dyn_args = argnums_partial_except(f, static_argnums, args)
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
@ -623,8 +623,7 @@ def _xla_computation(
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, dyn_args = argnums_partial(f, dyn_argnums, args)
f, dyn_args = argnums_partial_except(f, static_argnums, args)
else:
dyn_args = args
args_flat, in_tree = tree_flatten((dyn_args, kwargs))

View File

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple
from absl import logging
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap, _replace_nones,
tree_structure)
@ -19,8 +22,6 @@ from . import linear_util as lu
from .util import safe_map, curry, WrapHashably, Hashable
from .core import unit
from typing import Tuple
map = safe_map
@ -48,7 +49,7 @@ def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@ -62,7 +63,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@ -74,6 +75,7 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args):
if isinstance(dyn_argnums, int):
dyn_argnums = (dyn_argnums,)
@ -84,6 +86,32 @@ def argnums_partial(f, dyn_argnums, args):
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
args: Tuple[Any]):
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
dyn_args = tuple(args[i] for i in dyn_argnums)
fixed_args = [unit] * len(args) # type: ignore
for i in static_argnums:
static_arg = args[i]
try:
hash(static_arg)
except TypeError:
logging.warning(
"Static argument (index %s) of type %s for function %s is "
"non-hashable. As this can lead to unexpected cache-misses, it "
"will raise an error in a near future.", i, type(static_arg),
f.__name__)
# e.g. ndarrays, DeviceArrays
fixed_args[i] = WrapHashably(static_arg) # type: ignore
else:
fixed_args[i] = Hashable(static_arg) # type: ignore
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]