rocm_jax/jax/api_util.py
Tom Hennigan 6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00

166 lines
5.7 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tree_util import (build_tree, tree_flatten, tree_unflatten,
treedef_is_leaf, tree_multimap, _replace_nones,
tree_structure)
from . import linear_util as lu
from .util import safe_map, unzip2, partial, curry, WrapHashably, Hashable
from .core import unit
from typing import Tuple
map = safe_map
@curry
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
try:
fun.__name__ = namestr.format(fun=get_name(wrapped))
fun.__module__ = get_module(wrapped)
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
fun.__wrapped__ = wrapped
finally:
return fun
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
def get_doc(fun): return getattr(fun, "__doc__", "")
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans, aux = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args):
if isinstance(dyn_argnums, int):
dyn_argnums = (dyn_argnums,)
else:
dyn_argnums = tuple(dyn_argnums)
fixed_args = tuple([unit if i in dyn_argnums else wrap_hashably(arg)
for i, arg in enumerate(args)])
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
res = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
res.extend((False,) * tree_structure(kwargs).num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
"""Shifts donate to account for static.
>>> rebase_donate_argnums((3, 4), (0, 1))
(1, 2)
Args:
donate_argnums: An iterable of ints.
static_argnums: An iterable of ints.
Returns:
A tuple of unique, sorted integer values based on donate_argnums with each
element offset to account for static_argnums.
"""
if not (static_argnums or donate_argnums):
return tuple(sorted(donate_argnums))
static_argnums = sorted(set(static_argnums))
donate_argnums = sorted(set(donate_argnums))
i = j = o = 0
out = []
while j < len(donate_argnums):
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
raise ValueError(f"`static_argnums` {static_argnums} and "
f"`donate_argnums` {donate_argnums} cannot intersect.")
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
o += 1
i += 1
else:
out.append(donate_argnums[j] - o)
j += 1
return tuple(out)
def wrap_hashably(arg):
try:
hash(arg)
except TypeError:
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
else:
return Hashable(arg)
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args, kwargs
yield ans
def flatten_axes(treedef, axis_tree):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError as e:
msg = ("axes specification must be a tree prefix of the corresponding "
"value, got specification {} for value {}.")
raise ValueError(msg.format(axis_tree, treedef)) from e
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes