2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2019-10-31 14:09:12 -07:00
|
|
|
from .tree_util import (build_tree, tree_flatten, tree_unflatten,
|
|
|
|
treedef_is_leaf)
|
2020-01-05 04:35:34 +01:00
|
|
|
from . import linear_util as lu
|
2019-01-03 16:14:30 -08:00
|
|
|
from .util import safe_map, unzip2, partial, curry
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
|
|
|
|
@curry
|
2019-01-06 11:59:33 -08:00
|
|
|
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
|
|
|
|
try:
|
|
|
|
fun.__name__ = namestr.format(fun=get_name(wrapped))
|
|
|
|
fun.__module__ = get_module(wrapped)
|
|
|
|
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
|
2019-02-14 11:00:40 -05:00
|
|
|
fun.__wrapped__ = wrapped
|
2019-01-06 11:59:33 -08:00
|
|
|
finally:
|
|
|
|
return fun
|
|
|
|
|
|
|
|
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
|
|
|
|
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
|
|
|
|
def get_doc(fun): return getattr(fun, "__doc__", "")
|
2019-01-03 16:14:30 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-05-17 07:36:52 -07:00
|
|
|
def flatten_fun(in_tree, *args_flat):
|
|
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans = yield py_args, py_kwargs
|
2019-07-26 16:48:17 -04:00
|
|
|
yield tree_flatten(ans)
|
2019-07-25 12:41:11 -07:00
|
|
|
|
2019-07-26 23:17:21 -04:00
|
|
|
def apply_flat_fun(fun, io_tree, *py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten((py_args, {}))
|
|
|
|
if in_tree != in_tree_expected:
|
|
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def flatten_fun_nokwargs(in_tree, *args_flat):
|
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans = yield py_args, {}
|
|
|
|
yield tree_flatten(ans)
|
|
|
|
|
|
|
|
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten(py_args)
|
|
|
|
if in_tree != in_tree_expected:
|
|
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans, aux = yield py_args, {}
|
|
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
|
|
aux_flat, aux_tree = tree_flatten(aux)
|
|
|
|
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|