rocm_jax/jax/api_util.py
Skye Wanderman-Milne a5da921f4c
Move _flatten_axes to api_util.py (#3041)
This is in preparation for using it in sharded_jit.py (since sharded_jit isn't included in api.py yet).
2020-05-11 11:04:57 -07:00

120 lines
4.3 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tree_util import (build_tree, tree_flatten, tree_unflatten,
treedef_is_leaf, tree_multimap, _replace_nones)
from . import linear_util as lu
from .util import safe_map, unzip2, partial, curry, WrapHashably, Hashable
from .core import unit
map = safe_map
@curry
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
try:
fun.__name__ = namestr.format(fun=get_name(wrapped))
fun.__module__ = get_module(wrapped)
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
fun.__wrapped__ = wrapped
finally:
return fun
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
def get_doc(fun): return getattr(fun, "__doc__", "")
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans, aux = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args):
if isinstance(dyn_argnums, int):
dyn_argnums = (dyn_argnums,)
else:
dyn_argnums = tuple(dyn_argnums)
fixed_args = tuple([unit if i in dyn_argnums else wrap_hashably(arg)
for i, arg in enumerate(args)])
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
def wrap_hashably(arg):
try:
hash(arg)
except TypeError:
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
else:
return Hashable(arg)
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args, kwargs
yield ans
def flatten_axes(treedef, axis_tree):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError as e:
msg = ("axes specification must be a tree prefix of the corresponding "
"value, got specification {} for value {}.")
raise ValueError(msg.format(axis_tree, treedef)) from e
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes