mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Emit a warning on non-hashble static arguments in jax.jit.
The message looks like, e.g.: Static argument (index 1) of type <class 'numpy.ndarray'> for function f is non-hashable. As this can lead to unexpected cache-misses, it will raise an error in a near future.
This commit is contained in:
parent
9ba28d2634
commit
98ee69ba1c
25
jax/api.py
25
jax/api.py
@ -41,8 +41,9 @@ from . import ad_util
|
||||
from . import dtypes
|
||||
from .core import eval_jaxpr
|
||||
from .api_util import (wraps, flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
|
||||
flatten_fun_nokwargs2, argnums_partial, flatten_axes,
|
||||
donation_vector, rebase_donate_argnums)
|
||||
flatten_fun_nokwargs2, argnums_partial,
|
||||
argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums)
|
||||
from .traceback_util import api_boundary
|
||||
from .tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
tree_transpose, tree_leaves, tree_multimap,
|
||||
@ -130,11 +131,12 @@ def jit(fun: Callable[..., T],
|
||||
arguments to treat as static (compile-time constant). Operations that only
|
||||
depend on static arguments will be constant-folded in Python (during
|
||||
tracing), and so the corresponding argument values can be any Python
|
||||
object. Calling the jitted function with different values for these
|
||||
constants will trigger recompilation. If the jitted function is called
|
||||
with fewer positional arguments than indicated by ``static_argnums`` then
|
||||
an error is raised. Arguments that are not arrays or containers thereof
|
||||
must be marked as static. Defaults to ().
|
||||
object. Static arguments should be hashable, meaning both ``__hash__``
|
||||
and ``__eq__` are implemented. Calling the jitted function with different
|
||||
values for these constants will trigger recompilation. If the jitted
|
||||
function is called with fewer positional arguments than indicated by
|
||||
``static_argnums`` then an error is raised. Arguments that are not arrays
|
||||
or containers thereof must be marked as static. Defaults to ().
|
||||
device: This is an experimental feature and the API is likely to change.
|
||||
Optional, the Device the jitted function will run on. (Available devices
|
||||
can be retrieved via :py:func:`jax.devices`.) The default is inherited from
|
||||
@ -198,8 +200,7 @@ def _python_jit(
|
||||
raise ValueError(msg.format(static_argnums, donate_argnums, len(args)))
|
||||
f = lu.wrap_init(fun)
|
||||
if static_argnums:
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f, dyn_args = argnums_partial(f, dyn_argnums, args)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args)
|
||||
else:
|
||||
dyn_args = args
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
@ -258,8 +259,7 @@ def _cpp_jit(
|
||||
# work/code that is redundant between C++ and Python. We can try that later.
|
||||
f = lu.wrap_init(fun)
|
||||
if static_argnums:
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f, dyn_args = argnums_partial(f, dyn_argnums, args)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args)
|
||||
else:
|
||||
dyn_args = args
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
@ -623,8 +623,7 @@ def _xla_computation(
|
||||
|
||||
f = lu.wrap_init(fun)
|
||||
if static_argnums:
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f, dyn_args = argnums_partial(f, dyn_argnums, args)
|
||||
f, dyn_args = argnums_partial_except(f, static_argnums, args)
|
||||
else:
|
||||
dyn_args = args
|
||||
args_flat, in_tree = tree_flatten((dyn_args, kwargs))
|
||||
|
@ -12,6 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from absl import logging
|
||||
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap, _replace_nones,
|
||||
tree_structure)
|
||||
@ -19,8 +22,6 @@ from . import linear_util as lu
|
||||
from .util import safe_map, curry, WrapHashably, Hashable
|
||||
from .core import unit
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
map = safe_map
|
||||
|
||||
|
||||
@ -48,7 +49,7 @@ def apply_flat_fun(fun, io_tree, *py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
args, in_tree = tree_flatten((py_args, {}))
|
||||
if in_tree != in_tree_expected:
|
||||
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
||||
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
@ -62,7 +63,7 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
||||
in_tree_expected, out_tree = io_tree
|
||||
args, in_tree = tree_flatten(py_args)
|
||||
if in_tree != in_tree_expected:
|
||||
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
||||
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
||||
ans = fun(*args)
|
||||
return tree_unflatten(out_tree, ans)
|
||||
|
||||
@ -74,6 +75,7 @@ def flatten_fun_nokwargs2(in_tree, *args_flat):
|
||||
aux_flat, aux_tree = tree_flatten(aux)
|
||||
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
||||
|
||||
|
||||
def argnums_partial(f, dyn_argnums, args):
|
||||
if isinstance(dyn_argnums, int):
|
||||
dyn_argnums = (dyn_argnums,)
|
||||
@ -84,6 +86,32 @@ def argnums_partial(f, dyn_argnums, args):
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
|
||||
|
||||
|
||||
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
||||
args: Tuple[Any]):
|
||||
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
|
||||
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
|
||||
dyn_args = tuple(args[i] for i in dyn_argnums)
|
||||
|
||||
fixed_args = [unit] * len(args) # type: ignore
|
||||
for i in static_argnums:
|
||||
static_arg = args[i]
|
||||
try:
|
||||
hash(static_arg)
|
||||
except TypeError:
|
||||
logging.warning(
|
||||
"Static argument (index %s) of type %s for function %s is "
|
||||
"non-hashable. As this can lead to unexpected cache-misses, it "
|
||||
"will raise an error in a near future.", i, type(static_arg),
|
||||
f.__name__)
|
||||
# e.g. ndarrays, DeviceArrays
|
||||
fixed_args[i] = WrapHashably(static_arg) # type: ignore
|
||||
else:
|
||||
fixed_args[i] = Hashable(static_arg) # type: ignore
|
||||
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
|
||||
@lu.transformation
|
||||
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
args = [None if arg is unit else arg.val for arg in fixed_args]
|
||||
|
Loading…
x
Reference in New Issue
Block a user