2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-12-14 14:52:51 -08:00
|
|
|
import operator
|
2021-05-01 12:28:12 -07:00
|
|
|
from functools import partial
|
|
|
|
from typing import Any, Dict, Iterable, Tuple, Union, Optional
|
2020-10-20 22:22:33 +02:00
|
|
|
|
2021-02-10 16:22:29 -08:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from . import core
|
2021-04-07 19:35:17 -07:00
|
|
|
from ._src import dtypes
|
2021-03-24 12:00:12 -07:00
|
|
|
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap,
|
2021-01-12 19:37:19 -08:00
|
|
|
tree_structure, treedef_children, treedef_is_leaf)
|
2021-03-24 12:00:12 -07:00
|
|
|
from ._src.tree_util import _replace_nones
|
2020-01-05 04:35:34 +01:00
|
|
|
from . import linear_util as lu
|
2021-03-29 13:52:39 -07:00
|
|
|
from ._src.util import safe_map, WrapHashably, WrapKwArgs, Hashable
|
2020-01-15 15:00:38 -08:00
|
|
|
from .core import unit
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-11-04 09:01:18 -08:00
|
|
|
from ._src import traceback_util
|
2020-10-26 10:03:06 -07:00
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-12-14 14:52:51 -08:00
|
|
|
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
|
|
|
|
"""Ensure x is either an index or a tuple of indices."""
|
|
|
|
try:
|
|
|
|
return operator.index(x)
|
|
|
|
except TypeError:
|
|
|
|
return tuple(map(operator.index, x))
|
|
|
|
|
|
|
|
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
|
|
|
|
"""Convert x to a tuple of indices."""
|
|
|
|
try:
|
|
|
|
return (operator.index(x),)
|
|
|
|
except TypeError:
|
|
|
|
return tuple(map(operator.index, x))
|
2019-01-03 16:14:30 -08:00
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
def _ensure_str(x: str) -> str:
|
|
|
|
if not isinstance(x, str):
|
|
|
|
raise TypeError(f"argument is not a string: {x}")
|
|
|
|
return x
|
|
|
|
|
|
|
|
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
|
|
|
|
"""Convert x to a tuple of strings."""
|
|
|
|
if isinstance(x, str):
|
|
|
|
return (x,)
|
|
|
|
else:
|
|
|
|
return tuple(map(_ensure_str, x))
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-05-17 07:36:52 -07:00
|
|
|
def flatten_fun(in_tree, *args_flat):
|
|
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans = yield py_args, py_kwargs
|
2019-07-26 16:48:17 -04:00
|
|
|
yield tree_flatten(ans)
|
2019-07-25 12:41:11 -07:00
|
|
|
|
2019-07-26 23:17:21 -04:00
|
|
|
def apply_flat_fun(fun, io_tree, *py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten((py_args, {}))
|
|
|
|
if in_tree != in_tree_expected:
|
2020-10-20 22:22:33 +02:00
|
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
2019-07-26 23:17:21 -04:00
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def flatten_fun_nokwargs(in_tree, *args_flat):
|
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
|
|
|
ans = yield py_args, {}
|
|
|
|
yield tree_flatten(ans)
|
|
|
|
|
|
|
|
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
|
|
|
|
in_tree_expected, out_tree = io_tree
|
|
|
|
args, in_tree = tree_flatten(py_args)
|
|
|
|
if in_tree != in_tree_expected:
|
2020-10-20 22:22:33 +02:00
|
|
|
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
|
2019-07-27 15:46:14 -07:00
|
|
|
ans = fun(*args)
|
|
|
|
return tree_unflatten(out_tree, ans)
|
|
|
|
|
2021-05-01 12:28:12 -07:00
|
|
|
PyTreeDef = Any
|
|
|
|
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
|
|
|
|
# This implementation relies on internal details of linear_util.py's
|
|
|
|
# WrappedFun, but it's for the worthy cause of better user error messages.
|
|
|
|
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
|
|
|
|
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
|
|
|
|
# core.eval_jaxpr encounters a call primitive (though at that point we're just
|
|
|
|
# round-tripping jaxprs and the user errors in question are impossible).
|
|
|
|
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
|
|
|
|
assert (isinstance(flatten_fun_nokwargs, partial) and
|
|
|
|
len(flatten_fun_nokwargs.args) == 1)
|
|
|
|
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
|
|
|
|
try:
|
|
|
|
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
|
|
|
|
for f, args in fn.transforms if f in flat_xforms)
|
|
|
|
except ValueError:
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
return in_tree, has_kwargs
|
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2019-07-27 15:46:14 -07:00
|
|
|
def flatten_fun_nokwargs2(in_tree, *args_flat):
|
|
|
|
py_args = tree_unflatten(in_tree, args_flat)
|
2021-02-18 09:46:16 -08:00
|
|
|
pair = yield py_args, {}
|
|
|
|
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
|
|
raise TypeError("expected function with aux output to return a two-element "
|
|
|
|
f"tuple, but got type {type(pair)} with value {repr(pair)}")
|
|
|
|
ans, aux = pair
|
2019-07-27 15:46:14 -07:00
|
|
|
ans_flat, ans_tree = tree_flatten(ans)
|
|
|
|
aux_flat, aux_tree = tree_flatten(aux)
|
|
|
|
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
2020-10-20 22:22:33 +02:00
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def argnums_partial(f, dyn_argnums, args):
|
2020-12-14 14:52:51 -08:00
|
|
|
dyn_argnums = _ensure_index_tuple(dyn_argnums)
|
2021-03-29 13:52:39 -07:00
|
|
|
fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg)
|
|
|
|
for i, arg in enumerate(args))
|
2020-09-18 19:41:53 -07:00
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
|
2020-09-18 17:39:05 -07:00
|
|
|
|
2020-10-20 22:22:33 +02:00
|
|
|
|
|
|
|
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
|
2021-03-29 13:52:39 -07:00
|
|
|
args: Tuple[Any], *, allow_invalid: bool):
|
2020-10-20 22:22:33 +02:00
|
|
|
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
|
2021-03-29 13:52:39 -07:00
|
|
|
if not static_argnums:
|
|
|
|
return f, args
|
2020-10-20 22:22:33 +02:00
|
|
|
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
|
|
|
|
dyn_args = tuple(args[i] for i in dyn_argnums)
|
|
|
|
|
|
|
|
fixed_args = [unit] * len(args) # type: ignore
|
|
|
|
for i in static_argnums:
|
2021-03-29 13:52:39 -07:00
|
|
|
# TODO(shoyer): set allow_invalid=True permanently after enabling
|
|
|
|
# static_argnames.
|
|
|
|
if allow_invalid and i >= len(args):
|
|
|
|
continue
|
2020-10-20 22:22:33 +02:00
|
|
|
static_arg = args[i]
|
|
|
|
try:
|
|
|
|
hash(static_arg)
|
|
|
|
except TypeError:
|
Raise an error on non-hashable static arguments for jax.jit and xla_computation.
Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`:
```
class WrapHashably(object):
__slots__ = ["val"]
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return self.val is other.val
```
This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with:
@partial(jax.jit, static_argnums=(1,))
def sum(a, b):
return a+ b
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
# The next line will recompile, because the 1-indexed argument is non
# hashable and thus compared by identity with different instances
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
or more simply
np.pad(a, [2, 3], 'constant', constant_values=(4, 6))
^^^^^^
non-hashable static argument.
The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about).
If this commit breaks you, you usually have one of the following options:
- If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static.
- When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset.
- You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function.
PiperOrigin-RevId: 339351798
2020-10-27 16:11:41 -07:00
|
|
|
raise ValueError(
|
|
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
|
|
f"to unexpected cache-misses. Static argument (index {i}) of type "
|
|
|
|
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
|
2020-10-20 22:22:33 +02:00
|
|
|
else:
|
|
|
|
fixed_args[i] = Hashable(static_arg) # type: ignore
|
|
|
|
|
|
|
|
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
|
|
|
|
|
|
|
|
2020-09-18 17:39:05 -07:00
|
|
|
@lu.transformation
|
|
|
|
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
|
|
|
args = [None if arg is unit else arg.val for arg in fixed_args]
|
|
|
|
for i, arg in zip(dyn_argnums, dyn_args):
|
|
|
|
args[i] = arg
|
|
|
|
ans = yield args, kwargs
|
|
|
|
yield ans
|
|
|
|
|
2021-03-29 13:52:39 -07:00
|
|
|
|
|
|
|
def argnames_partial(f, dyn_argnames, kwargs):
|
|
|
|
dyn_argnames = _ensure_str_tuple(dyn_argnames)
|
|
|
|
fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v))
|
|
|
|
for k, v in kwargs.items())
|
|
|
|
dyn_kwargs = {k: kwargs[k] for k in dyn_argnames}
|
|
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
|
|
|
|
kwargs: Dict[str, Any]):
|
|
|
|
if not static_argnames:
|
|
|
|
return f, kwargs
|
|
|
|
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
|
|
|
|
|
|
|
|
fixed_kwargs: Dict[str, Any] = {}
|
|
|
|
for k, arg in kwargs.items():
|
|
|
|
if k in dyn_kwargs:
|
|
|
|
fixed_kwargs[k] = unit
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
hash(arg)
|
|
|
|
except TypeError:
|
|
|
|
raise ValueError(
|
|
|
|
"Non-hashable static arguments are not supported, as this can lead "
|
|
|
|
f"to unexpected cache-misses. Static argument (name {k}) of type "
|
|
|
|
f"{type(arg)} for function {f.__name__} is non-hashable.")
|
|
|
|
else:
|
|
|
|
fixed_kwargs[k] = Hashable(arg) # type: ignore
|
|
|
|
|
|
|
|
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
|
|
|
kwargs = {k: None if arg is unit else arg.val
|
|
|
|
for k, arg in fixed_kwargs.val.items()}
|
|
|
|
kwargs.update(dyn_kwargs)
|
|
|
|
ans = yield args, kwargs
|
|
|
|
yield ans
|
|
|
|
|
|
|
|
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
|
|
|
|
"""Returns a tuple with a boolean value for each leaf in args."""
|
|
|
|
res = []
|
|
|
|
for i, arg in enumerate(args):
|
|
|
|
donate = bool(i in donate_argnums)
|
|
|
|
res.extend((donate,) * tree_structure(arg).num_leaves)
|
|
|
|
res.extend((False,) * tree_structure(kwargs).num_leaves)
|
|
|
|
return tuple(res)
|
|
|
|
|
|
|
|
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
|
|
|
|
"""Shifts donate to account for static.
|
|
|
|
|
|
|
|
>>> rebase_donate_argnums((3, 4), (0, 1))
|
|
|
|
(1, 2)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
donate_argnums: An iterable of ints.
|
|
|
|
static_argnums: An iterable of ints.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of unique, sorted integer values based on donate_argnums with each
|
|
|
|
element offset to account for static_argnums.
|
|
|
|
"""
|
|
|
|
if not (static_argnums or donate_argnums):
|
|
|
|
return tuple(sorted(donate_argnums))
|
|
|
|
|
|
|
|
static_argnums = sorted(set(static_argnums))
|
|
|
|
donate_argnums = sorted(set(donate_argnums))
|
|
|
|
i = j = o = 0
|
|
|
|
out = []
|
|
|
|
while j < len(donate_argnums):
|
|
|
|
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
|
|
|
|
raise ValueError(f"`static_argnums` {static_argnums} and "
|
|
|
|
f"`donate_argnums` {donate_argnums} cannot intersect.")
|
|
|
|
|
|
|
|
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
|
|
|
|
o += 1
|
|
|
|
i += 1
|
|
|
|
else:
|
|
|
|
out.append(donate_argnums[j] - o)
|
|
|
|
j += 1
|
|
|
|
return tuple(out)
|
|
|
|
|
2020-01-15 15:00:38 -08:00
|
|
|
def wrap_hashably(arg):
|
|
|
|
try:
|
|
|
|
hash(arg)
|
|
|
|
except TypeError:
|
|
|
|
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
|
|
|
|
else:
|
|
|
|
return Hashable(arg)
|
|
|
|
|
2021-01-12 19:37:19 -08:00
|
|
|
def flatten_axes(name, treedef, axis_tree, *, kws=False):
|
2020-05-11 11:04:57 -07:00
|
|
|
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
|
|
|
|
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
|
|
|
|
# the given treedef, build a complete axis spec tree with the same structure
|
|
|
|
# and return the flattened result
|
|
|
|
# TODO(mattjj,phawkins): improve this implementation
|
|
|
|
proxy = object()
|
|
|
|
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
|
|
|
|
axes = []
|
|
|
|
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
|
|
|
|
try:
|
|
|
|
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
|
2020-06-30 22:19:16 -07:00
|
|
|
except ValueError:
|
2021-01-12 19:37:19 -08:00
|
|
|
if kws:
|
|
|
|
# if keyword arguments are included in the tree, we make adapt the error
|
|
|
|
# message only to be about the positional arguments
|
|
|
|
treedef, leaf = treedef_children(treedef)
|
|
|
|
assert treedef_is_leaf(leaf)
|
|
|
|
axis_tree, _ = axis_tree
|
2020-06-30 22:19:16 -07:00
|
|
|
raise ValueError(f"{name} specification must be a tree prefix of the "
|
|
|
|
f"corresponding value, got specification {axis_tree} "
|
|
|
|
f"for value tree {treedef}.") from None
|
2020-05-11 11:04:57 -07:00
|
|
|
axes = [None if a is proxy else a for a in axes]
|
|
|
|
assert len(axes) == treedef.num_leaves
|
|
|
|
return axes
|
2021-02-10 16:22:29 -08:00
|
|
|
|
|
|
|
def _dtype(x):
|
|
|
|
try:
|
|
|
|
return dtypes.result_type(x)
|
|
|
|
except ValueError:
|
|
|
|
return dtypes.result_type(getattr(x, 'dtype'))
|
|
|
|
|
|
|
|
def shaped_abstractify(x):
|
|
|
|
try:
|
|
|
|
return core.raise_to_shaped(core.get_aval(x))
|
|
|
|
except TypeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
weak_type = getattr(x, 'weak_type', False)
|
|
|
|
named_shape = getattr(x, 'named_shape', {})
|
|
|
|
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
|
|
|
named_shape=named_shape)
|
2021-04-01 00:05:00 -07:00
|
|
|
|
|
|
|
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
|
|
|
# By default it does nothing, but it can be monkey-patched to do other things.
|
|
|
|
def api_hook(fun, tag: str):
|
|
|
|
return fun
|