2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
import itertools as it
|
2018-12-11 11:52:31 -05:00
|
|
|
import types
|
2019-01-14 20:11:08 -05:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
def safe_zip(*args):
|
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
2018-11-21 13:20:44 -08:00
|
|
|
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
|
|
|
return list(zip(*args))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def safe_map(f, *args):
|
2018-11-21 13:20:44 -08:00
|
|
|
args = list(map(list, args))
|
2018-11-17 18:03:33 -08:00
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
2018-11-21 13:20:44 -08:00
|
|
|
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
|
|
|
return list(map(f, *args))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def unzip2(xys):
|
|
|
|
xs = []
|
|
|
|
ys = []
|
|
|
|
for x, y in xys:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
return tuple(xs), tuple(ys)
|
|
|
|
|
|
|
|
def unzip3(xyzs):
|
|
|
|
xs = []
|
|
|
|
ys = []
|
|
|
|
zs = []
|
|
|
|
for x, y, z in xyzs:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
zs.append(z)
|
|
|
|
return tuple(xs), tuple(ys), tuple(zs)
|
|
|
|
|
2020-04-30 06:19:46 -07:00
|
|
|
def unzip4(wxyzs):
|
|
|
|
ws = []
|
|
|
|
xs = []
|
|
|
|
ys = []
|
|
|
|
zs = []
|
|
|
|
for w, x, y, z in wxyzs:
|
|
|
|
ws.append(w)
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
zs.append(z)
|
|
|
|
return tuple(ws), tuple(xs), tuple(ys), tuple(zs)
|
|
|
|
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
def subvals(lst, replace):
|
|
|
|
lst = list(lst)
|
|
|
|
for i, v in replace:
|
|
|
|
lst[i] = v
|
|
|
|
return tuple(lst)
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def split_list(args, ns):
|
|
|
|
assert type(ns) is list
|
|
|
|
args = list(args)
|
|
|
|
lists = []
|
|
|
|
for n in ns:
|
|
|
|
lists.append(args[:n])
|
|
|
|
args = args[n:]
|
|
|
|
lists.append(args)
|
|
|
|
return lists
|
|
|
|
|
|
|
|
def split_dict(dct, names):
|
|
|
|
dct = dict(dct)
|
|
|
|
lst = [dct.pop(name) for name in names]
|
|
|
|
assert not dct
|
|
|
|
return lst
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def concatenate(xs):
|
|
|
|
return list(it.chain.from_iterable(xs))
|
|
|
|
|
|
|
|
def partial(fun, *args, **kwargs):
|
|
|
|
wrapped = functools.partial(fun, *args, **kwargs)
|
|
|
|
functools.update_wrapper(wrapped, fun)
|
|
|
|
wrapped._bound_args = args
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
class partialmethod(functools.partial):
|
|
|
|
def __get__(self, instance, owner):
|
|
|
|
if instance is None:
|
|
|
|
return self
|
|
|
|
else:
|
|
|
|
return partial(self.func, instance,
|
|
|
|
*(self.args or ()), **(self.keywords or {}))
|
|
|
|
|
|
|
|
def curry(f):
|
2019-03-12 15:07:52 -04:00
|
|
|
"""Curries arguments of f, returning a function on any remaining arguments.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
>>> f = lambda x, y, z, w: x * y + z * w
|
|
|
|
>>> f(2,3,4,5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2)(3, 4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3)(4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3, 4, 5)()
|
|
|
|
26
|
|
|
|
"""
|
2018-11-17 18:03:33 -08:00
|
|
|
return partial(partial, f)
|
|
|
|
|
2019-07-26 16:48:17 -04:00
|
|
|
def toposort(end_nodes):
|
2019-07-27 15:46:14 -07:00
|
|
|
if not end_nodes: return []
|
2019-10-01 17:56:44 +01:00
|
|
|
end_nodes = _remove_duplicates(end_nodes)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
child_counts = {}
|
2019-07-26 16:48:17 -04:00
|
|
|
stack = list(end_nodes)
|
2018-11-17 18:03:33 -08:00
|
|
|
while stack:
|
|
|
|
node = stack.pop()
|
2018-11-21 13:20:44 -08:00
|
|
|
if id(node) in child_counts:
|
|
|
|
child_counts[id(node)] += 1
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
child_counts[id(node)] = 1
|
2018-11-17 18:03:33 -08:00
|
|
|
stack.extend(node.parents)
|
2019-07-27 15:46:14 -07:00
|
|
|
for node in end_nodes:
|
|
|
|
child_counts[id(node)] -= 1
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
sorted_nodes = []
|
2019-07-27 15:46:14 -07:00
|
|
|
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
|
|
|
assert childless_nodes
|
2018-11-17 18:03:33 -08:00
|
|
|
while childless_nodes:
|
|
|
|
node = childless_nodes.pop()
|
|
|
|
sorted_nodes.append(node)
|
|
|
|
for parent in node.parents:
|
2018-11-21 13:20:44 -08:00
|
|
|
if child_counts[id(parent)] == 1:
|
2018-11-17 18:03:33 -08:00
|
|
|
childless_nodes.append(parent)
|
|
|
|
else:
|
2018-11-21 13:20:44 -08:00
|
|
|
child_counts[id(parent)] -= 1
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
check_toposort(sorted_nodes[::-1])
|
2018-11-17 18:03:33 -08:00
|
|
|
return sorted_nodes[::-1]
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def check_toposort(nodes):
|
|
|
|
visited = set()
|
|
|
|
for node in nodes:
|
|
|
|
assert all(id(parent) in visited for parent in node.parents)
|
|
|
|
visited.add(id(node))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-10-01 17:56:44 +01:00
|
|
|
def _remove_duplicates(node_list):
|
|
|
|
seen = set()
|
|
|
|
out = []
|
|
|
|
for n in node_list:
|
|
|
|
if id(n) not in seen:
|
|
|
|
seen.add(id(n))
|
|
|
|
out.append(n)
|
|
|
|
return out
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def split_merge(predicate, xs):
|
2018-11-21 13:20:44 -08:00
|
|
|
sides = list(map(predicate, xs))
|
2018-11-17 18:03:33 -08:00
|
|
|
lhs = [x for x, s in zip(xs, sides) if s]
|
|
|
|
rhs = [x for x, s in zip(xs, sides) if not s]
|
|
|
|
def merge(new_lhs, new_rhs):
|
|
|
|
out = []
|
|
|
|
for s in sides:
|
|
|
|
if s:
|
|
|
|
out.append(new_lhs[0])
|
|
|
|
new_lhs = new_lhs[1:]
|
|
|
|
else:
|
|
|
|
out.append(new_rhs[0])
|
|
|
|
new_rhs = new_rhs[1:]
|
|
|
|
assert not new_rhs
|
|
|
|
assert not new_lhs
|
|
|
|
return out
|
|
|
|
|
|
|
|
return lhs, rhs, merge
|
|
|
|
|
2019-08-09 13:12:44 -04:00
|
|
|
def cache(max_size=4096):
|
2020-01-14 10:08:23 -05:00
|
|
|
return functools.lru_cache(maxsize=max_size)
|
2019-08-09 13:12:44 -04:00
|
|
|
|
2020-01-14 10:08:23 -05:00
|
|
|
memoize = functools.lru_cache(maxsize=None)
|
2019-03-26 08:35:34 -07:00
|
|
|
|
2018-11-26 18:50:27 -08:00
|
|
|
def prod(xs):
|
2019-08-05 14:08:46 -04:00
|
|
|
out = 1
|
|
|
|
for x in xs:
|
|
|
|
out *= x
|
|
|
|
return out
|
2018-11-30 16:16:28 -05:00
|
|
|
|
|
|
|
class WrapHashably(object):
|
2019-05-09 20:00:24 -07:00
|
|
|
__slots__ = ["val"]
|
|
|
|
|
2018-11-30 16:16:28 -05:00
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return id(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2018-12-08 00:03:34 -05:00
|
|
|
return self.val is other.val
|
2018-12-11 11:52:31 -05:00
|
|
|
|
2019-05-09 20:00:24 -07:00
|
|
|
class Hashable(object):
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val == other.val
|
|
|
|
|
2018-12-11 11:52:31 -05:00
|
|
|
def get_module_functions(module):
|
|
|
|
"""Finds functions in module.
|
|
|
|
Args:
|
|
|
|
module: A Python module.
|
|
|
|
Returns:
|
2020-07-08 14:44:49 -07:00
|
|
|
module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`.
|
2018-12-11 11:52:31 -05:00
|
|
|
"""
|
2020-07-08 14:44:49 -07:00
|
|
|
module_fns = {}
|
2018-12-11 11:52:31 -05:00
|
|
|
for key in dir(module):
|
2019-12-31 12:00:58 -08:00
|
|
|
# Omitting module level __getattr__, __dir__ which was added in Python 3.7
|
|
|
|
# https://www.python.org/dev/peps/pep-0562/
|
|
|
|
if key in ('__getattr__', '__dir__'):
|
|
|
|
continue
|
2018-12-11 11:52:31 -05:00
|
|
|
attr = getattr(module, key)
|
|
|
|
if isinstance(
|
2020-07-14 13:05:31 -07:00
|
|
|
attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
|
2020-07-08 14:44:49 -07:00
|
|
|
module_fns[key] = attr
|
2018-12-11 11:52:31 -05:00
|
|
|
return module_fns
|
2020-01-26 23:27:56 -08:00
|
|
|
|
|
|
|
def wrap_name(name, transform_name):
|
|
|
|
return transform_name + '(' + name + ')'
|
|
|
|
|
|
|
|
def extend_name_stack(stack, name=''):
|
|
|
|
return stack + name + '/'
|