rocm_jax/jax/api_util.py
Dougal Maclaurin 6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00

55 lines
1.8 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .tree_util import (build_tree, process_pytree, tree_flatten,
tree_unflatten, treedef_is_leaf)
from .linear_util import transformation_with_aux
from .util import safe_map, unzip2, partial, curry
map = safe_map
@curry
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
try:
fun.__name__ = namestr.format(fun=get_name(wrapped))
fun.__module__ = get_module(wrapped)
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
fun.__wrapped__ = wrapped
finally:
return fun
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
def get_doc(fun): return getattr(fun, "__doc__", "")
@transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
def abstract_tuple_tree_leaves(aval):
if type(aval) is AbstractTuple:
for elt in aval:
# TODO(mattjj,phawkins): use 'yield from' when PY2 is dropped
for a in abstract_tuple_tree_leaves(elt):
yield a
else:
yield aval