rocm_jax/jax/_src/api_util.py
Peter Hawkins a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00

322 lines
11 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from functools import partial
from typing import Any, Dict, Iterable, Tuple, Union, Optional
import numpy as np
from .. import core
from . import dtypes
from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
tree_structure, treedef_children, treedef_is_leaf)
from .tree_util import _replace_nones
from .. import linear_util as lu
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
from ..core import unit
from . import traceback_util
traceback_util.register_exclusion(__file__)
map = safe_map
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
"""Ensure x is either an index or a tuple of indices."""
try:
return operator.index(x)
except TypeError:
return tuple(map(operator.index, x))
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
"""Convert x to a tuple of indices."""
try:
return (operator.index(x),)
except TypeError:
return tuple(map(operator.index, x))
def _ensure_str(x: str) -> str:
if not isinstance(x, str):
raise TypeError(f"argument is not a string: {x}")
return x
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
else:
return tuple(map(_ensure_str, x))
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
for f, args in fn.transforms if f in flat_xforms)
except ValueError:
return None
else:
return in_tree, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
pair = yield py_args, {}
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {repr(pair)}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
dyn_argnums = _ensure_index_tuple(dyn_argnums)
fixed_args = [unit] * len(args)
for i, arg in enumerate(args):
if i in dyn_argnums: continue
if require_static_args_hashable:
if not is_hashable(arg):
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
fixed_args[i] = Hashable(arg)
else:
fixed_args[i] = Unhashable(arg)
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
args: Tuple[Any], *, allow_invalid: bool):
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
if not static_argnums:
return f, args
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
dyn_args = tuple(args[i] for i in dyn_argnums)
fixed_args = [unit] * len(args) # type: ignore
for i in static_argnums:
# TODO(shoyer): set allow_invalid=True permanently after enabling
# static_argnames.
if allow_invalid and i >= len(args):
continue
static_arg = args[i]
try:
hash(static_arg)
except TypeError:
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
else:
fixed_args[i] = Hashable(static_arg) # type: ignore
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args, kwargs
yield ans
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
kwargs: Dict[str, Any]):
if not static_argnames:
return f, kwargs
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
fixed_kwargs: Dict[str, Any] = {}
for k, arg in kwargs.items():
if k in dyn_kwargs:
fixed_kwargs[k] = unit
else:
try:
hash(arg)
except TypeError:
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (name {k}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
else:
fixed_kwargs[k] = Hashable(arg) # type: ignore
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = {k: None if arg is unit else arg.val
for k, arg in fixed_kwargs.val.items()}
kwargs.update(dyn_kwargs)
ans = yield args, kwargs
yield ans
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
res = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
res.extend((False,) * tree_structure(kwargs).num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
"""Shifts donate to account for static.
>>> rebase_donate_argnums((3, 4), (0, 1))
(1, 2)
Args:
donate_argnums: An iterable of ints.
static_argnums: An iterable of ints.
Returns:
A tuple of unique, sorted integer values based on donate_argnums with each
element offset to account for static_argnums.
"""
if not (static_argnums or donate_argnums):
return tuple(sorted(donate_argnums))
static_argnums = sorted(set(static_argnums))
donate_argnums = sorted(set(donate_argnums))
i = j = o = 0
out = []
while j < len(donate_argnums):
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
raise ValueError(f"`static_argnums` {static_argnums} and "
f"`donate_argnums` {donate_argnums} cannot intersect.")
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
o += 1
i += 1
else:
out.append(donate_argnums[j] - o)
j += 1
return tuple(out)
def is_hashable(arg):
try:
hash(arg)
return True
except TypeError:
return False
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError:
if kws:
# if keyword arguments are included in the tree, we make adapt the error
# message only to be about the positional arguments
treedef, leaf = treedef_children(treedef)
assert treedef_is_leaf(leaf)
axis_tree, _ = axis_tree
hint = ""
if tupled_args:
hint += (f" Note that {name} that are non-trivial pytrees should always be "
f"wrapped in a tuple representing the argument list.")
if len(treedef.children()) == 1:
try:
flatten_axes(name, treedef, (axis_tree,))
except ValueError:
pass # That's not the issue.
else:
hint += (f" In particular, you're passing in a single argument which "
f"means that {name} might need to be wrapped in "
f"a singleton tuple.")
raise ValueError(f"{name} specification must be a tree prefix of the "
f"corresponding value, got specification {axis_tree} "
f"for value tree {treedef}.{hint}") from None
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes
def _dtype(x):
try:
return dtypes.result_type(x)
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def shaped_abstractify(x):
try:
return core.raise_to_shaped(core.get_aval(x))
except TypeError:
pass
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
named_shape=named_shape)
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.
def api_hook(fun, tag: str):
return fun