rocm_jax/jax/util.py
2018-11-17 18:03:33 -08:00

140 lines
3.2 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import functools
import itertools as it
import weakref
allow_memoize_hash_failures = False
def safe_zip(*args):
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
return zip(*args)
def safe_map(f, *args):
args = map(list, args)
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(map(len, args))
return map(f, *args)
def unzip2(xys):
xs = []
ys = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)
def unzip3(xyzs):
xs = []
ys = []
zs = []
for x, y, z in xyzs:
xs.append(x)
ys.append(y)
zs.append(z)
return tuple(xs), tuple(ys), tuple(zs)
def concatenate(xs):
return list(it.chain.from_iterable(xs))
def partial(fun, *args, **kwargs):
wrapped = functools.partial(fun, *args, **kwargs)
functools.update_wrapper(wrapped, fun)
wrapped._bound_args = args
return wrapped
class partialmethod(functools.partial):
def __get__(self, instance, owner):
if instance is None:
return self
else:
return partial(self.func, instance,
*(self.args or ()), **(self.keywords or {}))
def curry(f):
return partial(partial, f)
def toposort(end_node):
child_counts = {}
stack = [end_node]
while stack:
node = stack.pop()
if node in child_counts:
child_counts[node] += 1
else:
child_counts[node] = 1
stack.extend(node.parents)
sorted_nodes = []
childless_nodes = [end_node]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[parent] == 1:
childless_nodes.append(parent)
else:
child_counts[parent] -= 1
return sorted_nodes[::-1]
def split_merge(predicate, xs):
sides = map(predicate, xs)
lhs = [x for x, s in zip(xs, sides) if s]
rhs = [x for x, s in zip(xs, sides) if not s]
def merge(new_lhs, new_rhs):
out = []
for s in sides:
if s:
out.append(new_lhs[0])
new_lhs = new_lhs[1:]
else:
out.append(new_rhs[0])
new_rhs = new_rhs[1:]
assert not new_rhs
assert not new_lhs
return out
return lhs, rhs, merge
def memoize(fun):
cache = {}
def memoized_fun(*args, **kwargs):
key = (args, tuple(kwargs and sorted(kwargs.items())))
try:
return cache[key]
except KeyError:
ans = cache[key] = fun(*args, **kwargs)
return ans
except TypeError:
if allow_memoize_hash_failures:
return fun(*args, **kwargs)
else:
raise
return memoized_fun