mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move jax.linear_util to jax._src.linear_util
This commit is contained in:
parent
d4fa1a4dfb
commit
4a6bbde409
@ -149,6 +149,7 @@ from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import image as image
|
||||
from jax import lax as lax
|
||||
from jax import linear_util as linear_util
|
||||
from jax import nn as nn
|
||||
from jax import numpy as numpy
|
||||
from jax import ops as ops
|
||||
|
@ -18,7 +18,7 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
import types
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
@ -34,7 +34,7 @@ import numpy as np
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
|
@ -27,7 +27,7 @@ from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
|
||||
treedef_children, treedef_is_leaf)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
|
||||
|
||||
from jax._src import traceback_util
|
||||
|
@ -20,7 +20,7 @@ from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, I
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
|
@ -40,7 +40,7 @@ from jax._src import config as jax_config
|
||||
from jax._src.config import FLAGS, config
|
||||
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
|
||||
|
@ -17,7 +17,7 @@ import operator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -17,7 +17,7 @@ import inspect
|
||||
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
|
||||
Any)
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
|
@ -15,7 +15,7 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
@ -22,7 +22,7 @@ from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
|
@ -32,7 +32,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
import jax.interpreters.ad as ad
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
from typing import Callable, Optional, Sequence, Set
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax
|
||||
|
@ -22,7 +22,7 @@ import operator
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.core import ConcreteArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple,
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
|
||||
import jax
|
||||
import weakref
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
|
@ -19,7 +19,7 @@ import operator
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.core import raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -32,7 +32,7 @@ from jax._src import api_util
|
||||
from jax._src import array
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax._src import source_info_util
|
||||
|
346
jax/_src/linear_util.py
Normal file
346
jax/_src/linear_util.py
Normal file
@ -0,0 +1,346 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utilities for defining functions composed with transformations.
|
||||
|
||||
For example,
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
|
||||
|
||||
A `WrappedFun` object represents a function `f`, together with a sequence of
|
||||
nested transformations that are to be applied to the positional and keyword
|
||||
arguments at call time and function return values at return time.
|
||||
A transformation can take some static positional arguments that are given
|
||||
at the wrapping time, and may also return some auxiliary output:
|
||||
|
||||
wf, aux_out_thunk = trans1(wf, static_arg)
|
||||
|
||||
We can call the transformed function. First, the transformation is applied
|
||||
to the dynamic args and keyword args to produce new dynamic and keyword args.
|
||||
Then the underlying function is called and the transformation is applied to
|
||||
the results.
|
||||
If there are multiple transformations, they form a stack. The arguments are
|
||||
transformed first with the last applied transformation; the results are
|
||||
transformed first with the first applied transformation.
|
||||
|
||||
res = wf.call_wrapped(dynamic_args, kwargs)
|
||||
# Now `aux_out_thunk()` is the auxiliary output.
|
||||
|
||||
A transformation is written as a generator function that takes zero or more
|
||||
static positional arguments (given when the transformation is instantiated),
|
||||
along with positional and keyword arguments to be transformed.
|
||||
The generator will yield twice:
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def trans1(static_arg, *dynamic_args, **kwargs):
|
||||
...
|
||||
# First yield: pair of transformed (args, kwargs). Get back the results.
|
||||
results = yield (new_dynamic_args, new_kwargs)
|
||||
...
|
||||
# Second yield: pair of (transformed results, and auxiliary output)
|
||||
yield new_results, auxiliary_output
|
||||
|
||||
|
||||
`WrappedFun` objects explicitly represent the set of transformations so that
|
||||
they can be used as dictionary keys for memoization. `WrappedFun` objects
|
||||
compare as equal only if they compute the same function. The static and the
|
||||
dynamic positional arguments for the generators, and also the auxiliary output
|
||||
data must be immutable, because it will be stored in function memoization tables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
import weakref
|
||||
|
||||
from jax.tree_util import tree_map
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.util import curry
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
class StoreException(Exception): pass
|
||||
|
||||
|
||||
class EmptyStoreValue: pass
|
||||
_EMPTY_STORE_VALUE = EmptyStoreValue()
|
||||
|
||||
class Store:
|
||||
"""Storage for a value, with checks for overwriting or reading empty store."""
|
||||
__slots__ = ("_val",)
|
||||
|
||||
def __init__(self):
|
||||
self._val = _EMPTY_STORE_VALUE
|
||||
|
||||
def store(self, val):
|
||||
if self._val is not _EMPTY_STORE_VALUE:
|
||||
raise StoreException("Store occupied")
|
||||
self._val = val
|
||||
|
||||
def reset(self):
|
||||
# This should only be called in exceptional circumstances (e.g. debugging).
|
||||
self._val = _EMPTY_STORE_VALUE
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
if not self:
|
||||
raise StoreException("Store empty")
|
||||
return self._val
|
||||
|
||||
def __nonzero__(self):
|
||||
return self._val is not _EMPTY_STORE_VALUE
|
||||
|
||||
__bool__ = __nonzero__
|
||||
|
||||
|
||||
class WrappedFun:
|
||||
"""Represents a function `f` to which `transforms` are to be applied.
|
||||
|
||||
Args:
|
||||
f: the function to be transformed.
|
||||
transforms: a list of `(gen, gen_static_args)` tuples representing
|
||||
transformations to apply to `f.` Here `gen` is a generator function and
|
||||
`gen_static_args` is a tuple of static arguments for the generator. See
|
||||
description at the start of this module for the expected behavior of the
|
||||
generator.
|
||||
stores: a list of out_store for the auxiliary output of the `transforms`.
|
||||
params: extra parameters to pass as keyword arguments to `f`, along with the
|
||||
transformed keyword arguments.
|
||||
"""
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type")
|
||||
|
||||
def __init__(self, f, transforms, stores, params, in_type):
|
||||
self.f = f
|
||||
self.transforms = transforms
|
||||
self.stores = stores
|
||||
self.params = params
|
||||
self.in_type = in_type
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return getattr(self.f, '__name__', '<unnamed wrapped function>')
|
||||
|
||||
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
|
||||
"""Add another transform and its store."""
|
||||
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
for self_store, other_store in zip(self.stores, stores):
|
||||
if self_store is not None:
|
||||
self_store.store(other_store.val)
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
"""Calls the underlying function, applying the transforms.
|
||||
|
||||
The positional `args` and keyword `kwargs` are passed to the first
|
||||
transformation generator.
|
||||
"""
|
||||
stack = []
|
||||
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
|
||||
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
gen = gen_static_args = out_store = None
|
||||
|
||||
try:
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
except:
|
||||
# Some transformations yield from inside context managers, so we have to
|
||||
# interrupt them before reraising the exception. Otherwise they will only
|
||||
# get garbage-collected at some later time, running their cleanup tasks
|
||||
# only after this exception is handled, which can corrupt the global
|
||||
# state.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
|
||||
args = kwargs = None
|
||||
while stack:
|
||||
gen, out_store = stack.pop()
|
||||
try:
|
||||
ans = gen.send(ans)
|
||||
except:
|
||||
# As above does for the first half of the transformation, exceptions
|
||||
# raised in the second half of the transformation also require us to
|
||||
# clean up references here.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
if out_store is not None:
|
||||
ans, side = ans
|
||||
out_store.store(side)
|
||||
|
||||
return ans
|
||||
|
||||
def __repr__(self):
|
||||
def transform_to_str(x):
|
||||
i, (gen, args) = x
|
||||
return f"{i} : {fun_name(gen)} {fun_name(args)}"
|
||||
transformation_stack = map(transform_to_str, enumerate(self.transforms))
|
||||
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.f, self.transforms, self.params, self.in_type))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.f == other.f and self.transforms == other.transforms and
|
||||
self.params == other.params and self.in_type == other.in_type)
|
||||
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
"""Adds one more transformation to a WrappedFun.
|
||||
Args:
|
||||
gen: the transformation generator function
|
||||
fun: a WrappedFun on which to apply the transformation
|
||||
gen_static_args: static args for the generator function
|
||||
"""
|
||||
return fun.wrap(gen, gen_static_args, None)
|
||||
|
||||
@curry
|
||||
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]:
|
||||
"""Adds one more transformation with auxiliary output to a WrappedFun."""
|
||||
out_store = Store()
|
||||
out_thunk = lambda: out_store.val
|
||||
return fun.wrap(gen, gen_static_args, out_store), out_thunk
|
||||
|
||||
def fun_name(f):
|
||||
try:
|
||||
return f.__name__
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
def wrap_init(f, params=None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None)
|
||||
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
return f
|
||||
_check_input_type(in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
|
||||
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
|
||||
|
||||
def valid_size(d) -> bool:
|
||||
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
|
||||
return True
|
||||
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
|
||||
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
|
||||
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
|
||||
for d in a.shape)
|
||||
|
||||
# Check that all DBIdx point to positions to the left of the input on which
|
||||
# they appear.
|
||||
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
|
||||
if isinstance(aval, core.DShapedArray) for d in aval.shape
|
||||
if isinstance(d, core.DBIdx))
|
||||
|
||||
# Check that all implicit arguments have at least one DBIdx pointing to them.
|
||||
provided = [e for _, e in in_type]
|
||||
for aval, _ in in_type:
|
||||
if type(aval) is core.DShapedArray:
|
||||
for d in aval.shape:
|
||||
if isinstance(d, core.DBIdx):
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
|
||||
|
||||
def cache(call: Callable):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
Args:
|
||||
call: a Python callable that takes a WrappedFun as its first argument. The
|
||||
underlying transforms and params on the WrappedFun are used as part of the
|
||||
memoization cache key.
|
||||
|
||||
Returns:
|
||||
A memoized version of ``call``.
|
||||
"""
|
||||
fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, {})
|
||||
if config.jax_check_tracer_leaks:
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
|
||||
config.x64_enabled, config.jax_default_device,
|
||||
config._trace_context())
|
||||
else:
|
||||
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
|
||||
config.jax_default_device, config._trace_context())
|
||||
result = cache.get(key, None)
|
||||
if result is not None:
|
||||
ans, stores = result
|
||||
fun.populate_stores(stores)
|
||||
else:
|
||||
ans = call(fun, *args)
|
||||
cache[key] = (ans, fun.stores)
|
||||
|
||||
return ans
|
||||
|
||||
def _evict_function(f):
|
||||
fun_caches.pop(f, None)
|
||||
|
||||
memoized_fun.cache_clear = fun_caches.clear # type: ignore
|
||||
memoized_fun.evict_function = _evict_function # type: ignore
|
||||
|
||||
return memoized_fun
|
||||
|
||||
@partial(partial, tree_map)
|
||||
def _copy_main_traces(x):
|
||||
if isinstance(x, core.MainTrace):
|
||||
return core.MainTrace(x.level, x.trace_type, **x.payload)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
@transformation
|
||||
def hashable_partial(x, *args):
|
||||
ans = yield (x,) + args, {}
|
||||
yield ans
|
||||
|
||||
|
||||
def merge_linear_aux(aux1, aux2):
|
||||
try:
|
||||
out1 = aux1()
|
||||
except StoreException:
|
||||
# store 1 was not occupied, so store 2 better be
|
||||
try:
|
||||
out2 = aux2()
|
||||
except StoreException:
|
||||
raise StoreException("neither store occupied") from None
|
||||
else:
|
||||
return False, out2
|
||||
else:
|
||||
# store 1 was occupied, so let's check store 2 is not occupied
|
||||
try:
|
||||
out2 = aux2()
|
||||
except StoreException:
|
||||
return True, out1
|
||||
else:
|
||||
raise StoreException("both stores occupied")
|
@ -29,7 +29,7 @@ from jax._src.sharding import (
|
||||
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding)
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import array
|
||||
from jax._src.api import (_check_callable, _check_arg, FLAGS, _resolve_argnums,
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
|
||||
|
@ -24,7 +24,7 @@ from jax.core import Trace, Tracer, jaxpr_as_fun
|
||||
from jax import lax
|
||||
from jax import custom_derivatives as cd
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import safe_map, wraps, split_list
|
||||
from jax._src.lax import control_flow as lcf
|
||||
|
||||
|
@ -17,7 +17,7 @@ from typing import Any, Callable, Tuple
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.experimental import pjit
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -25,7 +25,7 @@ import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax import core, custom_derivatives
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import random, tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import maps
|
||||
|
@ -62,7 +62,6 @@ import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import xla
|
||||
import jax.linear_util as lu
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten)
|
||||
@ -70,6 +69,7 @@ from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
from jax._src import ad_util
|
||||
from jax._src import dispatch
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import unzip2
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from enum import Enum
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src.api import _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
|
@ -38,7 +38,7 @@ from jax._src.numpy.util import _promote_dtypes_inexact
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.flatten_util import ravel_pytree
|
||||
from jax.tree_util import tree_leaves, tree_map
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
@ -54,7 +54,7 @@ import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
|
||||
import jax.numpy as jnp
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
|
@ -19,7 +19,7 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
|
@ -30,7 +30,7 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
|
@ -29,7 +29,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
@ -26,7 +26,7 @@ from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
|
@ -47,7 +47,7 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -12,335 +12,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utilities for defining functions composed with transformations.
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
For example,
|
||||
# TODO(jakevdp): deprecate these and remove this module.
|
||||
|
||||
from jax import linear_util as lu
|
||||
|
||||
wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
|
||||
|
||||
A `WrappedFun` object represents a function `f`, together with a sequence of
|
||||
nested transformations that are to be applied to the positional and keyword
|
||||
arguments at call time and function return values at return time.
|
||||
A transformation can take some static positional arguments that are given
|
||||
at the wrapping time, and may also return some auxiliary output:
|
||||
|
||||
wf, aux_out_thunk = trans1(wf, static_arg)
|
||||
|
||||
We can call the transformed function. First, the transformation is applied
|
||||
to the dynamic args and keyword args to produce new dynamic and keyword args.
|
||||
Then the underlying function is called and the transformation is applied to
|
||||
the results.
|
||||
If there are multiple transformations, they form a stack. The arguments are
|
||||
transformed first with the last applied transformation; the results are
|
||||
transformed first with the first applied transformation.
|
||||
|
||||
res = wf.call_wrapped(dynamic_args, kwargs)
|
||||
# Now `aux_out_thunk()` is the auxiliary output.
|
||||
|
||||
A transformation is written as a generator function that takes zero or more
|
||||
static positional arguments (given when the transformation is instantiated),
|
||||
along with positional and keyword arguments to be transformed.
|
||||
The generator will yield twice:
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def trans1(static_arg, *dynamic_args, **kwargs):
|
||||
...
|
||||
# First yield: pair of transformed (args, kwargs). Get back the results.
|
||||
results = yield (new_dynamic_args, new_kwargs)
|
||||
...
|
||||
# Second yield: pair of (transformed results, and auxiliary output)
|
||||
yield new_results, auxiliary_output
|
||||
|
||||
|
||||
`WrappedFun` objects explicitly represent the set of transformations so that
|
||||
they can be used as dictionary keys for memoization. `WrappedFun` objects
|
||||
compare as equal only if they compute the same function. The static and the
|
||||
dynamic positional arguments for the generators, and also the auxiliary output
|
||||
data must be immutable, because it will be stored in function memoization tables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
import weakref
|
||||
|
||||
from jax.tree_util import tree_map
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.util import curry
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
class StoreException(Exception): pass
|
||||
|
||||
|
||||
class EmptyStoreValue: pass
|
||||
_EMPTY_STORE_VALUE = EmptyStoreValue()
|
||||
|
||||
class Store:
|
||||
"""Storage for a value, with checks for overwriting or reading empty store."""
|
||||
__slots__ = ("_val",)
|
||||
|
||||
def __init__(self):
|
||||
self._val = _EMPTY_STORE_VALUE
|
||||
|
||||
def store(self, val):
|
||||
if self._val is not _EMPTY_STORE_VALUE:
|
||||
raise StoreException("Store occupied")
|
||||
self._val = val
|
||||
|
||||
def reset(self):
|
||||
# This should only be called in exceptional circumstances (e.g. debugging).
|
||||
self._val = _EMPTY_STORE_VALUE
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
if not self:
|
||||
raise StoreException("Store empty")
|
||||
return self._val
|
||||
|
||||
def __nonzero__(self):
|
||||
return self._val is not _EMPTY_STORE_VALUE
|
||||
|
||||
__bool__ = __nonzero__
|
||||
|
||||
|
||||
class WrappedFun:
|
||||
"""Represents a function `f` to which `transforms` are to be applied.
|
||||
|
||||
Args:
|
||||
f: the function to be transformed.
|
||||
transforms: a list of `(gen, gen_static_args)` tuples representing
|
||||
transformations to apply to `f.` Here `gen` is a generator function and
|
||||
`gen_static_args` is a tuple of static arguments for the generator. See
|
||||
description at the start of this module for the expected behavior of the
|
||||
generator.
|
||||
stores: a list of out_store for the auxiliary output of the `transforms`.
|
||||
params: extra parameters to pass as keyword arguments to `f`, along with the
|
||||
transformed keyword arguments.
|
||||
"""
|
||||
__slots__ = ("f", "transforms", "stores", "params", "in_type")
|
||||
|
||||
def __init__(self, f, transforms, stores, params, in_type):
|
||||
self.f = f
|
||||
self.transforms = transforms
|
||||
self.stores = stores
|
||||
self.params = params
|
||||
self.in_type = in_type
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return getattr(self.f, '__name__', '<unnamed wrapped function>')
|
||||
|
||||
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
|
||||
"""Add another transform and its store."""
|
||||
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
|
||||
(out_store,) + self.stores, self.params, None)
|
||||
|
||||
def populate_stores(self, stores):
|
||||
"""Copy the values from the `stores` into `self.stores`."""
|
||||
for self_store, other_store in zip(self.stores, stores):
|
||||
if self_store is not None:
|
||||
self_store.store(other_store.val)
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
"""Calls the underlying function, applying the transforms.
|
||||
|
||||
The positional `args` and keyword `kwargs` are passed to the first
|
||||
transformation generator.
|
||||
"""
|
||||
stack = []
|
||||
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
|
||||
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
|
||||
args, kwargs = next(gen)
|
||||
stack.append((gen, out_store))
|
||||
gen = gen_static_args = out_store = None
|
||||
|
||||
try:
|
||||
ans = self.f(*args, **dict(self.params, **kwargs))
|
||||
except:
|
||||
# Some transformations yield from inside context managers, so we have to
|
||||
# interrupt them before reraising the exception. Otherwise they will only
|
||||
# get garbage-collected at some later time, running their cleanup tasks
|
||||
# only after this exception is handled, which can corrupt the global
|
||||
# state.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
|
||||
args = kwargs = None
|
||||
while stack:
|
||||
gen, out_store = stack.pop()
|
||||
try:
|
||||
ans = gen.send(ans)
|
||||
except:
|
||||
# As above does for the first half of the transformation, exceptions
|
||||
# raised in the second half of the transformation also require us to
|
||||
# clean up references here.
|
||||
while stack:
|
||||
stack.pop()[0].close()
|
||||
raise
|
||||
if out_store is not None:
|
||||
ans, side = ans
|
||||
out_store.store(side)
|
||||
|
||||
return ans
|
||||
|
||||
def __repr__(self):
|
||||
def transform_to_str(x):
|
||||
i, (gen, args) = x
|
||||
return f"{i} : {fun_name(gen)} {fun_name(args)}"
|
||||
transformation_stack = map(transform_to_str, enumerate(self.transforms))
|
||||
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.f, self.transforms, self.params, self.in_type))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.f == other.f and self.transforms == other.transforms and
|
||||
self.params == other.params and self.in_type == other.in_type)
|
||||
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
"""Adds one more transformation to a WrappedFun.
|
||||
Args:
|
||||
gen: the transformation generator function
|
||||
fun: a WrappedFun on which to apply the transformation
|
||||
gen_static_args: static args for the generator function
|
||||
"""
|
||||
return fun.wrap(gen, gen_static_args, None)
|
||||
|
||||
@curry
|
||||
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]:
|
||||
"""Adds one more transformation with auxiliary output to a WrappedFun."""
|
||||
out_store = Store()
|
||||
out_thunk = lambda: out_store.val
|
||||
return fun.wrap(gen, gen_static_args, out_store), out_thunk
|
||||
|
||||
def fun_name(f):
|
||||
try:
|
||||
return f.__name__
|
||||
except:
|
||||
return str(f)
|
||||
|
||||
def wrap_init(f, params=None) -> WrappedFun:
|
||||
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
|
||||
params = () if params is None else tuple(sorted(params.items()))
|
||||
return WrappedFun(f, (), (), params, None)
|
||||
|
||||
|
||||
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
|
||||
assert f.in_type is None
|
||||
if in_type is None:
|
||||
return f
|
||||
_check_input_type(in_type)
|
||||
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
|
||||
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
|
||||
|
||||
def valid_size(d) -> bool:
|
||||
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
|
||||
return True
|
||||
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
|
||||
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
|
||||
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
|
||||
for d in a.shape)
|
||||
|
||||
# Check that all DBIdx point to positions to the left of the input on which
|
||||
# they appear.
|
||||
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
|
||||
if isinstance(aval, core.DShapedArray) for d in aval.shape
|
||||
if isinstance(d, core.DBIdx))
|
||||
|
||||
# Check that all implicit arguments have at least one DBIdx pointing to them.
|
||||
provided = [e for _, e in in_type]
|
||||
for aval, _ in in_type:
|
||||
if type(aval) is core.DShapedArray:
|
||||
for d in aval.shape:
|
||||
if isinstance(d, core.DBIdx):
|
||||
provided[d.val] = True
|
||||
assert all(provided)
|
||||
|
||||
|
||||
def cache(call: Callable):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
Args:
|
||||
call: a Python callable that takes a WrappedFun as its first argument. The
|
||||
underlying transforms and params on the WrappedFun are used as part of the
|
||||
memoization cache key.
|
||||
|
||||
Returns:
|
||||
A memoized version of ``call``.
|
||||
"""
|
||||
fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
|
||||
def memoized_fun(fun: WrappedFun, *args):
|
||||
cache = fun_caches.setdefault(fun.f, {})
|
||||
if config.jax_check_tracer_leaks:
|
||||
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
|
||||
config.x64_enabled, config.jax_default_device,
|
||||
config._trace_context())
|
||||
else:
|
||||
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
|
||||
config.jax_default_device, config._trace_context())
|
||||
result = cache.get(key, None)
|
||||
if result is not None:
|
||||
ans, stores = result
|
||||
fun.populate_stores(stores)
|
||||
else:
|
||||
ans = call(fun, *args)
|
||||
cache[key] = (ans, fun.stores)
|
||||
|
||||
return ans
|
||||
|
||||
def _evict_function(f):
|
||||
fun_caches.pop(f, None)
|
||||
|
||||
memoized_fun.cache_clear = fun_caches.clear # type: ignore
|
||||
memoized_fun.evict_function = _evict_function # type: ignore
|
||||
|
||||
return memoized_fun
|
||||
|
||||
@partial(partial, tree_map)
|
||||
def _copy_main_traces(x):
|
||||
if isinstance(x, core.MainTrace):
|
||||
return core.MainTrace(x.level, x.trace_type, **x.payload)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
@transformation
|
||||
def hashable_partial(x, *args):
|
||||
ans = yield (x,) + args, {}
|
||||
yield ans
|
||||
|
||||
|
||||
def merge_linear_aux(aux1, aux2):
|
||||
try:
|
||||
out1 = aux1()
|
||||
except StoreException:
|
||||
# store 1 was not occupied, so store 2 better be
|
||||
try:
|
||||
out2 = aux2()
|
||||
except StoreException:
|
||||
raise StoreException("neither store occupied") from None
|
||||
else:
|
||||
return False, out2
|
||||
else:
|
||||
# store 1 was occupied, so let's check store 2 is not occupied
|
||||
try:
|
||||
out2 = aux2()
|
||||
except StoreException:
|
||||
return True, out1
|
||||
else:
|
||||
raise StoreException("both stores occupied")
|
||||
from jax._src.linear_util import (
|
||||
EmptyStoreValue as EmptyStoreValue,
|
||||
Store as Store,
|
||||
StoreException as StoreException,
|
||||
WrappedFun as WrappedFun,
|
||||
_EMPTY_STORE_VALUE as _EMPTY_STORE_VALUE,
|
||||
_check_input_type as _check_input_type,
|
||||
_copy_main_traces as _copy_main_traces,
|
||||
annotate as annotate,
|
||||
annotations as annotations,
|
||||
cache as cache,
|
||||
config as config,
|
||||
core as core,
|
||||
curry as curry,
|
||||
fun_name as fun_name,
|
||||
hashable_partial as hashable_partial,
|
||||
merge_linear_aux as merge_linear_aux,
|
||||
traceback_util as traceback_util,
|
||||
transformation as transformation,
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
tree_map as tree_map,
|
||||
wrap_init as wrap_init,
|
||||
)
|
||||
|
@ -35,6 +35,7 @@ per-file-ignores =
|
||||
jax/dtypes.py:F401
|
||||
jax/errors.py:F401
|
||||
jax/flatten_util.py:F401
|
||||
jax/linear_util.py:F401
|
||||
jax/prng.py:F401
|
||||
jax/profiler.py:F401
|
||||
jax/random.py:F401
|
||||
|
@ -65,7 +65,7 @@ from jax._src import prng
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
import jax._src.util as jax_util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
|
||||
@ -4746,7 +4746,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
@jax_util.curry
|
||||
def call(f, *args):
|
||||
return jax.core.call(
|
||||
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
|
||||
lu.wrap_init(lambda *args: [f(*args)]),
|
||||
*args, name='foo')[0]
|
||||
|
||||
f = call(add_one)
|
||||
|
@ -27,7 +27,7 @@ import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
|
@ -21,7 +21,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.experimental import maps
|
||||
|
@ -18,7 +18,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
|
Loading…
x
Reference in New Issue
Block a user