mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

* Don't wrap static arguments in hashable wrappers in pmap. * Delete wrap_hashably(). * In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended.. * Delete argnames_partial, which appears unused.
322 lines
11 KiB
Python
322 lines
11 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import operator
|
|
from functools import partial
|
|
from typing import Any, Dict, Iterable, Tuple, Union, Optional
|
|
|
|
import numpy as np
|
|
|
|
from .. import core
|
|
from . import dtypes
|
|
from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
|
|
tree_structure, treedef_children, treedef_is_leaf)
|
|
from .tree_util import _replace_nones
|
|
from .. import linear_util as lu
|
|
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
|
|
from ..core import unit
|
|
|
|
from . import traceback_util
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map = safe_map
|
|
|
|
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
|
"""Ensure x is either an index or a tuple of indices."""
|
|
try:
|
|
return operator.index(x)
|
|
except TypeError:
|
|
return tuple(map(operator.index, x))
|
|
|
|
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
|
|
"""Convert x to a tuple of indices."""
|
|
try:
|
|
return (operator.index(x),)
|
|
except TypeError:
|
|
return tuple(map(operator.index, x))
|
|
|
|
def _ensure_str(x: str) -> str:
|
|
if not isinstance(x, str):
|
|
raise TypeError(f"argument is not a string: {x}")
|
|
return x
|
|
|
|
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
|
|
"""Convert x to a tuple of strings."""
|
|
if isinstance(x, str):
|
|
return (x,)
|
|
else:
|
|
return tuple(map(_ensure_str, x))
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun(in_tree, *args_flat):
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
ans = yield py_args, py_kwargs
|
|
yield tree_flatten(ans)
|
|
|
|
def apply_flat_fun(fun, io_tree, *py_args):
|
|
in_tree_expected, out_tree = io_tree
|
|
args, in_tree = tree_flatten((py_args, {}))
|
|
if in_tree != in_tree_expected:
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
ans = fun(*args)
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_nokwargs(in_tree, *args_flat):
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
ans = yield py_args, {}
|
|
yield tree_flatten(ans)
|
|
|
|
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
|
in_tree_expected, out_tree = io_tree
|
|
args, in_tree = tree_flatten(py_args)
|
|
if in_tree != in_tree_expected:
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
ans = fun(*args)
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
|
|
# This implementation relies on internal details of linear_util.py's
|
|
# WrappedFun, but it's for the worthy cause of better user error messages.
|
|
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
|
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
|
|
# core.eval_jaxpr encounters a call primitive (though at that point we're just
|
|
# round-tripping jaxprs and the user errors in question are impossible).
|
|
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
|
assert (isinstance(flatten_fun_nokwargs, partial) and
|
|
len(flatten_fun_nokwargs.args) == 1)
|
|
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
|
try:
|
|
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
|
|
for f, args in fn.transforms if f in flat_xforms)
|
|
except ValueError:
|
|
return None
|
|
else:
|
|
return in_tree, has_kwargs
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
pair = yield py_args, {}
|
|
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
raise TypeError("expected function with aux output to return a two-element "
|
|
f"tuple, but got type {type(pair)} with value {repr(pair)}")
|
|
ans, aux = pair
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
aux_flat, aux_tree = tree_flatten(aux)
|
|
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
|
|
|
|
|
def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
|
|
dyn_argnums = _ensure_index_tuple(dyn_argnums)
|
|
fixed_args = [unit] * len(args)
|
|
for i, arg in enumerate(args):
|
|
if i in dyn_argnums: continue
|
|
if require_static_args_hashable:
|
|
if not is_hashable(arg):
|
|
raise ValueError(
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
|
fixed_args[i] = Hashable(arg)
|
|
else:
|
|
fixed_args[i] = Unhashable(arg)
|
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
|
|
|
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
|
args: Tuple[Any], *, allow_invalid: bool):
|
|
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
|
|
if not static_argnums:
|
|
return f, args
|
|
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
fixed_args = [unit] * len(args) # type: ignore
|
|
for i in static_argnums:
|
|
# TODO(shoyer): set allow_invalid=True permanently after enabling
|
|
# static_argnames.
|
|
if allow_invalid and i >= len(args):
|
|
continue
|
|
static_arg = args[i]
|
|
try:
|
|
hash(static_arg)
|
|
except TypeError:
|
|
raise ValueError(
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
|
|
else:
|
|
fixed_args[i] = Hashable(static_arg) # type: ignore
|
|
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
|
|
|
|
|
@lu.transformation
|
|
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
|
args = [None if arg is unit else arg.val for arg in fixed_args]
|
|
for i, arg in zip(dyn_argnums, dyn_args):
|
|
args[i] = arg
|
|
ans = yield args, kwargs
|
|
yield ans
|
|
|
|
|
|
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
|
|
kwargs: Dict[str, Any]):
|
|
if not static_argnames:
|
|
return f, kwargs
|
|
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
|
|
|
|
fixed_kwargs: Dict[str, Any] = {}
|
|
for k, arg in kwargs.items():
|
|
if k in dyn_kwargs:
|
|
fixed_kwargs[k] = unit
|
|
else:
|
|
try:
|
|
hash(arg)
|
|
except TypeError:
|
|
raise ValueError(
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
f"to unexpected cache-misses. Static argument (name {k}) of type "
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
|
else:
|
|
fixed_kwargs[k] = Hashable(arg) # type: ignore
|
|
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
|
@lu.transformation
|
|
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
|
kwargs = {k: None if arg is unit else arg.val
|
|
for k, arg in fixed_kwargs.val.items()}
|
|
kwargs.update(dyn_kwargs)
|
|
ans = yield args, kwargs
|
|
yield ans
|
|
|
|
|
|
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
|
|
"""Returns a tuple with a boolean value for each leaf in args."""
|
|
res = []
|
|
for i, arg in enumerate(args):
|
|
donate = bool(i in donate_argnums)
|
|
res.extend((donate,) * tree_structure(arg).num_leaves)
|
|
res.extend((False,) * tree_structure(kwargs).num_leaves)
|
|
return tuple(res)
|
|
|
|
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
|
|
"""Shifts donate to account for static.
|
|
|
|
>>> rebase_donate_argnums((3, 4), (0, 1))
|
|
(1, 2)
|
|
|
|
Args:
|
|
donate_argnums: An iterable of ints.
|
|
static_argnums: An iterable of ints.
|
|
|
|
Returns:
|
|
A tuple of unique, sorted integer values based on donate_argnums with each
|
|
element offset to account for static_argnums.
|
|
"""
|
|
if not (static_argnums or donate_argnums):
|
|
return tuple(sorted(donate_argnums))
|
|
|
|
static_argnums = sorted(set(static_argnums))
|
|
donate_argnums = sorted(set(donate_argnums))
|
|
i = j = o = 0
|
|
out = []
|
|
while j < len(donate_argnums):
|
|
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
|
|
raise ValueError(f"`static_argnums` {static_argnums} and "
|
|
f"`donate_argnums` {donate_argnums} cannot intersect.")
|
|
|
|
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
|
|
o += 1
|
|
i += 1
|
|
else:
|
|
out.append(donate_argnums[j] - o)
|
|
j += 1
|
|
return tuple(out)
|
|
|
|
|
|
def is_hashable(arg):
|
|
try:
|
|
hash(arg)
|
|
return True
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
|
|
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
|
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
|
|
# the given treedef, build a complete axis spec tree with the same structure
|
|
# and return the flattened result
|
|
# TODO(mattjj,phawkins): improve this implementation
|
|
|
|
proxy = object()
|
|
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
|
axes = []
|
|
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
|
try:
|
|
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
|
except ValueError:
|
|
if kws:
|
|
# if keyword arguments are included in the tree, we make adapt the error
|
|
# message only to be about the positional arguments
|
|
treedef, leaf = treedef_children(treedef)
|
|
assert treedef_is_leaf(leaf)
|
|
axis_tree, _ = axis_tree
|
|
hint = ""
|
|
if tupled_args:
|
|
hint += (f" Note that {name} that are non-trivial pytrees should always be "
|
|
f"wrapped in a tuple representing the argument list.")
|
|
if len(treedef.children()) == 1:
|
|
try:
|
|
flatten_axes(name, treedef, (axis_tree,))
|
|
except ValueError:
|
|
pass # That's not the issue.
|
|
else:
|
|
hint += (f" In particular, you're passing in a single argument which "
|
|
f"means that {name} might need to be wrapped in "
|
|
f"a singleton tuple.")
|
|
raise ValueError(f"{name} specification must be a tree prefix of the "
|
|
f"corresponding value, got specification {axis_tree} "
|
|
f"for value tree {treedef}.{hint}") from None
|
|
axes = [None if a is proxy else a for a in axes]
|
|
assert len(axes) == treedef.num_leaves
|
|
return axes
|
|
|
|
def _dtype(x):
|
|
try:
|
|
return dtypes.result_type(x)
|
|
except ValueError:
|
|
return dtypes.result_type(getattr(x, 'dtype'))
|
|
|
|
def shaped_abstractify(x):
|
|
try:
|
|
return core.raise_to_shaped(core.get_aval(x))
|
|
except TypeError:
|
|
pass
|
|
|
|
weak_type = getattr(x, 'weak_type', False)
|
|
named_shape = getattr(x, 'named_shape', {})
|
|
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
|
named_shape=named_shape)
|
|
|
|
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
|
# By default it does nothing, but it can be monkey-patched to do other things.
|
|
def api_hook(fun, tag: str):
|
|
return fun
|