rocm_jax/jax/api_util.py

306 lines
11 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from functools import partial
from typing import Any, Dict, Iterable, Tuple, Union, Optional
import numpy as np
from . import core
from ._src import dtypes
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap,
tree_structure, treedef_children, treedef_is_leaf)
from ._src.tree_util import _replace_nones
from . import linear_util as lu
2021-03-29 13:52:39 -07:00
from ._src.util import safe_map, WrapHashably, WrapKwArgs, Hashable
from .core import unit
2018-11-17 18:03:33 -08:00
from ._src import traceback_util
traceback_util.register_exclusion(__file__)
map = safe_map
2018-11-17 18:03:33 -08:00
def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
"""Ensure x is either an index or a tuple of indices."""
try:
return operator.index(x)
except TypeError:
return tuple(map(operator.index, x))
def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
"""Convert x to a tuple of indices."""
try:
return (operator.index(x),)
except TypeError:
return tuple(map(operator.index, x))
2021-03-29 13:52:39 -07:00
def _ensure_str(x: str) -> str:
if not isinstance(x, str):
raise TypeError(f"argument is not a string: {x}")
return x
def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
"""Convert x to a tuple of strings."""
if isinstance(x, str):
return (x,)
else:
return tuple(map(_ensure_str, x))
@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
2019-07-26 23:17:21 -04:00
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
2019-07-26 23:17:21 -04:00
ans = fun(*args)
return tree_unflatten(out_tree, ans)
@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
ans = fun(*args)
return tree_unflatten(out_tree, ans)
PyTreeDef = Any
def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
# This implementation relies on internal details of linear_util.py's
# WrappedFun, but it's for the worthy cause of better user error messages.
# It can fail (i.e. return None) if its WrappedFun argument is not transformed
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
try:
(in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
for f, args in fn.transforms if f in flat_xforms)
except ValueError:
return None
else:
return in_tree, has_kwargs
@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
pair = yield py_args, {}
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {repr(pair)}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
def argnums_partial(f, dyn_argnums, args):
dyn_argnums = _ensure_index_tuple(dyn_argnums)
2021-03-29 13:52:39 -07:00
fixed_args = tuple(unit if i in dyn_argnums else wrap_hashably(arg)
for i, arg in enumerate(args))
2020-09-18 19:41:53 -07:00
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, fixed_args), dyn_args
def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
2021-03-29 13:52:39 -07:00
args: Tuple[Any], *, allow_invalid: bool):
"""Version of ``argnums_partial`` that checks hashability of static_argnums."""
2021-03-29 13:52:39 -07:00
if not static_argnums:
return f, args
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
dyn_args = tuple(args[i] for i in dyn_argnums)
fixed_args = [unit] * len(args) # type: ignore
for i in static_argnums:
2021-03-29 13:52:39 -07:00
# TODO(shoyer): set allow_invalid=True permanently after enabling
# static_argnames.
if allow_invalid and i >= len(args):
continue
static_arg = args[i]
try:
hash(static_arg)
except TypeError:
Raise an error on non-hashable static arguments for jax.jit and xla_computation. Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`: ``` class WrapHashably(object): __slots__ = ["val"] def __init__(self, val): self.val = val def __hash__(self): return id(self.val) def __eq__(self, other): return self.val is other.val ``` This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with: @partial(jax.jit, static_argnums=(1,)) def sum(a, b): return a+ b sum(np.asarray([1,2,3]), np.asarray([4,5,6]) # The next line will recompile, because the 1-indexed argument is non # hashable and thus compared by identity with different instances sum(np.asarray([1,2,3]), np.asarray([4,5,6]) or more simply np.pad(a, [2, 3], 'constant', constant_values=(4, 6)) ^^^^^^ non-hashable static argument. The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about). If this commit breaks you, you usually have one of the following options: - If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static. - When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset. - You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function. PiperOrigin-RevId: 339351798
2020-10-27 16:11:41 -07:00
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
else:
fixed_args[i] = Hashable(static_arg) # type: ignore
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
ans = yield args, kwargs
yield ans
2021-03-29 13:52:39 -07:00
def argnames_partial(f, dyn_argnames, kwargs):
dyn_argnames = _ensure_str_tuple(dyn_argnames)
fixed_kwargs = tuple((k, unit if k in dyn_argnames else wrap_hashably(v))
for k, v in kwargs.items())
dyn_kwargs = {k: kwargs[k] for k in dyn_argnames}
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
kwargs: Dict[str, Any]):
if not static_argnames:
return f, kwargs
dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}
fixed_kwargs: Dict[str, Any] = {}
for k, arg in kwargs.items():
if k in dyn_kwargs:
fixed_kwargs[k] = unit
else:
try:
hash(arg)
except TypeError:
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (name {k}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
else:
fixed_kwargs[k] = Hashable(arg) # type: ignore
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = {k: None if arg is unit else arg.val
for k, arg in fixed_kwargs.val.items()}
kwargs.update(dyn_kwargs)
ans = yield args, kwargs
yield ans
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
res = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
res.extend((False,) * tree_structure(kwargs).num_leaves)
return tuple(res)
def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
"""Shifts donate to account for static.
>>> rebase_donate_argnums((3, 4), (0, 1))
(1, 2)
Args:
donate_argnums: An iterable of ints.
static_argnums: An iterable of ints.
Returns:
A tuple of unique, sorted integer values based on donate_argnums with each
element offset to account for static_argnums.
"""
if not (static_argnums or donate_argnums):
return tuple(sorted(donate_argnums))
static_argnums = sorted(set(static_argnums))
donate_argnums = sorted(set(donate_argnums))
i = j = o = 0
out = []
while j < len(donate_argnums):
if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
raise ValueError(f"`static_argnums` {static_argnums} and "
f"`donate_argnums` {donate_argnums} cannot intersect.")
if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
o += 1
i += 1
else:
out.append(donate_argnums[j] - o)
j += 1
return tuple(out)
def wrap_hashably(arg):
try:
hash(arg)
except TypeError:
return WrapHashably(arg) # e.g. ndarrays, DeviceArrays
else:
return Hashable(arg)
def flatten_axes(name, treedef, axis_tree, *, kws=False):
# given an axis spec tree axis_tree (a pytree with integers and Nones at the
# leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
# the given treedef, build a complete axis spec tree with the same structure
# and return the flattened result
# TODO(mattjj,phawkins): improve this implementation
proxy = object()
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
axes = []
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError:
if kws:
# if keyword arguments are included in the tree, we make adapt the error
# message only to be about the positional arguments
treedef, leaf = treedef_children(treedef)
assert treedef_is_leaf(leaf)
axis_tree, _ = axis_tree
raise ValueError(f"{name} specification must be a tree prefix of the "
f"corresponding value, got specification {axis_tree} "
f"for value tree {treedef}.") from None
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes
def _dtype(x):
try:
return dtypes.result_type(x)
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def shaped_abstractify(x):
try:
return core.raise_to_shaped(core.get_aval(x))
except TypeError:
pass
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
named_shape=named_shape)
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.
def api_hook(fun, tag: str):
return fun