mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 23:26:05 +00:00
306 lines
11 KiB
Python
306 lines
11 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import operator
|
|
from functools import partial
|
|
from typing import Any, Dict, Iterable, Tuple, Union, Optional
|
|
|
|
import numpy as np
|
|
|
|
from . import core
|
|
from ._src import dtypes
|
|
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap,
|
|
tree_structure, treedef_children, treedef_is_leaf)
|
|
from ._src.tree_util import _replace_nones
|
|
from . import linear_util as lu
|
|
from ._src.util import safe_map, WrapHashably, WrapKwArgs, Hashable
|
|
from .core import unit
|
|
|
|
from ._src import traceback_util
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
map = safe_map
|
|
|
|
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
|
"""Ensure x is either an index or a tuple of indices."""
|
|
try:
|
|
return operator.index(x)
|
|
except TypeError:
|
|
return tuple(map(operator.index, x))
|
|
|
|
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
|
|
"""Convert x to a tuple of indices."""
|
|
try:
|
|
return (operator.index(x),)
|
|
except TypeError:
|
|
return tuple(map(operator.index, x))
|
|
|
|
def _ensure_str(x: str) -> str:
|
|
if not isinstance(x, str):
|
|
raise TypeError(f"argument is not a string: {x}")
|
|
return x
|
|
|
|
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
|
|
"""Convert x to a tuple of strings."""
|
|
if isinstance(x, str):
|
|
return (x,)
|
|
else:
|
|
return tuple(map(_ensure_str, x))
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun(in_tree, *args_flat):
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
ans = yield py_args, py_kwargs
|
|
yield tree_flatten(ans)
|
|
|
|
def apply_flat_fun(fun, io_tree, *py_args):
|
|
in_tree_expected, out_tree = io_tree
|
|
args, in_tree = tree_flatten((py_args, {}))
|
|
if in_tree != in_tree_expected:
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
ans = fun(*args)
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_nokwargs(in_tree, *args_flat):
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
ans = yield py_args, {}
|
|
yield tree_flatten(ans)
|
|
|
|
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
|
in_tree_expected, out_tree = io_tree
|
|
args, in_tree = tree_flatten(py_args)
|
|
if in_tree != in_tree_expected:
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
ans = fun(*args)
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
PyTreeDef = Any
|
|
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
|
|
# This implementation relies on internal details of linear_util.py's
|
|
# WrappedFun, but it's for the worthy cause of better user error messages.
|
|
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
|
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
|
|
# core.eval_jaxpr encounters a call primitive (though at that point we're just
|
|
# round-tripping jaxprs and the user errors in question are impossible).
|
|
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
|
assert (isinstance(flatten_fun_nokwargs, partial) and
|
|
len(flatten_fun_nokwargs.args) == 1)
|
|
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
|
try:
|
|
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
|
|
for f, args in fn.transforms if f in flat_xforms)
|
|
except ValueError:
|
|
return None
|
|
else:
|
|
return in_tree, has_kwargs
|
|
|
|
@lu.transformation_with_aux
|
|
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
pair = yield py_args, {}
|
|
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
raise TypeError("expected function with aux output to return a two-element "
|
|
f"tuple, but got type {type(pair)} with value {repr(pair)}")
|
|
ans, aux = pair
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
aux_flat, aux_tree = tree_flatten(aux)
|
|
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
|
|
|
|
|
def argnums_partial(f, dyn_argnums, args):
|
|
dyn_argnums = _ensure_index_tuple(dyn_argnums)
|
|
fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg)
|
|
for i, arg in enumerate(args))
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
|
|
|
|
|
|
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
|
args: Tuple[Any], *, allow_invalid: bool):
|
|
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
|
|
if not static_argnums:
|
|
return f, args
|
|
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
fixed_args = [unit] * len(args) # type: ignore
|
|
for i in static_argnums:
|
|
# TODO(shoyer): set allow_invalid=True permanently after enabling
|
|
# static_argnames.
|
|
if allow_invalid and i >= len(args):
|
|
continue
|
|
static_arg = args[i]
|
|
try:
|
|
hash(static_arg)
|
|
except TypeError:
|
|
raise ValueError(
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
|
|
else:
|
|
fixed_args[i] = Hashable(static_arg) # type: ignore
|
|
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
|
|
|
|
|
@lu.transformation
|
|
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
|
args = [None if arg is unit else arg.val for arg in fixed_args]
|
|
for i, arg in zip(dyn_argnums, dyn_args):
|
|
args[i] = arg
|
|
ans = yield args, kwargs
|
|
yield ans
|
|
|
|
|
|
def argnames_partial(f, dyn_argnames, kwargs):
|
|
dyn_argnames = _ensure_str_tuple(dyn_argnames)
|
|
fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v))
|
|
for k, v in kwargs.items())
|
|
dyn_kwargs = {k: kwargs[k] for k in dyn_argnames}
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
|
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
|
|
kwargs: Dict[str, Any]):
|
|
if not static_argnames:
|
|
return f, kwargs
|
|
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
|
|
|
|
fixed_kwargs: Dict[str, Any] = {}
|
|
for k, arg in kwargs.items():
|
|
if k in dyn_kwargs:
|
|
fixed_kwargs[k] = unit
|
|
else:
|
|
try:
|
|
hash(arg)
|
|
except TypeError:
|
|
raise ValueError(
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
f"to unexpected cache-misses. Static argument (name {k}) of type "
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
|
else:
|
|
fixed_kwargs[k] = Hashable(arg) # type: ignore
|
|
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
|
@lu.transformation
|
|
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
|
kwargs = {k: None if arg is unit else arg.val
|
|
for k, arg in fixed_kwargs.val.items()}
|
|
kwargs.update(dyn_kwargs)
|
|
ans = yield args, kwargs
|
|
yield ans
|
|
|
|
|
|
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
|
|
"""Returns a tuple with a boolean value for each leaf in args."""
|
|
res = []
|
|
for i, arg in enumerate(args):
|
|
donate = bool(i in donate_argnums)
|
|
res.extend((donate,) * tree_structure(arg).num_leaves)
|
|
res.extend((False,) * tree_structure(kwargs).num_leaves)
|
|
return tuple(res)
|
|
|
|
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
|
|
"""Shifts donate to account for static.
|
|
|
|
>>> rebase_donate_argnums((3, 4), (0, 1))
|
|
(1, 2)
|
|
|
|
Args:
|
|
donate_argnums: An iterable of ints.
|
|
static_argnums: An iterable of ints.
|
|
|
|
Returns:
|
|
A tuple of unique, sorted integer values based on donate_argnums with each
|
|
element offset to account for static_argnums.
|
|
"""
|
|
if not (static_argnums or donate_argnums):
|
|
return tuple(sorted(donate_argnums))
|
|
|
|
static_argnums = sorted(set(static_argnums))
|
|
donate_argnums = sorted(set(donate_argnums))
|
|
i = j = o = 0
|
|
out = []
|
|
while j < len(donate_argnums):
|
|
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
|
|
raise ValueError(f"`static_argnums` {static_argnums} and "
|
|
f"`donate_argnums` {donate_argnums} cannot intersect.")
|
|
|
|
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
|
|
o += 1
|
|
i += 1
|
|
else:
|
|
out.append(donate_argnums[j] - o)
|
|
j += 1
|
|
return tuple(out)
|
|
|
|
def wrap_hashably(arg):
|
|
try:
|
|
hash(arg)
|
|
except TypeError:
|
|
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
|
|
else:
|
|
return Hashable(arg)
|
|
|
|
def flatten_axes(name, treedef, axis_tree, *, kws=False):
|
|
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
|
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
|
|
# the given treedef, build a complete axis spec tree with the same structure
|
|
# and return the flattened result
|
|
# TODO(mattjj,phawkins): improve this implementation
|
|
proxy = object()
|
|
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
|
axes = []
|
|
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
|
try:
|
|
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
|
except ValueError:
|
|
if kws:
|
|
# if keyword arguments are included in the tree, we make adapt the error
|
|
# message only to be about the positional arguments
|
|
treedef, leaf = treedef_children(treedef)
|
|
assert treedef_is_leaf(leaf)
|
|
axis_tree, _ = axis_tree
|
|
raise ValueError(f"{name} specification must be a tree prefix of the "
|
|
f"corresponding value, got specification {axis_tree} "
|
|
f"for value tree {treedef}.") from None
|
|
axes = [None if a is proxy else a for a in axes]
|
|
assert len(axes) == treedef.num_leaves
|
|
return axes
|
|
|
|
def _dtype(x):
|
|
try:
|
|
return dtypes.result_type(x)
|
|
except ValueError:
|
|
return dtypes.result_type(getattr(x, 'dtype'))
|
|
|
|
def shaped_abstractify(x):
|
|
try:
|
|
return core.raise_to_shaped(core.get_aval(x))
|
|
except TypeError:
|
|
pass
|
|
|
|
weak_type = getattr(x, 'weak_type', False)
|
|
named_shape = getattr(x, 'named_shape', {})
|
|
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
|
named_shape=named_shape)
|
|
|
|
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
|
# By default it does nothing, but it can be monkey-patched to do other things.
|
|
def api_hook(fun, tag: str):
|
|
return fun
|