# Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import functools import itertools as it import weakref allow_memoize_hash_failures = False def safe_zip(*args): n = len(args[0]) for arg in args[1:]: assert len(arg) == n, 'length mismatch: {}'.format(map(len, args)) return zip(*args) def safe_map(f, *args): args = map(list, args) n = len(args[0]) for arg in args[1:]: assert len(arg) == n, 'length mismatch: {}'.format(map(len, args)) return map(f, *args) def unzip2(xys): xs = [] ys = [] for x, y in xys: xs.append(x) ys.append(y) return tuple(xs), tuple(ys) def unzip3(xyzs): xs = [] ys = [] zs = [] for x, y, z in xyzs: xs.append(x) ys.append(y) zs.append(z) return tuple(xs), tuple(ys), tuple(zs) def concatenate(xs): return list(it.chain.from_iterable(xs)) def partial(fun, *args, **kwargs): wrapped = functools.partial(fun, *args, **kwargs) functools.update_wrapper(wrapped, fun) wrapped._bound_args = args return wrapped class partialmethod(functools.partial): def __get__(self, instance, owner): if instance is None: return self else: return partial(self.func, instance, *(self.args or ()), **(self.keywords or {})) def curry(f): return partial(partial, f) def toposort(end_node): child_counts = {} stack = [end_node] while stack: node = stack.pop() if node in child_counts: child_counts[node] += 1 else: child_counts[node] = 1 stack.extend(node.parents) sorted_nodes = [] childless_nodes = [end_node] while childless_nodes: node = childless_nodes.pop() sorted_nodes.append(node) for parent in node.parents: if child_counts[parent] == 1: childless_nodes.append(parent) else: child_counts[parent] -= 1 return sorted_nodes[::-1] def split_merge(predicate, xs): sides = map(predicate, xs) lhs = [x for x, s in zip(xs, sides) if s] rhs = [x for x, s in zip(xs, sides) if not s] def merge(new_lhs, new_rhs): out = [] for s in sides: if s: out.append(new_lhs[0]) new_lhs = new_lhs[1:] else: out.append(new_rhs[0]) new_rhs = new_rhs[1:] assert not new_rhs assert not new_lhs return out return lhs, rhs, merge def memoize(fun): cache = {} def memoized_fun(*args, **kwargs): key = (args, tuple(kwargs and sorted(kwargs.items()))) try: return cache[key] except KeyError: ans = cache[key] = fun(*args, **kwargs) return ans except TypeError: if allow_memoize_hash_failures: return fun(*args, **kwargs) else: raise return memoized_fun