Move jax.linear_util to jax._src.linear_util

This commit is contained in:
Jake VanderPlas 2022-12-20 14:49:27 -08:00
parent d4fa1a4dfb
commit 4a6bbde409
40 changed files with 411 additions and 367 deletions

View File

@ -149,6 +149,7 @@ from jax import dtypes as dtypes
from jax import errors as errors
from jax import image as image
from jax import lax as lax
from jax import linear_util as linear_util
from jax import nn as nn
from jax import numpy as numpy
from jax import ops as ops

View File

@ -18,7 +18,7 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
import types
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir

View File

@ -34,7 +34,7 @@ import numpy as np
from contextlib import contextmanager, ExitStack
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import stages
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,

View File

@ -27,7 +27,7 @@ from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
treedef_children, treedef_is_leaf)
from jax._src.tree_util import _replace_nones
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax._src import traceback_util

View File

@ -20,7 +20,7 @@ from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, I
import jax
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src import core
from jax._src import prng
from jax._src import source_info_util

View File

@ -40,7 +40,7 @@ from jax._src import config as jax_config
from jax._src.config import FLAGS, config
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,

View File

@ -17,7 +17,7 @@ import operator
from typing import Callable, Optional
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -17,7 +17,7 @@ import inspect
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
Any)
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,

View File

@ -15,7 +15,7 @@
import functools
from typing import Any, Callable, Optional, Tuple
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe

View File

@ -22,7 +22,7 @@ from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
from jax import tree_util
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.experimental import pjit
from jax.interpreters import ad

View File

@ -32,7 +32,7 @@ import warnings
import numpy as np
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
import jax.interpreters.ad as ad

View File

@ -18,7 +18,7 @@ from functools import partial
from typing import Callable, Optional, Sequence, Set
from jax import core
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.lax import lax

View File

@ -22,7 +22,7 @@ import operator
from typing import Callable, Sequence, List, Tuple
from jax import core
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, raise_to_shaped
from jax.interpreters import ad

View File

@ -19,7 +19,7 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple,
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -21,7 +21,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar
import jax
import weakref
from jax import core
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad

View File

@ -19,7 +19,7 @@ import operator
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.core import raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -32,7 +32,7 @@ from jax._src import api_util
from jax._src import array
from jax._src import device_array
from jax._src import dispatch
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src import dtypes
from jax import tree_util
from jax._src import source_info_util

346
jax/_src/linear_util.py Normal file
View File

@ -0,0 +1,346 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for defining functions composed with transformations.
For example,
from jax._src import linear_util as lu
wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
A `WrappedFun` object represents a function `f`, together with a sequence of
nested transformations that are to be applied to the positional and keyword
arguments at call time and function return values at return time.
A transformation can take some static positional arguments that are given
at the wrapping time, and may also return some auxiliary output:
wf, aux_out_thunk = trans1(wf, static_arg)
We can call the transformed function. First, the transformation is applied
to the dynamic args and keyword args to produce new dynamic and keyword args.
Then the underlying function is called and the transformation is applied to
the results.
If there are multiple transformations, they form a stack. The arguments are
transformed first with the last applied transformation; the results are
transformed first with the first applied transformation.
res = wf.call_wrapped(dynamic_args, kwargs)
# Now `aux_out_thunk()` is the auxiliary output.
A transformation is written as a generator function that takes zero or more
static positional arguments (given when the transformation is instantiated),
along with positional and keyword arguments to be transformed.
The generator will yield twice:
@lu.transformation_with_aux
def trans1(static_arg, *dynamic_args, **kwargs):
...
# First yield: pair of transformed (args, kwargs). Get back the results.
results = yield (new_dynamic_args, new_kwargs)
...
# Second yield: pair of (transformed results, and auxiliary output)
yield new_results, auxiliary_output
`WrappedFun` objects explicitly represent the set of transformations so that
they can be used as dictionary keys for memoization. `WrappedFun` objects
compare as equal only if they compute the same function. The static and the
dynamic positional arguments for the generators, and also the auxiliary output
data must be immutable, because it will be stored in function memoization tables.
"""
from __future__ import annotations
from functools import partial
from typing import Any, Tuple, Callable
import weakref
from jax.tree_util import tree_map
from jax.config import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry
traceback_util.register_exclusion(__file__)
class StoreException(Exception): pass
class EmptyStoreValue: pass
_EMPTY_STORE_VALUE = EmptyStoreValue()
class Store:
"""Storage for a value, with checks for overwriting or reading empty store."""
__slots__ = ("_val",)
def __init__(self):
self._val = _EMPTY_STORE_VALUE
def store(self, val):
if self._val is not _EMPTY_STORE_VALUE:
raise StoreException("Store occupied")
self._val = val
def reset(self):
# This should only be called in exceptional circumstances (e.g. debugging).
self._val = _EMPTY_STORE_VALUE
@property
def val(self):
if not self:
raise StoreException("Store empty")
return self._val
def __nonzero__(self):
return self._val is not _EMPTY_STORE_VALUE
__bool__ = __nonzero__
class WrappedFun:
"""Represents a function `f` to which `transforms` are to be applied.
Args:
f: the function to be transformed.
transforms: a list of `(gen, gen_static_args)` tuples representing
transformations to apply to `f.` Here `gen` is a generator function and
`gen_static_args` is a tuple of static arguments for the generator. See
description at the start of this module for the expected behavior of the
generator.
stores: a list of out_store for the auxiliary output of the `transforms`.
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params", "in_type")
def __init__(self, f, transforms, stores, params, in_type):
self.f = f
self.transforms = transforms
self.stores = stores
self.params = params
self.in_type = in_type
@property
def __name__(self):
return getattr(self.f, '__name__', '<unnamed wrapped function>')
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
for self_store, other_store in zip(self.stores, stores):
if self_store is not None:
self_store.store(other_store.val)
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
ans = self.f(*args, **dict(self.params, **kwargs))
except:
# Some transformations yield from inside context managers, so we have to
# interrupt them before reraising the exception. Otherwise they will only
# get garbage-collected at some later time, running their cleanup tasks
# only after this exception is handled, which can corrupt the global
# state.
while stack:
stack.pop()[0].close()
raise
args = kwargs = None
while stack:
gen, out_store = stack.pop()
try:
ans = gen.send(ans)
except:
# As above does for the first half of the transformation, exceptions
# raised in the second half of the transformation also require us to
# clean up references here.
while stack:
stack.pop()[0].close()
raise
if out_store is not None:
ans, side = ans
out_store.store(side)
return ans
def __repr__(self):
def transform_to_str(x):
i, (gen, args) = x
return f"{i} : {fun_name(gen)} {fun_name(args)}"
transformation_stack = map(transform_to_str, enumerate(self.transforms))
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def __hash__(self):
return hash((self.f, self.transforms, self.params, self.in_type))
def __eq__(self, other):
return (self.f == other.f and self.transforms == other.transforms and
self.params == other.params and self.in_type == other.in_type)
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""Adds one more transformation to a WrappedFun.
Args:
gen: the transformation generator function
fun: a WrappedFun on which to apply the transformation
gen_static_args: static args for the generator function
"""
return fun.wrap(gen, gen_static_args, None)
@curry
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]:
"""Adds one more transformation with auxiliary output to a WrappedFun."""
out_store = Store()
out_thunk = lambda: out_store.val
return fun.wrap(gen, gen_static_args, out_store), out_thunk
def fun_name(f):
try:
return f.__name__
except:
return str(f)
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
def valid_size(d) -> bool:
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
return True
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
for d in a.shape)
# Check that all DBIdx point to positions to the left of the input on which
# they appear.
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
if isinstance(aval, core.DShapedArray) for d in aval.shape
if isinstance(d, core.DBIdx))
# Check that all implicit arguments have at least one DBIdx pointing to them.
provided = [e for _, e in in_type]
for aval, _ in in_type:
if type(aval) is core.DShapedArray:
for d in aval.shape:
if isinstance(d, core.DBIdx):
provided[d.val] = True
assert all(provided)
def cache(call: Callable):
"""Memoization decorator for functions taking a WrappedFun as first argument.
Args:
call: a Python callable that takes a WrappedFun as its first argument. The
underlying transforms and params on the WrappedFun are used as part of the
memoization cache key.
Returns:
A memoized version of ``call``.
"""
fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config.jax_default_device,
config._trace_context())
else:
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
config.jax_default_device, config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
ans = call(fun, *args)
cache[key] = (ans, fun.stores)
return ans
def _evict_function(f):
fun_caches.pop(f, None)
memoized_fun.cache_clear = fun_caches.clear # type: ignore
memoized_fun.evict_function = _evict_function # type: ignore
return memoized_fun
@partial(partial, tree_map)
def _copy_main_traces(x):
if isinstance(x, core.MainTrace):
return core.MainTrace(x.level, x.trace_type, **x.payload)
else:
return x
@transformation
def hashable_partial(x, *args):
ans = yield (x,) + args, {}
yield ans
def merge_linear_aux(aux1, aux2):
try:
out1 = aux1()
except StoreException:
# store 1 was not occupied, so store 2 better be
try:
out2 = aux2()
except StoreException:
raise StoreException("neither store occupied") from None
else:
return False, out2
else:
# store 1 was occupied, so let's check store 2 is not occupied
try:
out2 = aux2()
except StoreException:
return True, out1
else:
raise StoreException("both stores occupied")

View File

@ -29,7 +29,7 @@ from jax._src.sharding import (
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
XLADeviceAssignment, SingleDeviceSharding)
from jax import core
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import stages
from jax._src import array
from jax._src.api import (_check_callable, _check_arg, FLAGS, _resolve_argnums,

View File

@ -22,7 +22,7 @@ import numpy as np
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map, safe_zip, split_list

View File

@ -24,7 +24,7 @@ from jax.core import Trace, Tracer, jaxpr_as_fun
from jax import lax
from jax import custom_derivatives as cd
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src.util import safe_map, wraps, split_list
from jax._src.lax import control_flow as lcf

View File

@ -17,7 +17,7 @@ from typing import Any, Callable, Tuple
import jax
from jax import core
from jax import tree_util
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.experimental import pjit
from jax._src.lib.mlir.dialects import hlo

View File

@ -25,7 +25,7 @@ import jax
from jax import lax
from jax import config
from jax import core, custom_derivatives
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import random, tree_util
from jax import numpy as jnp
from jax.experimental import maps

View File

@ -62,7 +62,6 @@ import jax
from jax import core
from jax import lax
from jax.interpreters import xla
import jax.linear_util as lu
import jax.numpy as jnp
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
@ -70,6 +69,7 @@ from jax.tree_util import (register_pytree_node, tree_structure,
from jax._src import ad_util
from jax._src import dispatch
from jax._src.lax import lax as lax_internal
from jax._src import linear_util as lu
from jax._src.util import unzip2

View File

@ -25,7 +25,7 @@ from enum import Enum
from jax import numpy as jnp
from jax import core
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg
from jax._src import dispatch

View File

@ -38,7 +38,7 @@ from jax._src.numpy.util import _promote_dtypes_inexact
from jax._src.util import safe_map, safe_zip
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves, tree_map
from jax import linear_util as lu
from jax._src import linear_util as lu
map = safe_map
zip = safe_zip

View File

@ -54,7 +54,7 @@ import numpy as np
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs

View File

@ -19,7 +19,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.config import config
from jax.tree_util import (tree_flatten, tree_unflatten,

View File

@ -30,7 +30,7 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero)
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)

View File

@ -29,7 +29,7 @@ import warnings
import numpy as np
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe

View File

@ -26,7 +26,7 @@ from weakref import ref
import numpy as np
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax._src import api_util
from jax._src import core

View File

@ -47,7 +47,7 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
import numpy as np
import jax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -12,335 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for defining functions composed with transformations.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
For example,
# TODO(jakevdp): deprecate these and remove this module.
from jax import linear_util as lu
wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f`
A `WrappedFun` object represents a function `f`, together with a sequence of
nested transformations that are to be applied to the positional and keyword
arguments at call time and function return values at return time.
A transformation can take some static positional arguments that are given
at the wrapping time, and may also return some auxiliary output:
wf, aux_out_thunk = trans1(wf, static_arg)
We can call the transformed function. First, the transformation is applied
to the dynamic args and keyword args to produce new dynamic and keyword args.
Then the underlying function is called and the transformation is applied to
the results.
If there are multiple transformations, they form a stack. The arguments are
transformed first with the last applied transformation; the results are
transformed first with the first applied transformation.
res = wf.call_wrapped(dynamic_args, kwargs)
# Now `aux_out_thunk()` is the auxiliary output.
A transformation is written as a generator function that takes zero or more
static positional arguments (given when the transformation is instantiated),
along with positional and keyword arguments to be transformed.
The generator will yield twice:
@lu.transformation_with_aux
def trans1(static_arg, *dynamic_args, **kwargs):
...
# First yield: pair of transformed (args, kwargs). Get back the results.
results = yield (new_dynamic_args, new_kwargs)
...
# Second yield: pair of (transformed results, and auxiliary output)
yield new_results, auxiliary_output
`WrappedFun` objects explicitly represent the set of transformations so that
they can be used as dictionary keys for memoization. `WrappedFun` objects
compare as equal only if they compute the same function. The static and the
dynamic positional arguments for the generators, and also the auxiliary output
data must be immutable, because it will be stored in function memoization tables.
"""
from __future__ import annotations
from functools import partial
from typing import Any, Tuple, Callable
import weakref
from jax.tree_util import tree_map
from jax.config import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry
traceback_util.register_exclusion(__file__)
class StoreException(Exception): pass
class EmptyStoreValue: pass
_EMPTY_STORE_VALUE = EmptyStoreValue()
class Store:
"""Storage for a value, with checks for overwriting or reading empty store."""
__slots__ = ("_val",)
def __init__(self):
self._val = _EMPTY_STORE_VALUE
def store(self, val):
if self._val is not _EMPTY_STORE_VALUE:
raise StoreException("Store occupied")
self._val = val
def reset(self):
# This should only be called in exceptional circumstances (e.g. debugging).
self._val = _EMPTY_STORE_VALUE
@property
def val(self):
if not self:
raise StoreException("Store empty")
return self._val
def __nonzero__(self):
return self._val is not _EMPTY_STORE_VALUE
__bool__ = __nonzero__
class WrappedFun:
"""Represents a function `f` to which `transforms` are to be applied.
Args:
f: the function to be transformed.
transforms: a list of `(gen, gen_static_args)` tuples representing
transformations to apply to `f.` Here `gen` is a generator function and
`gen_static_args` is a tuple of static arguments for the generator. See
description at the start of this module for the expected behavior of the
generator.
stores: a list of out_store for the auxiliary output of the `transforms`.
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params", "in_type")
def __init__(self, f, transforms, stores, params, in_type):
self.f = f
self.transforms = transforms
self.stores = stores
self.params = params
self.in_type = in_type
@property
def __name__(self):
return getattr(self.f, '__name__', '<unnamed wrapped function>')
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None)
def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
for self_store, other_store in zip(self.stores, stores):
if self_store is not None:
self_store.store(other_store.val)
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
ans = self.f(*args, **dict(self.params, **kwargs))
except:
# Some transformations yield from inside context managers, so we have to
# interrupt them before reraising the exception. Otherwise they will only
# get garbage-collected at some later time, running their cleanup tasks
# only after this exception is handled, which can corrupt the global
# state.
while stack:
stack.pop()[0].close()
raise
args = kwargs = None
while stack:
gen, out_store = stack.pop()
try:
ans = gen.send(ans)
except:
# As above does for the first half of the transformation, exceptions
# raised in the second half of the transformation also require us to
# clean up references here.
while stack:
stack.pop()[0].close()
raise
if out_store is not None:
ans, side = ans
out_store.store(side)
return ans
def __repr__(self):
def transform_to_str(x):
i, (gen, args) = x
return f"{i} : {fun_name(gen)} {fun_name(args)}"
transformation_stack = map(transform_to_str, enumerate(self.transforms))
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def __hash__(self):
return hash((self.f, self.transforms, self.params, self.in_type))
def __eq__(self, other):
return (self.f == other.f and self.transforms == other.transforms and
self.params == other.params and self.in_type == other.in_type)
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""Adds one more transformation to a WrappedFun.
Args:
gen: the transformation generator function
fun: a WrappedFun on which to apply the transformation
gen_static_args: static args for the generator function
"""
return fun.wrap(gen, gen_static_args, None)
@curry
def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]:
"""Adds one more transformation with auxiliary output to a WrappedFun."""
out_store = Store()
out_thunk = lambda: out_store.val
return fun.wrap(gen, gen_static_args, out_store), out_thunk
def fun_name(f):
try:
return f.__name__
except:
return str(f)
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
def valid_size(d) -> bool:
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
return True
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
for d in a.shape)
# Check that all DBIdx point to positions to the left of the input on which
# they appear.
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
if isinstance(aval, core.DShapedArray) for d in aval.shape
if isinstance(d, core.DBIdx))
# Check that all implicit arguments have at least one DBIdx pointing to them.
provided = [e for _, e in in_type]
for aval, _ in in_type:
if type(aval) is core.DShapedArray:
for d in aval.shape:
if isinstance(d, core.DBIdx):
provided[d.val] = True
assert all(provided)
def cache(call: Callable):
"""Memoization decorator for functions taking a WrappedFun as first argument.
Args:
call: a Python callable that takes a WrappedFun as its first argument. The
underlying transforms and params on the WrappedFun are used as part of the
memoization cache key.
Returns:
A memoized version of ``call``.
"""
fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args,
config.x64_enabled, config.jax_default_device,
config._trace_context())
else:
key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled,
config.jax_default_device, config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
ans = call(fun, *args)
cache[key] = (ans, fun.stores)
return ans
def _evict_function(f):
fun_caches.pop(f, None)
memoized_fun.cache_clear = fun_caches.clear # type: ignore
memoized_fun.evict_function = _evict_function # type: ignore
return memoized_fun
@partial(partial, tree_map)
def _copy_main_traces(x):
if isinstance(x, core.MainTrace):
return core.MainTrace(x.level, x.trace_type, **x.payload)
else:
return x
@transformation
def hashable_partial(x, *args):
ans = yield (x,) + args, {}
yield ans
def merge_linear_aux(aux1, aux2):
try:
out1 = aux1()
except StoreException:
# store 1 was not occupied, so store 2 better be
try:
out2 = aux2()
except StoreException:
raise StoreException("neither store occupied") from None
else:
return False, out2
else:
# store 1 was occupied, so let's check store 2 is not occupied
try:
out2 = aux2()
except StoreException:
return True, out1
else:
raise StoreException("both stores occupied")
from jax._src.linear_util import (
EmptyStoreValue as EmptyStoreValue,
Store as Store,
StoreException as StoreException,
WrappedFun as WrappedFun,
_EMPTY_STORE_VALUE as _EMPTY_STORE_VALUE,
_check_input_type as _check_input_type,
_copy_main_traces as _copy_main_traces,
annotate as annotate,
annotations as annotations,
cache as cache,
config as config,
core as core,
curry as curry,
fun_name as fun_name,
hashable_partial as hashable_partial,
merge_linear_aux as merge_linear_aux,
traceback_util as traceback_util,
transformation as transformation,
transformation_with_aux as transformation_with_aux,
tree_map as tree_map,
wrap_init as wrap_init,
)

View File

@ -35,6 +35,7 @@ per-file-ignores =
jax/dtypes.py:F401
jax/errors.py:F401
jax/flatten_util.py:F401
jax/linear_util.py:F401
jax/prng.py:F401
jax/profiler.py:F401
jax/random.py:F401

View File

@ -65,7 +65,7 @@ from jax._src import prng
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
from jax._src import linear_util as lu
import jax._src.util as jax_util
from jax._src.ad_checkpoint import saved_residuals
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
@ -4746,7 +4746,7 @@ class RematTest(jtu.JaxTestCase):
@jax_util.curry
def call(f, *args):
return jax.core.call(
jax.linear_util.wrap_init(lambda *args: [f(*args)]),
lu.wrap_init(lambda *args: [f(*args)]),
*args, name='foo')[0]
f = call(add_one)

View File

@ -27,7 +27,7 @@ import jax
from jax import core
from jax import lax
from jax import numpy as jnp
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray, DBIdx
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,

View File

@ -21,7 +21,7 @@ import jax
import jax.numpy as jnp
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import ad
from jax.experimental import maps

View File

@ -18,7 +18,7 @@ import jax
import jax.numpy as jnp
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax._src import test_util as jtu
from jax._src.lib import xla_client

View File

@ -22,7 +22,7 @@ import numpy as np
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import test_util as jtu

View File

@ -14,7 +14,7 @@
from absl.testing import absltest
from jax import linear_util as lu
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax.config import config