rocm_jax/jax/_src/util.py
2022-04-14 13:41:05 -07:00

548 lines
16 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from functools import partial
import itertools as it
from collections import namedtuple
import operator
import types
import threading
from typing import (Any, Callable, Dict, Iterable, List, Tuple, Generic,
TypeVar, Set, Iterator, Sequence)
import weakref
from absl import logging
import numpy as np
from jax.config import config
Seq = Sequence
T = TypeVar("T")
def safe_zip(*args):
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
return list(zip(*args))
def safe_map(f, *args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
return list(map(f, *args))
def unzip2(xys):
"""Unzip sequence of length-2 tuples into two tuples."""
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-2 output.
xs = []
ys = []
for x, y in xys:
xs.append(x)
ys.append(y)
return tuple(xs), tuple(ys)
def unzip3(xyzs):
"""Unzip sequence of length-3 tuples into three tuples."""
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
# is too permissive about inputs, and does not guarantee a length-3 output.
xs = []
ys = []
zs = []
for x, y, z in xyzs:
xs.append(x)
ys.append(y)
zs.append(z)
return tuple(xs), tuple(ys), tuple(zs)
def subvals(lst, replace):
lst = list(lst)
for i, v in replace:
lst[i] = v
return tuple(lst)
def split_list(args: Sequence[T], ns: Sequence[int]) -> List[List[T]]:
args = list(args)
lists = []
for n in ns:
lists.append(args[:n])
args = args[n:]
lists.append(args)
return lists
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> Tuple[List[T], List[T]]:
assert len(bs) == len(l)
lists = [], [] # type: ignore
for b, x in zip(bs, l):
lists[b].append(x)
return lists
def merge_lists(bs: Sequence[bool], l0: Sequence[T], l1: Sequence[T]) -> List[T]:
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
i0, i1 = iter(l0), iter(l1)
out = [next(i1) if b else next(i0) for b in bs]
sentinel = object()
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
return out
def split_dict(dct, names):
dct = dict(dct)
lst = [dct.pop(name) for name in names]
assert not dct
return lst
def concatenate(xs: Iterable[Sequence[T]]) -> List[T]:
"""Concatenates/flattens a list of lists."""
return list(it.chain.from_iterable(xs))
flatten = concatenate
_unflatten_done = object()
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> List[List[T]]:
"""Splits `xs` into subsequences of lengths `ns`.
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`."""
xs_iter = iter(xs)
unflattened = [[next(xs_iter) for _ in range(n)] for n in ns]
assert next(xs_iter, _unflatten_done) is _unflatten_done
return unflattened
def curry(f):
"""Curries arguments of f, returning a function on any remaining arguments.
For example:
>>> f = lambda x, y, z, w: x * y + z * w
>>> f(2,3,4,5)
26
>>> curry(f)(2)(3, 4, 5)
26
>>> curry(f)(2, 3)(4, 5)
26
>>> curry(f)(2, 3, 4, 5)()
26
"""
return partial(partial, f)
def toposort(end_nodes):
if not end_nodes: return []
end_nodes = _remove_duplicates(end_nodes)
child_counts = {}
stack = list(end_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(node.parents)
for node in end_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
assert childless_nodes
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in node.parents:
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
check_toposort(sorted_nodes[::-1])
return sorted_nodes[::-1]
def check_toposort(nodes):
visited = set()
for node in nodes:
assert all(id(parent) in visited for parent in node.parents)
visited.add(id(node))
def _remove_duplicates(node_list):
seen = set()
out = []
for n in node_list:
if id(n) not in seen:
seen.add(id(n))
out.append(n)
return out
def split_merge(predicate, xs):
sides = list(map(predicate, xs))
lhs = [x for x, s in zip(xs, sides) if s]
rhs = [x for x, s in zip(xs, sides) if not s]
def merge(new_lhs, new_rhs):
out = []
for s in sides:
if s:
out.append(new_lhs[0])
new_lhs = new_lhs[1:]
else:
out.append(new_rhs[0])
new_rhs = new_rhs[1:]
assert not new_rhs
assert not new_lhs
return out
return lhs, rhs, merge
def cache(max_size=4096):
def wrap(f):
@functools.lru_cache(max_size)
def cached(_, *args, **kwargs):
return f(*args, **kwargs)
@functools.wraps(f)
def wrapper(*args, **kwargs):
if config.jax_check_tracer_leaks:
return f(*args, **kwargs)
else:
return cached(config._trace_context(), *args, **kwargs)
wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
return wrapper
return wrap
memoize = cache(max_size=None)
CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
def weakref_lru_cache(call: Callable, maxsize=2048):
"""
Least recently used cache decorator with weakref support.
The cache will take a weakref to the first argument of the wrapped function
and strong refs to all subsequent operations. In all other respects it should
behave similar to `functools.lru_cache`.
"""
cache: Dict[Any, Any] = {}
hits = misses = 0
lock = threading.Lock()
def remove_key(tctx, args, kwargs, weak_arg):
k = (weak_arg, tctx, args, kwargs)
try:
# This has a chance to race with the iteration in next(iter(cache)),
# but we cannot lock because GC can get triggered synchronously inside
# a critical section and will not relinquish control until the callback
# has finished. This would lead to a deadlock between this weakref
# cleanup function and any function below which locks.
del cache[k]
except KeyError:
pass
def wrapped(weak_arg, *args, **kwargs):
nonlocal hits, misses
if config.jax_check_tracer_leaks:
return call(weak_arg, *args, **kwargs)
kwargs_key = tuple(kwargs.items())
tctx = config._trace_context()
k = (weakref.ref(weak_arg,
functools.partial(remove_key, tctx, args, kwargs_key)),
tctx, args, kwargs_key)
with lock:
if k in cache:
hits += 1
result = cache[k]
# del and reinsert to bump key in the insertion order.
del cache[k]
cache[k] = result
return result
misses += 1
result = call(weak_arg, *args, **kwargs)
with lock:
cache[k] = result
num_errors = 0
while len(cache) > maxsize:
try:
del_k = next(iter(cache))
# This happens if a weakref callback happens between iter and
# next. Just ignore the error. WeakKeyDictionary handles this
# by deferring the deletes, but that has a chance at leaking,
# and this solution is easier.
except RuntimeError:
num_errors += 1
if num_errors > len(cache):
# This must be some other problem.
raise
else:
continue
del cache[del_k]
return result
def cache_info():
with lock:
return CacheInfo(hits, misses, maxsize, len(cache))
def cache_clear():
nonlocal hits, misses
with lock:
hits = misses = 0
cache.clear()
wrapped.cache_info = cache_info
wrapped.cache_clear = cache_clear
return wrapped
def prod(xs):
out = 1
for x in xs:
out *= x
return out
class Unhashable:
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __eq__(self, other):
return self.val == other.val
class Hashable:
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return hash(self.val)
def __eq__(self, other):
return self.val == other.val
class WrapKwArgs:
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return hash(tuple((k, v) for k, v in sorted(self.val.items())))
def __eq__(self, other):
return self.val == other.val
def get_module_functions(module):
"""Finds functions in module.
Args:
module: A Python module.
Returns:
module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`.
"""
module_fns = {}
for key in dir(module):
# Omitting module level __getattr__, __dir__ which was added in Python 3.7
# https://www.python.org/dev/peps/pep-0562/
if key in ('__getattr__', '__dir__'):
continue
attr = getattr(module, key)
if isinstance(
attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
module_fns[key] = attr
return module_fns
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def new_name_stack(name: str = ''):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
name_stack = source_info_util.NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
return name + '/'
def extend_name_stack(stack, name: str):
if config.jax_experimental_name_stack:
from jax._src import source_info_util
assert isinstance(stack, source_info_util.NameStack), stack
return stack.extend(name)
assert isinstance(stack, str)
return stack + name + '/'
def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
if axis < 0:
axis = axis + num_dims
return axis
def moveaxis(x, src, dst):
if src == dst:
return x
if isinstance(src, int):
src = (src,)
if isinstance(dst, int):
dst = (dst,)
src = [canonicalize_axis(a, x.ndim) for a in src]
dst = [canonicalize_axis(a, x.ndim) for a in dst]
perm = [i for i in range(np.ndim(x)) if i not in src]
for d, s in sorted(zip(dst, src)):
perm.insert(d, s)
return x.transpose(perm)
def ceil_of_ratio(x, y):
return -(-x // y)
@curry
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
"""
Like functools.wraps, but with finer-grained control over the name and docstring
of the resulting function.
"""
try:
name = getattr(wrapped, "__name__", "<unnamed function>")
doc = getattr(wrapped, "__doc__", "") or ""
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
fun.__name__ = namestr.format(fun=name)
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
fun.__doc__ = docstr.format(fun=name, doc=doc, **kwargs)
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
fun.__wrapped__ = wrapped
finally:
return fun
# NOTE: Ideally we would annotate both the argument and return type as NoReturn
# but it seems like pytype doesn't support that...
def assert_unreachable(x):
raise AssertionError(f"Unhandled case: {type(x).__name__}")
def tuple_insert(t, idx, val):
assert 0 <= idx <= len(t), (idx, len(t))
return t[:idx] + (val,) + t[idx:]
def tuple_delete(t, idx):
assert 0 <= idx < len(t), (idx, len(t))
return t[:idx] + t[idx + 1:]
# TODO(mattjj): replace with dataclass when Python 2 support is removed
def taggedtuple(name, fields) -> Callable[..., Any]:
"""Lightweight version of namedtuple where equality depends on the type."""
def __new__(cls, *xs):
return tuple.__new__(cls, (cls,) + xs)
def __repr__(self):
return '{}{}'.format(name, tuple.__str__(self[1:]))
class_namespace = {'__new__' : __new__, '__repr__': __repr__}
for i, f in enumerate(fields):
class_namespace[f] = property(operator.itemgetter(i+1)) # type: ignore
return type(name, (tuple,), class_namespace)
class HashableFunction:
"""Decouples function equality and hash from its identity.
Local lambdas and functiond defs are reallocated on each function call, making
the functions created on different calls compare as unequal. This breaks our
caching logic, which should really only care about comparing the semantics and
not actual identity.
This class makes it possible to compare different functions based on their
semantics. The parts that are taken into account are: the bytecode of
the wrapped function (which is cached by the CPython interpreter and is stable
across the invocations of the surrounding function), and `closure` which should
contain all values in scope that affect the function semantics. In particular
`closure` should contain all elements of the function closure, or it should be
possible to derive the relevant elements of the true function closure based
solely on the contents of the `closure` argument (e.g. in case some closed-over
values are not hashable, but are entirely determined by hashable locals).
"""
def __init__(self, f, closure):
self.f = f
self.closure = closure
def __eq__(self, other):
return (type(other) is HashableFunction and
self.f.__code__ == other.f.__code__ and
self.closure == other.closure)
def __hash__(self):
return hash((self.f.__code__, self.closure))
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
def __repr__(self):
return f'<hashable {self.f.__name__} with closure={self.closure}>'
def as_hashable_function(closure):
return lambda f: HashableFunction(f, closure)
def maybe_named_axis(axis, if_pos, if_named):
try:
pos = operator.index(axis)
named = False
except TypeError:
named = True
return if_named(axis) if named else if_pos(pos)
def distributed_debug_log(*pairs):
"""Format and log `pairs` if config.jax_distributed_debug is enabled.
Args:
pairs: A sequence of label/value pairs to log. The first pair is treated as
a heading for subsequent pairs.
"""
if config.jax_distributed_debug:
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
try:
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
for label, value in pairs[1:]:
lines.append(f" {label}: {value}")
except Exception as e:
lines.append("DISTRIBUTED_DEBUG logging failed!")
lines.append(f"{e}")
lines.append("DISTRIBUTED_DEBUG_END")
logging.warning("\n".join(lines))
class OrderedSet(Generic[T]):
elts_set: Set[T]
elts_list: List[T]
def __init__(self):
self.elts_set = set()
self.elts_list = []
def add(self, elt: T) -> None:
if elt not in self.elts_set:
self.elts_set.add(elt)
self.elts_list.append(elt)
def update(self, elts: Seq[T]) -> None:
for e in elts:
self.add(e)
def __iter__(self) -> Iterator[T]:
return iter(self.elts_list)
def __len__(self) -> int:
return len(self.elts_list)
def __contains__(self, elt: T) -> bool:
return elt in self.elts_set