2021-01-11 14:20:32 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
import functools
|
2021-09-13 17:24:44 -04:00
|
|
|
from functools import partial
|
2021-01-11 14:20:32 -08:00
|
|
|
import itertools as it
|
|
|
|
import operator
|
|
|
|
import types
|
2021-08-25 20:46:11 -07:00
|
|
|
from typing import (Any, Callable, List, Tuple, Generic, TypeVar, Set, Iterator,
|
|
|
|
Sequence)
|
2021-01-11 14:20:32 -08:00
|
|
|
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
from absl import logging
|
2021-01-11 14:20:32 -08:00
|
|
|
import numpy as np
|
|
|
|
|
2021-02-04 09:48:22 -08:00
|
|
|
from jax.config import config
|
2021-01-19 18:38:53 -08:00
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
Seq = Sequence
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2021-01-19 18:38:53 -08:00
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def safe_zip(*args):
|
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
|
|
|
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
|
|
|
return list(zip(*args))
|
|
|
|
|
|
|
|
def safe_map(f, *args):
|
|
|
|
args = list(map(list, args))
|
|
|
|
n = len(args[0])
|
|
|
|
for arg in args[1:]:
|
|
|
|
assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
|
|
|
|
return list(map(f, *args))
|
|
|
|
|
|
|
|
def unzip2(xys):
|
|
|
|
xs = []
|
|
|
|
ys = []
|
|
|
|
for x, y in xys:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
return tuple(xs), tuple(ys)
|
|
|
|
|
|
|
|
def unzip3(xyzs):
|
|
|
|
xs = []
|
|
|
|
ys = []
|
|
|
|
zs = []
|
|
|
|
for x, y, z in xyzs:
|
|
|
|
xs.append(x)
|
|
|
|
ys.append(y)
|
|
|
|
zs.append(z)
|
|
|
|
return tuple(xs), tuple(ys), tuple(zs)
|
|
|
|
|
|
|
|
def subvals(lst, replace):
|
|
|
|
lst = list(lst)
|
|
|
|
for i, v in replace:
|
|
|
|
lst[i] = v
|
|
|
|
return tuple(lst)
|
|
|
|
|
|
|
|
def split_list(args, ns):
|
|
|
|
assert type(ns) is list
|
|
|
|
args = list(args)
|
|
|
|
lists = []
|
|
|
|
for n in ns:
|
|
|
|
lists.append(args[:n])
|
|
|
|
args = args[n:]
|
|
|
|
lists.append(args)
|
|
|
|
return lists
|
|
|
|
|
2021-08-25 20:46:11 -07:00
|
|
|
def partition_list(bs: Seq[bool], l: Seq[Any]) -> Tuple[List[Any], List[Any]]:
|
2021-08-06 11:09:29 -07:00
|
|
|
assert len(bs) == len(l)
|
2021-08-25 20:46:11 -07:00
|
|
|
lists = [], [] # type: ignore
|
2021-08-06 11:09:29 -07:00
|
|
|
for b, x in zip(bs, l):
|
|
|
|
lists[b].append(x)
|
2021-08-25 20:46:11 -07:00
|
|
|
return lists
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def split_dict(dct, names):
|
|
|
|
dct = dict(dct)
|
|
|
|
lst = [dct.pop(name) for name in names]
|
|
|
|
assert not dct
|
|
|
|
return lst
|
|
|
|
|
|
|
|
def concatenate(xs):
|
|
|
|
return list(it.chain.from_iterable(xs))
|
|
|
|
|
|
|
|
def curry(f):
|
|
|
|
"""Curries arguments of f, returning a function on any remaining arguments.
|
|
|
|
|
|
|
|
For example:
|
|
|
|
>>> f = lambda x, y, z, w: x * y + z * w
|
|
|
|
>>> f(2,3,4,5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2)(3, 4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3)(4, 5)
|
|
|
|
26
|
|
|
|
>>> curry(f)(2, 3, 4, 5)()
|
|
|
|
26
|
|
|
|
"""
|
|
|
|
return partial(partial, f)
|
|
|
|
|
|
|
|
def toposort(end_nodes):
|
|
|
|
if not end_nodes: return []
|
|
|
|
end_nodes = _remove_duplicates(end_nodes)
|
|
|
|
|
|
|
|
child_counts = {}
|
|
|
|
stack = list(end_nodes)
|
|
|
|
while stack:
|
|
|
|
node = stack.pop()
|
|
|
|
if id(node) in child_counts:
|
|
|
|
child_counts[id(node)] += 1
|
|
|
|
else:
|
|
|
|
child_counts[id(node)] = 1
|
|
|
|
stack.extend(node.parents)
|
|
|
|
for node in end_nodes:
|
|
|
|
child_counts[id(node)] -= 1
|
|
|
|
|
|
|
|
sorted_nodes = []
|
|
|
|
childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
|
|
|
|
assert childless_nodes
|
|
|
|
while childless_nodes:
|
|
|
|
node = childless_nodes.pop()
|
|
|
|
sorted_nodes.append(node)
|
|
|
|
for parent in node.parents:
|
|
|
|
if child_counts[id(parent)] == 1:
|
|
|
|
childless_nodes.append(parent)
|
|
|
|
else:
|
|
|
|
child_counts[id(parent)] -= 1
|
|
|
|
|
|
|
|
check_toposort(sorted_nodes[::-1])
|
|
|
|
return sorted_nodes[::-1]
|
|
|
|
|
|
|
|
def check_toposort(nodes):
|
|
|
|
visited = set()
|
|
|
|
for node in nodes:
|
|
|
|
assert all(id(parent) in visited for parent in node.parents)
|
|
|
|
visited.add(id(node))
|
|
|
|
|
|
|
|
def _remove_duplicates(node_list):
|
|
|
|
seen = set()
|
|
|
|
out = []
|
|
|
|
for n in node_list:
|
|
|
|
if id(n) not in seen:
|
|
|
|
seen.add(id(n))
|
|
|
|
out.append(n)
|
|
|
|
return out
|
|
|
|
|
|
|
|
def split_merge(predicate, xs):
|
|
|
|
sides = list(map(predicate, xs))
|
|
|
|
lhs = [x for x, s in zip(xs, sides) if s]
|
|
|
|
rhs = [x for x, s in zip(xs, sides) if not s]
|
|
|
|
def merge(new_lhs, new_rhs):
|
|
|
|
out = []
|
|
|
|
for s in sides:
|
|
|
|
if s:
|
|
|
|
out.append(new_lhs[0])
|
|
|
|
new_lhs = new_lhs[1:]
|
|
|
|
else:
|
|
|
|
out.append(new_rhs[0])
|
|
|
|
new_rhs = new_rhs[1:]
|
|
|
|
assert not new_rhs
|
|
|
|
assert not new_lhs
|
|
|
|
return out
|
|
|
|
|
|
|
|
return lhs, rhs, merge
|
|
|
|
|
|
|
|
def cache(max_size=4096):
|
2021-01-19 18:38:53 -08:00
|
|
|
def wrap(f):
|
2021-01-25 13:23:15 -08:00
|
|
|
@functools.lru_cache(max_size)
|
|
|
|
def cached(_, *args, **kwargs):
|
|
|
|
return f(*args, **kwargs)
|
2021-01-19 18:38:53 -08:00
|
|
|
|
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapper(*args, **kwargs):
|
2021-03-19 13:49:38 -07:00
|
|
|
if config.jax_check_tracer_leaks:
|
2021-01-19 18:38:53 -08:00
|
|
|
return f(*args, **kwargs)
|
|
|
|
else:
|
2021-04-21 06:36:08 -07:00
|
|
|
return cached(config._trace_context(), *args, **kwargs)
|
2021-01-19 18:38:53 -08:00
|
|
|
|
|
|
|
wrapper.cache_clear = cached.cache_clear
|
|
|
|
wrapper.cache_info = cached.cache_info
|
|
|
|
return wrapper
|
|
|
|
return wrap
|
|
|
|
|
2021-01-25 13:23:15 -08:00
|
|
|
def memoize(f):
|
|
|
|
@functools.lru_cache(None)
|
|
|
|
def memoized(_, *args, **kwargs):
|
|
|
|
return f(*args, **kwargs)
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2021-01-25 13:23:15 -08:00
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapper(*args, **kwargs):
|
2021-04-21 06:36:08 -07:00
|
|
|
return memoized(config._trace_context(), *args, **kwargs)
|
2021-01-25 13:23:15 -08:00
|
|
|
|
|
|
|
wrapper.cache_clear = memoized.cache_clear
|
|
|
|
wrapper.cache_info = memoized.cache_info
|
|
|
|
return wrapper
|
2021-01-11 14:20:32 -08:00
|
|
|
|
|
|
|
def prod(xs):
|
|
|
|
out = 1
|
|
|
|
for x in xs:
|
|
|
|
out *= x
|
|
|
|
return out
|
|
|
|
|
2021-07-19 13:11:38 -04:00
|
|
|
class Unhashable:
|
2021-01-11 14:20:32 -08:00
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
2021-07-19 13:11:38 -04:00
|
|
|
return self.val == other.val
|
2021-01-11 14:20:32 -08:00
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
class Hashable:
|
2021-01-11 14:20:32 -08:00
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val == other.val
|
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
class WrapKwArgs:
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
self.val = val
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(tuple((k, v) for k, v in sorted(self.val.items())))
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return self.val == other.val
|
|
|
|
|
2021-01-11 14:20:32 -08:00
|
|
|
def get_module_functions(module):
|
|
|
|
"""Finds functions in module.
|
|
|
|
Args:
|
|
|
|
module: A Python module.
|
|
|
|
Returns:
|
|
|
|
module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`.
|
|
|
|
"""
|
|
|
|
module_fns = {}
|
|
|
|
for key in dir(module):
|
|
|
|
# Omitting module level __getattr__, __dir__ which was added in Python 3.7
|
|
|
|
# https://www.python.org/dev/peps/pep-0562/
|
|
|
|
if key in ('__getattr__', '__dir__'):
|
|
|
|
continue
|
|
|
|
attr = getattr(module, key)
|
|
|
|
if isinstance(
|
|
|
|
attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
|
|
|
|
module_fns[key] = attr
|
|
|
|
return module_fns
|
|
|
|
|
|
|
|
def wrap_name(name, transform_name):
|
|
|
|
return transform_name + '(' + name + ')'
|
|
|
|
|
|
|
|
def extend_name_stack(stack, name=''):
|
|
|
|
return stack + name + '/'
|
|
|
|
|
|
|
|
def canonicalize_axis(axis, num_dims) -> int:
|
|
|
|
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
|
|
|
axis = operator.index(axis)
|
|
|
|
if not -num_dims <= axis < num_dims:
|
|
|
|
raise ValueError(
|
|
|
|
"axis {} is out of bounds for array of dimension {}".format(
|
|
|
|
axis, num_dims))
|
|
|
|
if axis < 0:
|
|
|
|
axis = axis + num_dims
|
|
|
|
return axis
|
|
|
|
|
|
|
|
def moveaxis(x, src, dst):
|
|
|
|
if src == dst:
|
|
|
|
return x
|
2021-02-08 20:09:33 +00:00
|
|
|
if isinstance(src, int):
|
|
|
|
src = (src,)
|
|
|
|
if isinstance(dst, int):
|
|
|
|
dst = (dst,)
|
|
|
|
src = [canonicalize_axis(a, x.ndim) for a in src]
|
|
|
|
dst = [canonicalize_axis(a, x.ndim) for a in dst]
|
|
|
|
perm = [i for i in range(np.ndim(x)) if i not in src]
|
|
|
|
for d, s in sorted(zip(dst, src)):
|
|
|
|
perm.insert(d, s)
|
2021-01-11 14:20:32 -08:00
|
|
|
return x.transpose(perm)
|
|
|
|
|
|
|
|
def ceil_of_ratio(x, y):
|
|
|
|
return -(-x // y)
|
|
|
|
|
|
|
|
@curry
|
|
|
|
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
|
2021-02-08 11:31:53 -08:00
|
|
|
"""
|
|
|
|
Like functools.wraps, but with finer-grained control over the name and docstring
|
|
|
|
of the resulting function.
|
|
|
|
"""
|
2021-01-11 14:20:32 -08:00
|
|
|
try:
|
2021-02-08 11:31:53 -08:00
|
|
|
name = getattr(wrapped, "__name__", "<unnamed function>")
|
|
|
|
doc = getattr(wrapped, "__doc__", "") or ""
|
|
|
|
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
|
|
|
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
|
|
|
fun.__name__ = namestr.format(fun=name)
|
|
|
|
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
|
|
|
fun.__doc__ = docstr.format(fun=name, doc=doc, **kwargs)
|
|
|
|
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
2021-01-11 14:20:32 -08:00
|
|
|
fun.__wrapped__ = wrapped
|
|
|
|
finally:
|
|
|
|
return fun
|
|
|
|
|
|
|
|
# NOTE: Ideally we would annotate both the argument and return type as NoReturn
|
|
|
|
# but it seems like pytype doesn't support that...
|
|
|
|
def assert_unreachable(x):
|
|
|
|
raise AssertionError(f"Unhandled case: {type(x).__name__}")
|
|
|
|
|
|
|
|
def tuple_insert(t, idx, val):
|
|
|
|
assert 0 <= idx <= len(t), (idx, len(t))
|
|
|
|
return t[:idx] + (val,) + t[idx:]
|
|
|
|
|
|
|
|
def tuple_delete(t, idx):
|
|
|
|
assert 0 <= idx < len(t), (idx, len(t))
|
|
|
|
return t[:idx] + t[idx + 1:]
|
|
|
|
|
|
|
|
# TODO(mattjj): replace with dataclass when Python 2 support is removed
|
|
|
|
def taggedtuple(name, fields) -> Callable[..., Any]:
|
|
|
|
"""Lightweight version of namedtuple where equality depends on the type."""
|
|
|
|
def __new__(cls, *xs):
|
|
|
|
return tuple.__new__(cls, (cls,) + xs)
|
|
|
|
def __repr__(self):
|
|
|
|
return '{}{}'.format(name, tuple.__str__(self[1:]))
|
|
|
|
class_namespace = {'__new__' : __new__, '__repr__': __repr__}
|
|
|
|
for i, f in enumerate(fields):
|
|
|
|
class_namespace[f] = property(operator.itemgetter(i+1)) # type: ignore
|
|
|
|
return type(name, (tuple,), class_namespace)
|
|
|
|
|
|
|
|
class HashableFunction:
|
|
|
|
"""Decouples function equality and hash from its identity.
|
|
|
|
|
|
|
|
Local lambdas and functiond defs are reallocated on each function call, making
|
|
|
|
the functions created on different calls compare as unequal. This breaks our
|
|
|
|
caching logic, which should really only care about comparing the semantics and
|
|
|
|
not actual identity.
|
|
|
|
|
|
|
|
This class makes it possible to compare different functions based on their
|
|
|
|
semantics. The parts that are taken into account are: the bytecode of
|
|
|
|
the wrapped function (which is cached by the CPython interpreter and is stable
|
|
|
|
across the invocations of the surrounding function), and `closure` which should
|
|
|
|
contain all values in scope that affect the function semantics. In particular
|
|
|
|
`closure` should contain all elements of the function closure, or it should be
|
|
|
|
possible to derive the relevant elements of the true function closure based
|
|
|
|
solely on the contents of the `closure` argument (e.g. in case some closed-over
|
|
|
|
values are not hashable, but are entirely determined by hashable locals).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, f, closure):
|
|
|
|
self.f = f
|
|
|
|
self.closure = closure
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return (type(other) is HashableFunction and
|
|
|
|
self.f.__code__ == other.f.__code__ and
|
|
|
|
self.closure == other.closure)
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.f.__code__, self.closure))
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return self.f(*args, **kwargs)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'<hashable {self.f.__name__} with closure={self.closure}>'
|
|
|
|
|
|
|
|
def as_hashable_function(closure):
|
|
|
|
return lambda f: HashableFunction(f, closure)
|
2021-02-05 11:34:32 +00:00
|
|
|
|
|
|
|
def maybe_named_axis(axis, if_pos, if_named):
|
|
|
|
try:
|
|
|
|
pos = operator.index(axis)
|
|
|
|
named = False
|
|
|
|
except TypeError:
|
|
|
|
named = True
|
|
|
|
return if_named(axis) if named else if_pos(pos)
|
Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.
This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.
Example output:
```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
process_index: 0
device_count: 8
local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
devices: None
abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
dtype=object), ('x',))
abstract args: []
DISTRIBUTED_DEBUG_END
DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
[TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
dtype=object), ('x', 'y'))
abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-19 12:45:17 -07:00
|
|
|
|
|
|
|
def distributed_debug_log(*pairs):
|
|
|
|
"""Format and log `pairs` if config.jax_distributed_debug is enabled.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
pairs: A sequence of label/value pairs to log. The first pair is treated as
|
|
|
|
a heading for subsequent pairs.
|
|
|
|
"""
|
|
|
|
if config.jax_distributed_debug:
|
|
|
|
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
|
|
|
|
try:
|
|
|
|
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
|
|
|
|
for label, value in pairs[1:]:
|
|
|
|
lines.append(f" {label}: {value}")
|
|
|
|
except Exception as e:
|
|
|
|
lines.append("DISTRIBUTED_DEBUG logging failed!")
|
|
|
|
lines.append(f"{e}")
|
|
|
|
lines.append("DISTRIBUTED_DEBUG_END")
|
|
|
|
logging.warning("\n".join(lines))
|
2021-08-25 20:46:11 -07:00
|
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
|
|
class OrderedSet(Generic[T]):
|
|
|
|
elts_set: Set[T]
|
|
|
|
elts_list: List[T]
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.elts_set = set()
|
|
|
|
self.elts_list = []
|
|
|
|
|
|
|
|
def add(self, elt: T) -> None:
|
|
|
|
if elt not in self.elts_set:
|
|
|
|
self.elts_set.add(elt)
|
|
|
|
self.elts_list.append(elt)
|
|
|
|
|
|
|
|
def update(self, elts: Seq[T]) -> None:
|
|
|
|
for e in elts:
|
|
|
|
self.add(e)
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
|
|
return iter(self.elts_list)
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return len(self.elts_list)
|
|
|
|
|
|
|
|
def __contains__(self, elt: T) -> bool:
|
|
|
|
return elt in self.elts_set
|