2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-01-11 14:20:32 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-12-11 13:59:29 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-01-26 13:52:47 +00:00
|
|
|
import abc
|
2024-06-26 14:44:52 -04:00
|
|
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
2021-01-11 14:20:32 -08:00
|
|
|
import functools
|
2023-03-02 08:23:50 -08:00
|
|
|
from functools import partial
|
2021-01-11 14:20:32 -08:00
|
|
|
import itertools as it
|
2022-10-13 17:06:22 +02:00
|
|
|
import logging
|
2021-01-11 14:20:32 -08:00
|
|
|
import operator
|
2024-06-26 14:44:52 -04:00
|
|
|
from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast)
|
2023-04-07 12:09:26 -07:00
|
|
|
import weakref
|
2021-01-11 14:20:32 -08:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
from jax._src import config
|
2022-08-18 14:35:37 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2023-04-10 18:09:24 -07:00
|
|
|
from jax._src.lib import utils as jaxlib_utils
|
2021-01-19 18:38:53 -08:00
|
|
|
|
2022-10-13 17:06:22 +02:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
Seq = Sequence
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
T = TypeVar("T")
|
2022-10-18 09:47:49 -07:00
|
|
|
T1 = TypeVar("T1")
|
|
|
|
T2 = TypeVar("T2")
|
|
|
|
T3 = TypeVar("T3")
|
2021-01-19 18:38:53 -08:00
|
|
|
|
2023-04-11 12:42:30 -07:00
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
# safe_zip cannot yet be fully annotated, so we use a strategy similar
|
|
|
|
# to that used for builtins.zip in python/typeshed. This supports
|
|
|
|
# return types matching input types for up to three arguments.
|
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_zip(__arg1: Iterable[T1]) -> list[tuple[T1]]: ...
|
2023-04-11 12:42:30 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[tuple[T1, T2]]: ...
|
2023-04-11 12:42:30 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[tuple[T1, T2, T3]]: ...
|
2023-04-11 12:42:30 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ...
|
2023-04-11 12:42:30 -07:00
|
|
|
|
|
|
|
def safe_zip(*args):
|
|
|
|
args = list(map(list, args))
|
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
|
|
|
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
|
|
|
return list(zip(*args))
|
|
|
|
|
|
|
|
else:
|
2023-07-12 11:53:55 -07:00
|
|
|
safe_zip = jaxlib_utils.safe_zip
|
2023-04-11 12:42:30 -07:00
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2023-04-10 18:09:24 -07:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
# safe_map cannot yet be fully annotated, so we use a strategy similar
|
|
|
|
# to that used for builtins.map in python/typeshed. This supports
|
|
|
|
# checking input types for the callable with up to three arguments.
|
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> list[T]: ...
|
2022-10-20 10:15:04 -07:00
|
|
|
|
2023-04-10 18:09:24 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> list[T]: ...
|
2022-10-20 10:15:04 -07:00
|
|
|
|
2023-04-10 18:09:24 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> list[T]: ...
|
2022-10-20 10:15:04 -07:00
|
|
|
|
2023-04-10 18:09:24 -07:00
|
|
|
@overload
|
2023-06-23 15:11:37 -07:00
|
|
|
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[T]: ...
|
2022-10-20 10:15:04 -07:00
|
|
|
|
2023-04-10 18:09:24 -07:00
|
|
|
def safe_map(f, *args):
|
|
|
|
args = list(map(list, args))
|
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
|
|
|
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
|
|
|
return list(map(f, *args))
|
|
|
|
|
|
|
|
else:
|
2023-07-12 11:53:55 -07:00
|
|
|
safe_map = jaxlib_utils.safe_map
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def unzip2(xys: Iterable[tuple[T1, T2]]
|
|
|
|
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
2022-04-14 13:41:05 -07:00
|
|
|
"""Unzip sequence of length-2 tuples into two tuples."""
|
|
|
|
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
|
|
|
# is too permissive about inputs, and does not guarantee a length-2 output.
|
2023-06-23 15:11:37 -07:00
|
|
|
xs: list[T1] = []
|
|
|
|
ys: list[T2] = []
|
2021-01-11 14:20:32 -08:00
|
|
|
for x, y in xys:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
return tuple(xs), tuple(ys)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def unzip3(xyzs: Iterable[tuple[T1, T2, T3]]
|
|
|
|
) -> tuple[tuple[T1, ...], tuple[T2, ...], tuple[T3, ...]]:
|
2022-04-14 13:41:05 -07:00
|
|
|
"""Unzip sequence of length-3 tuples into three tuples."""
|
|
|
|
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
|
|
|
|
# is too permissive about inputs, and does not guarantee a length-3 output.
|
2023-06-23 15:11:37 -07:00
|
|
|
xs: list[T1] = []
|
|
|
|
ys: list[T2] = []
|
|
|
|
zs: list[T3] = []
|
2021-01-11 14:20:32 -08:00
|
|
|
for x, y, z in xyzs:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
zs.append(z)
|
|
|
|
return tuple(xs), tuple(ys), tuple(zs)
|
|
|
|
|
|
|
|
def subvals(lst, replace):
|
|
|
|
lst = list(lst)
|
|
|
|
for i, v in replace:
|
|
|
|
lst[i] = v
|
|
|
|
return tuple(lst)
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
2021-01-11 14:20:32 -08:00
|
|
|
args = list(args)
|
|
|
|
lists = []
|
|
|
|
for n in ns:
|
|
|
|
lists.append(args[:n])
|
|
|
|
args = args[n:]
|
|
|
|
lists.append(args)
|
|
|
|
return lists
|
|
|
|
|
2024-04-28 21:16:13 -04:00
|
|
|
def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
|
|
|
args = list(args)
|
|
|
|
assert sum(ns) == len(args)
|
|
|
|
lists = []
|
|
|
|
for n in ns:
|
|
|
|
lists.append(args[:n])
|
|
|
|
args = args[n:]
|
|
|
|
return lists
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
|
2021-08-06 11:09:29 -07:00
|
|
|
assert len(bs) == len(l)
|
2021-08-25 20:46:11 -07:00
|
|
|
lists = [], [] # type: ignore
|
2021-08-06 11:09:29 -07:00
|
|
|
for b, x in zip(bs, l):
|
|
|
|
lists[b].append(x)
|
2021-08-25 20:46:11 -07:00
|
|
|
return lists
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2023-11-16 21:08:44 -05:00
|
|
|
def merge_lists(bs: Sequence[bool],
|
|
|
|
l0: Sequence[T1],
|
|
|
|
l1: Sequence[T2]
|
|
|
|
) -> list[T1 | T2]:
|
2022-02-06 17:21:31 -08:00
|
|
|
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
|
|
|
|
i0, i1 = iter(l0), iter(l1)
|
2023-11-16 21:08:44 -05:00
|
|
|
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
|
2022-02-06 17:21:31 -08:00
|
|
|
sentinel = object()
|
|
|
|
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
|
|
|
|
return out
|
|
|
|
|
2023-10-19 00:38:19 -07:00
|
|
|
def subs_list(
|
2023-12-11 13:59:29 +00:00
|
|
|
subs: Sequence[int | None], src: Sequence[T], base: Sequence[T],
|
2023-10-19 00:38:19 -07:00
|
|
|
) -> list[T]:
|
|
|
|
base_ = iter(base)
|
|
|
|
out = [src[i] if i is not None else next(base_) for i in subs]
|
|
|
|
sentinel = object()
|
|
|
|
assert next(base_, sentinel) is sentinel
|
|
|
|
return out
|
|
|
|
|
|
|
|
def subs_list2(
|
2023-12-11 13:59:29 +00:00
|
|
|
subs1: Sequence[int | None], subs2: Sequence[int | None],
|
2023-10-19 00:38:19 -07:00
|
|
|
src1: Sequence[T], src2: Sequence[T], base: Sequence[T],
|
|
|
|
) -> list[T]:
|
|
|
|
assert len(subs1) == len(subs2)
|
|
|
|
base_ = iter(base)
|
|
|
|
out = [src1[f1] if f1 is not None else src2[f2] if f2 is not None else
|
|
|
|
next(base_) for f1, f2, in zip(subs1, subs2)]
|
|
|
|
sentinel = object()
|
|
|
|
assert next(base_, sentinel) is sentinel
|
|
|
|
return out
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def split_dict(dct, names):
|
|
|
|
dct = dict(dct)
|
|
|
|
lst = [dct.pop(name) for name in names]
|
|
|
|
assert not dct
|
|
|
|
return lst
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def concatenate(xs: Iterable[Sequence[T]]) -> list[T]:
|
2021-11-11 06:36:31 -08:00
|
|
|
"""Concatenates/flattens a list of lists."""
|
2021-01-11 14:20:32 -08:00
|
|
|
return list(it.chain.from_iterable(xs))
|
|
|
|
|
2021-11-11 06:36:31 -08:00
|
|
|
flatten = concatenate
|
|
|
|
|
2021-11-24 12:52:08 -08:00
|
|
|
_unflatten_done = object()
|
|
|
|
|
2023-06-23 15:11:37 -07:00
|
|
|
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> list[list[T]]:
|
2021-11-11 06:36:31 -08:00
|
|
|
"""Splits `xs` into subsequences of lengths `ns`.
|
|
|
|
|
|
|
|
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`."""
|
|
|
|
xs_iter = iter(xs)
|
|
|
|
unflattened = [[next(xs_iter) for _ in range(n)] for n in ns]
|
2021-11-24 12:52:08 -08:00
|
|
|
assert next(xs_iter, _unflatten_done) is _unflatten_done
|
2021-11-11 06:36:31 -08:00
|
|
|
return unflattened
|
|
|
|
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def curry(f):
|
|
|
|
"""Curries arguments of f, returning a function on any remaining arguments.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
>>> f = lambda x, y, z, w: x * y + z * w
|
|
|
|
>>> f(2,3,4,5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2)(3, 4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3)(4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3, 4, 5)()
|
|
|
|
26
|
|
|
|
"""
|
2023-08-30 15:14:47 -07:00
|
|
|
return wraps(f)(partial(partial, f))
|
2021-01-11 14:20:32 -08:00
|
|
|
|
|
|
|
def toposort(end_nodes):
|
|
|
|
if not end_nodes: return []
|
|
|
|
end_nodes = _remove_duplicates(end_nodes)
|
|
|
|
|
|
|
|
child_counts = {}
|
|
|
|
stack = list(end_nodes)
|
|
|
|
while stack:
|
|
|
|
node = stack.pop()
|
|
|
|
if id(node) in child_counts:
|
|
|
|
child_counts[id(node)] += 1
|
|
|
|
else:
|
|
|
|
child_counts[id(node)] = 1
|
|
|
|
stack.extend(node.parents)
|
|
|
|
for node in end_nodes:
|
|
|
|
child_counts[id(node)] -= 1
|
|
|
|
|
|
|
|
sorted_nodes = []
|
|
|
|
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
|
|
|
assert childless_nodes
|
|
|
|
while childless_nodes:
|
|
|
|
node = childless_nodes.pop()
|
|
|
|
sorted_nodes.append(node)
|
|
|
|
for parent in node.parents:
|
|
|
|
if child_counts[id(parent)] == 1:
|
|
|
|
childless_nodes.append(parent)
|
|
|
|
else:
|
|
|
|
child_counts[id(parent)] -= 1
|
2022-06-17 15:53:53 -07:00
|
|
|
sorted_nodes = sorted_nodes[::-1]
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2022-06-17 15:53:53 -07:00
|
|
|
check_toposort(sorted_nodes)
|
|
|
|
return sorted_nodes
|
2021-01-11 14:20:32 -08:00
|
|
|
|
|
|
|
def check_toposort(nodes):
|
|
|
|
visited = set()
|
|
|
|
for node in nodes:
|
|
|
|
assert all(id(parent) in visited for parent in node.parents)
|
|
|
|
visited.add(id(node))
|
|
|
|
|
|
|
|
def _remove_duplicates(node_list):
|
|
|
|
seen = set()
|
|
|
|
out = []
|
|
|
|
for n in node_list:
|
|
|
|
if id(n) not in seen:
|
|
|
|
seen.add(id(n))
|
|
|
|
out.append(n)
|
|
|
|
return out
|
|
|
|
|
|
|
|
def split_merge(predicate, xs):
|
|
|
|
sides = list(map(predicate, xs))
|
|
|
|
lhs = [x for x, s in zip(xs, sides) if s]
|
|
|
|
rhs = [x for x, s in zip(xs, sides) if not s]
|
|
|
|
def merge(new_lhs, new_rhs):
|
|
|
|
out = []
|
|
|
|
for s in sides:
|
|
|
|
if s:
|
|
|
|
out.append(new_lhs[0])
|
|
|
|
new_lhs = new_lhs[1:]
|
|
|
|
else:
|
|
|
|
out.append(new_rhs[0])
|
|
|
|
new_rhs = new_rhs[1:]
|
|
|
|
assert not new_rhs
|
|
|
|
assert not new_lhs
|
|
|
|
return out
|
|
|
|
|
|
|
|
return lhs, rhs, merge
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
|
2024-06-28 13:49:48 -07:00
|
|
|
def _ignore(): return None
|
2024-06-11 12:46:11 -07:00
|
|
|
|
|
|
|
|
|
|
|
def cache(max_size=4096, trace_context_in_key=True):
|
2021-01-19 18:38:53 -08:00
|
|
|
def wrap(f):
|
2021-01-25 13:23:15 -08:00
|
|
|
@functools.lru_cache(max_size)
|
|
|
|
def cached(_, *args, **kwargs):
|
|
|
|
return f(*args, **kwargs)
|
2021-01-19 18:38:53 -08:00
|
|
|
|
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapper(*args, **kwargs):
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.check_tracer_leaks.value:
|
2021-01-19 18:38:53 -08:00
|
|
|
return f(*args, **kwargs)
|
2024-06-28 13:49:48 -07:00
|
|
|
return cached(config.trace_context() if trace_context_in_key else _ignore(),
|
|
|
|
*args, **kwargs)
|
2021-01-19 18:38:53 -08:00
|
|
|
|
|
|
|
wrapper.cache_clear = cached.cache_clear
|
|
|
|
wrapper.cache_info = cached.cache_info
|
2024-06-11 12:46:11 -07:00
|
|
|
cache_clearing_funs.add(wrapper.cache_clear)
|
2021-01-19 18:38:53 -08:00
|
|
|
return wrapper
|
|
|
|
return wrap
|
|
|
|
|
2024-06-11 12:46:11 -07:00
|
|
|
cache_clearing_funs = weakref.WeakSet() # type: ignore
|
|
|
|
|
|
|
|
def clear_all_caches():
|
|
|
|
global cache_clearing_funs
|
|
|
|
for clear in cache_clearing_funs:
|
|
|
|
clear()
|
|
|
|
|
2022-01-10 14:28:28 -08:00
|
|
|
memoize = cache(max_size=None)
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2024-07-01 09:40:43 -07:00
|
|
|
def weakref_lru_cache(call: Callable, maxsize=2048,
|
|
|
|
trace_context_in_key: bool = True):
|
2022-03-31 13:54:06 -07:00
|
|
|
"""
|
|
|
|
Least recently used cache decorator with weakref support.
|
|
|
|
|
|
|
|
The cache will take a weakref to the first argument of the wrapped function
|
|
|
|
and strong refs to all subsequent operations. In all other respects it should
|
|
|
|
behave similar to `functools.lru_cache`.
|
|
|
|
"""
|
2023-04-07 12:09:26 -07:00
|
|
|
global _weakref_lru_caches
|
2024-07-01 09:40:43 -07:00
|
|
|
cached_call = xc.weakref_lru_cache(
|
|
|
|
config.trace_context if trace_context_in_key else _ignore,
|
|
|
|
call, maxsize)
|
2023-04-07 12:09:26 -07:00
|
|
|
_weakref_lru_caches.add(cached_call)
|
|
|
|
return cached_call
|
2023-05-24 13:58:17 -07:00
|
|
|
|
2023-04-07 12:09:26 -07:00
|
|
|
_weakref_lru_caches = weakref.WeakSet() # type: ignore
|
|
|
|
|
|
|
|
def clear_all_weakref_lru_caches():
|
|
|
|
for cached_call in _weakref_lru_caches:
|
|
|
|
cached_call.cache_clear()
|
2022-03-18 19:51:29 -07:00
|
|
|
|
2021-07-19 13:11:38 -04:00
|
|
|
class Unhashable:
|
2021-01-11 14:20:32 -08:00
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2021-07-19 13:11:38 -04:00
|
|
|
return self.val == other.val
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
class Hashable:
|
2021-01-11 14:20:32 -08:00
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val == other.val
|
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
class WrapKwArgs:
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(tuple((k, v) for k, v in sorted(self.val.items())))
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val == other.val
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def wrap_name(name, transform_name):
|
|
|
|
return transform_name + '(' + name + ')'
|
|
|
|
|
2024-05-31 20:38:16 -07:00
|
|
|
def fun_name(fun: Callable):
|
|
|
|
return getattr(fun, "__name__", "<unnamed function>")
|
|
|
|
|
2024-07-02 13:07:46 -04:00
|
|
|
def fun_qual_name(fun: Callable):
|
|
|
|
return getattr(fun, "__qualname__",
|
|
|
|
getattr(fun, "__name__", "<unnamed function>"))
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def canonicalize_axis(axis, num_dims) -> int:
|
|
|
|
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
|
|
|
axis = operator.index(axis)
|
|
|
|
if not -num_dims <= axis < num_dims:
|
2021-11-23 16:54:02 -08:00
|
|
|
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
2021-01-11 14:20:32 -08:00
|
|
|
if axis < 0:
|
|
|
|
axis = axis + num_dims
|
|
|
|
return axis
|
|
|
|
|
|
|
|
def moveaxis(x, src, dst):
|
|
|
|
if src == dst:
|
|
|
|
return x
|
2021-02-08 20:09:33 +00:00
|
|
|
if isinstance(src, int):
|
|
|
|
src = (src,)
|
|
|
|
if isinstance(dst, int):
|
|
|
|
dst = (dst,)
|
|
|
|
src = [canonicalize_axis(a, x.ndim) for a in src]
|
|
|
|
dst = [canonicalize_axis(a, x.ndim) for a in dst]
|
|
|
|
perm = [i for i in range(np.ndim(x)) if i not in src]
|
|
|
|
for d, s in sorted(zip(dst, src)):
|
|
|
|
perm.insert(d, s)
|
2021-01-11 14:20:32 -08:00
|
|
|
return x.transpose(perm)
|
|
|
|
|
|
|
|
def ceil_of_ratio(x, y):
|
|
|
|
return -(-x // y)
|
|
|
|
|
2023-03-02 18:41:19 -08:00
|
|
|
|
|
|
|
def wraps(
|
|
|
|
wrapped: Callable,
|
2023-12-11 13:59:29 +00:00
|
|
|
namestr: str | None = None,
|
|
|
|
docstr: str | None = None,
|
2023-03-02 18:41:19 -08:00
|
|
|
**kwargs,
|
|
|
|
) -> Callable[[T], T]:
|
2021-02-08 11:31:53 -08:00
|
|
|
"""
|
|
|
|
Like functools.wraps, but with finer-grained control over the name and docstring
|
|
|
|
of the resulting function.
|
|
|
|
"""
|
2023-03-02 18:41:19 -08:00
|
|
|
def wrapper(fun: T) -> T:
|
|
|
|
try:
|
2024-05-31 20:38:16 -07:00
|
|
|
name = fun_name(wrapped)
|
2023-03-02 18:41:19 -08:00
|
|
|
doc = getattr(wrapped, "__doc__", "") or ""
|
|
|
|
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
|
|
|
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
|
|
|
fun.__name__ = name if namestr is None else namestr.format(fun=name)
|
|
|
|
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
|
|
|
fun.__doc__ = (doc if docstr is None
|
|
|
|
else docstr.format(fun=name, doc=doc, **kwargs))
|
|
|
|
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
|
|
|
fun.__wrapped__ = wrapped
|
|
|
|
finally:
|
|
|
|
return fun
|
|
|
|
return wrapper
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
|
|
|
|
# NOTE: Ideally we would annotate both the argument and return type as NoReturn
|
|
|
|
# but it seems like pytype doesn't support that...
|
|
|
|
def assert_unreachable(x):
|
|
|
|
raise AssertionError(f"Unhandled case: {type(x).__name__}")
|
|
|
|
|
|
|
|
def tuple_insert(t, idx, val):
|
|
|
|
assert 0 <= idx <= len(t), (idx, len(t))
|
|
|
|
return t[:idx] + (val,) + t[idx:]
|
|
|
|
|
|
|
|
def tuple_delete(t, idx):
|
|
|
|
assert 0 <= idx < len(t), (idx, len(t))
|
|
|
|
return t[:idx] + t[idx + 1:]
|
|
|
|
|
Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
2024-01-19 03:53:01 -08:00
|
|
|
def tuple_update(t, idx, val):
|
|
|
|
assert 0 <= idx < len(t), (idx, len(t))
|
|
|
|
return t[:idx] + (val,) + t[idx+1:]
|
|
|
|
|
2024-11-14 15:23:26 -05:00
|
|
|
def tuple_replace(tupl, index, item):
|
|
|
|
# unlike tuple_update, works with negative indices as well
|
|
|
|
return tupl[:index] + (item,) + tupl[index:][1:]
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
class HashableFunction:
|
|
|
|
"""Decouples function equality and hash from its identity.
|
|
|
|
|
2022-11-04 15:29:10 -07:00
|
|
|
Local lambdas and function defs are reallocated on each function call, making
|
2021-01-11 14:20:32 -08:00
|
|
|
the functions created on different calls compare as unequal. This breaks our
|
|
|
|
caching logic, which should really only care about comparing the semantics and
|
|
|
|
not actual identity.
|
|
|
|
|
|
|
|
This class makes it possible to compare different functions based on their
|
2022-11-04 15:29:10 -07:00
|
|
|
semantics. The parts that are taken into account are: the bytecode of the
|
|
|
|
wrapped function (which is cached by the CPython interpreter and is stable
|
|
|
|
across the invocations of the surrounding function), and `closure` which
|
|
|
|
should contain all values in scope that affect the function semantics. In
|
|
|
|
particular `closure` should contain all elements of the function closure, or
|
|
|
|
it should be possible to derive the relevant elements of the true function
|
|
|
|
closure based solely on the contents of the `closure` argument (e.g. in case
|
|
|
|
some closed-over values are not hashable, but are entirely determined by
|
|
|
|
hashable locals).
|
2021-01-11 14:20:32 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, f, closure):
|
|
|
|
self.f = f
|
|
|
|
self.closure = closure
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(other) is HashableFunction and
|
|
|
|
self.f.__code__ == other.f.__code__ and
|
|
|
|
self.closure == other.closure)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.f.__code__, self.closure))
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return self.f(*args, **kwargs)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'<hashable {self.f.__name__} with closure={self.closure}>'
|
|
|
|
|
|
|
|
def as_hashable_function(closure):
|
|
|
|
return lambda f: HashableFunction(f, closure)
|
2021-02-05 11:34:32 +00:00
|
|
|
|
2023-03-13 10:47:45 -07:00
|
|
|
class HashablePartial:
|
|
|
|
def __init__(self, f, *args, **kwargs):
|
|
|
|
self.f = f
|
|
|
|
self.args = args
|
|
|
|
self.kwargs = kwargs
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(other) is HashablePartial and
|
|
|
|
self.f.__code__ == other.f.__code__ and
|
|
|
|
self.args == other.args and self.kwargs == other.kwargs)
|
|
|
|
|
|
|
|
def __hash__(self):
|
2023-11-06 17:27:19 +08:00
|
|
|
return hash(
|
|
|
|
(
|
|
|
|
self.f.__code__,
|
|
|
|
self.args,
|
|
|
|
tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])),
|
|
|
|
),
|
|
|
|
)
|
2023-03-13 10:47:45 -07:00
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return self.f(*self.args, *args, **self.kwargs, **kwargs)
|
|
|
|
|
2021-02-05 11:34:32 +00:00
|
|
|
def maybe_named_axis(axis, if_pos, if_named):
|
|
|
|
try:
|
|
|
|
pos = operator.index(axis)
|
|
|
|
named = False
|
|
|
|
except TypeError:
|
|
|
|
named = True
|
|
|
|
return if_named(axis) if named else if_pos(pos)
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
|
|
|
|
def distributed_debug_log(*pairs):
|
|
|
|
"""Format and log `pairs` if config.jax_distributed_debug is enabled.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
pairs: A sequence of label/value pairs to log. The first pair is treated as
|
|
|
|
a heading for subsequent pairs.
|
|
|
|
"""
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.distributed_debug.value:
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
|
|
|
|
try:
|
|
|
|
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
|
|
|
|
for label, value in pairs[1:]:
|
|
|
|
lines.append(f" {label}: {value}")
|
|
|
|
except Exception as e:
|
|
|
|
lines.append("DISTRIBUTED_DEBUG logging failed!")
|
|
|
|
lines.append(f"{e}")
|
|
|
|
lines.append("DISTRIBUTED_DEBUG_END")
|
2022-10-13 17:06:22 +02:00
|
|
|
logger.warning("\n".join(lines))
|
2021-08-25 20:46:11 -07:00
|
|
|
|
|
|
|
|
2024-01-22 13:44:34 -08:00
|
|
|
def stable_unique(it: Iterable[T]) -> Iterable[T]:
|
|
|
|
"""Returns unique elements from `it` in the order of occurrence.
|
|
|
|
|
|
|
|
The elements must be hashable.
|
|
|
|
"""
|
|
|
|
return dict.fromkeys(it).keys()
|
|
|
|
|
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
class OrderedSet(Generic[T]):
|
2023-06-23 15:11:37 -07:00
|
|
|
elts_set: set[T]
|
|
|
|
elts_list: list[T]
|
2021-08-25 20:46:11 -07:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.elts_set = set()
|
|
|
|
self.elts_list = []
|
|
|
|
|
|
|
|
def add(self, elt: T) -> None:
|
|
|
|
if elt not in self.elts_set:
|
|
|
|
self.elts_set.add(elt)
|
|
|
|
self.elts_list.append(elt)
|
|
|
|
|
|
|
|
def update(self, elts: Seq[T]) -> None:
|
|
|
|
for e in elts:
|
|
|
|
self.add(e)
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
|
|
return iter(self.elts_list)
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return len(self.elts_list)
|
|
|
|
|
|
|
|
def __contains__(self, elt: T) -> bool:
|
|
|
|
return elt in self.elts_set
|
2022-07-07 16:44:00 -07:00
|
|
|
|
|
|
|
|
|
|
|
class HashableWrapper:
|
|
|
|
x: Any
|
2023-12-11 13:59:29 +00:00
|
|
|
hash: int | None
|
2022-07-07 16:44:00 -07:00
|
|
|
def __init__(self, x):
|
|
|
|
self.x = x
|
|
|
|
try: self.hash = hash(x)
|
|
|
|
except: self.hash = None
|
|
|
|
def __hash__(self):
|
|
|
|
return self.hash if self.hash is not None else id(self.x)
|
|
|
|
def __eq__(self, other):
|
|
|
|
if not isinstance(other, HashableWrapper):
|
|
|
|
return False
|
|
|
|
return self.x == other.x if self.hash is not None else self.x is other.x
|
2022-11-29 16:39:45 -08:00
|
|
|
|
|
|
|
|
|
|
|
def _original_func(f):
|
|
|
|
if isinstance(f, property):
|
|
|
|
return cast(property, f).fget
|
2023-03-02 08:23:50 -08:00
|
|
|
elif isinstance(f, functools.cached_property):
|
2022-11-29 16:39:45 -08:00
|
|
|
return f.func
|
|
|
|
return f
|
|
|
|
|
|
|
|
|
2023-09-20 15:50:15 -07:00
|
|
|
def set_module(module: str) -> Callable[[T], T]:
|
|
|
|
def wrapper(func: T) -> T:
|
2023-04-21 09:31:02 -07:00
|
|
|
if module is not None:
|
|
|
|
func.__module__ = module
|
|
|
|
return func
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]:
|
|
|
|
"""A decorator replacing a Python class with its C++ version at runtime."""
|
2022-11-29 16:39:45 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
def wrapper(cls):
|
|
|
|
if cpp_cls is None:
|
2023-03-14 14:19:25 -07:00
|
|
|
return cls
|
2022-11-29 16:39:45 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
exclude_methods = {'__module__', '__dict__', '__doc__'}
|
2022-11-29 16:39:45 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
originals = {}
|
|
|
|
for attr_name, attr in cls.__dict__.items():
|
|
|
|
if attr_name not in exclude_methods:
|
|
|
|
if hasattr(_original_func(attr), "_use_cpp"):
|
|
|
|
originals[attr_name] = attr
|
|
|
|
else:
|
|
|
|
setattr(cpp_cls, attr_name, attr)
|
2022-11-29 16:39:45 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
cpp_cls.__doc__ = cls.__doc__
|
|
|
|
# TODO(pschuh): Remove once fastpath is gone.
|
|
|
|
cpp_cls._original_py_fns = originals
|
|
|
|
return cpp_cls
|
2022-11-29 16:39:45 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
return wrapper
|
2023-02-17 11:52:08 -08:00
|
|
|
|
2024-08-12 12:49:18 +01:00
|
|
|
def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]:
|
|
|
|
"""A decorator excluding methods from the set that are forwarded to C++ class."""
|
|
|
|
if not isinstance(is_enabled, bool):
|
|
|
|
raise TypeError("``is_enabled`` must be a bool")
|
|
|
|
def decorator(f):
|
|
|
|
if is_enabled:
|
|
|
|
original_func = _original_func(f)
|
|
|
|
original_func._use_cpp = True
|
|
|
|
return f
|
|
|
|
return decorator
|
2023-08-07 19:08:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
# numpy 1.25.0 or newer
|
2023-08-31 09:05:26 -07:00
|
|
|
NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning
|
2023-08-07 19:08:41 +02:00
|
|
|
except AttributeError:
|
|
|
|
# legacy numpy
|
|
|
|
NumpyComplexWarning = np.ComplexWarning
|
2024-01-26 13:52:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
class StrictABCMeta(abc.ABCMeta):
|
|
|
|
"""A variant of `abc.ABCMeta` which does not allow virtual subclasses.
|
|
|
|
|
|
|
|
Virtual subclasses support require `abc.ABCMeta` to roundtrip through
|
|
|
|
pure Python when doing instance/subclass checking. This if fine for ABCs
|
|
|
|
which need virtual subclasses, but is wasteful for the ones which don't.
|
|
|
|
"""
|
|
|
|
def register(cls, subclass):
|
|
|
|
del subclass # Unused.
|
|
|
|
raise NotImplementedError(f"{cls} does not support virtual subclasses")
|
|
|
|
|
|
|
|
__instancecheck__ = type.__instancecheck__ # type: ignore[assignment]
|
|
|
|
__subclasscheck__ = type.__subclasscheck__ # type: ignore[assignment]
|
|
|
|
|
|
|
|
|
|
|
|
class StrictABC(metaclass=StrictABCMeta):
|
|
|
|
__slots__ = ()
|