From 4a6bbde40919ffcc55f35e84090c7e68ce02accc Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 20 Dec 2022 14:49:27 -0800 Subject: [PATCH] Move jax.linear_util to jax._src.linear_util --- jax/__init__.py | 1 + jax/_src/ad_checkpoint.py | 2 +- jax/_src/api.py | 2 +- jax/_src/api_util.py | 2 +- jax/_src/checkify.py | 2 +- jax/_src/core.py | 2 +- jax/_src/custom_batching.py | 2 +- jax/_src/custom_derivatives.py | 2 +- jax/_src/custom_transpose.py | 2 +- jax/_src/debugging.py | 2 +- jax/_src/dispatch.py | 2 +- jax/_src/lax/control_flow/common.py | 2 +- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/control_flow/solves.py | 2 +- jax/_src/lax/lax.py | 2 +- jax/_src/linear_util.py | 346 +++++++++++++++++++++ jax/_src/pjit.py | 2 +- jax/_src/state/discharge.py | 2 +- jax/experimental/callback.py | 2 +- jax/experimental/custom_partitioning.py | 2 +- jax/experimental/jax2tf/jax2tf.py | 2 +- jax/experimental/jet.py | 2 +- jax/experimental/maps.py | 2 +- jax/experimental/ode.py | 2 +- jax/experimental/sparse/transform.py | 2 +- jax/interpreters/ad.py | 2 +- jax/interpreters/batching.py | 2 +- jax/interpreters/mlir.py | 2 +- jax/interpreters/partial_eval.py | 2 +- jax/interpreters/pxla.py | 2 +- jax/linear_util.py | 356 ++-------------------- setup.cfg | 1 + tests/api_test.py | 4 +- tests/core_test.py | 2 +- tests/jaxpr_effects_test.py | 2 +- tests/name_stack_test.py | 2 +- tests/state_test.py | 2 +- tests/util_test.py | 2 +- 40 files changed, 411 insertions(+), 367 deletions(-) create mode 100644 jax/_src/linear_util.py diff --git a/jax/__init__.py b/jax/__init__.py index 23ada77e9..5a9f4972e 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -149,6 +149,7 @@ from jax import dtypes as dtypes from jax import errors as errors from jax import image as image from jax import lax as lax +from jax import linear_util as linear_util from jax import nn as nn from jax import numpy as numpy from jax import ops as ops diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2dd87ae7..a738e2d0e 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -18,7 +18,7 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any, import types import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir diff --git a/jax/_src/api.py b/jax/_src/api.py index 8d0df5285..8c5be369a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -34,7 +34,7 @@ import numpy as np from contextlib import contextmanager, ExitStack import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import stages from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index b440ea8e5..6b7b86d00 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -27,7 +27,7 @@ from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure, treedef_children, treedef_is_leaf) from jax._src.tree_util import _replace_nones -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable from jax._src import traceback_util diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 61927930a..53b337952 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -20,7 +20,7 @@ from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, I import jax from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src import core from jax._src import prng from jax._src import source_info_util diff --git a/jax/_src/core.py b/jax/_src/core.py index 0dd0a004f..22aef15a3 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -40,7 +40,7 @@ from jax._src import config as jax_config from jax._src.config import FLAGS, config from jax.errors import (ConcretizationTypeError, TracerArrayConversionError, TracerIntegerConversionError, UnexpectedTracerError) -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert, diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 691280f09..5a75cbffe 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -17,7 +17,7 @@ import operator from typing import Callable, Optional import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import tree_util from jax.interpreters import ad from jax.interpreters import batching diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index cdaa65409..213855bde 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,7 +17,7 @@ import inspect from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set, Any) -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.custom_transpose import custom_transpose from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index f1aab5472..5f18dd5c4 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -15,7 +15,7 @@ import functools from typing import Any, Callable, Optional, Tuple -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.interpreters import ad from jax.interpreters import mlir from jax.interpreters import partial_eval as pe diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 9da711650..87cfc8860 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -22,7 +22,7 @@ from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union from jax import tree_util from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.experimental import pjit from jax.interpreters import ad diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 47251b94a..4ddf34c74 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -32,7 +32,7 @@ import warnings import numpy as np import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.errors import UnexpectedTracerError from jax.monitoring import record_event_duration_secs import jax.interpreters.ad as ad diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 3063a298d..f5f183f30 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -18,7 +18,7 @@ from functools import partial from typing import Callable, Optional, Sequence, Set from jax import core -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.api_util import flatten_fun_nokwargs from jax.interpreters import partial_eval as pe from jax._src.lax import lax diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index df7498116..5bba512bf 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -22,7 +22,7 @@ import operator from typing import Callable, Sequence, List, Tuple from jax import core -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.core import ConcreteArray, raise_to_shaped from jax.interpreters import ad diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 51faac8dc..e7b85ea00 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -19,7 +19,7 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.api_util import flatten_fun_nokwargs from jax.interpreters import ad from jax.interpreters import batching diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6dcddf14b..15e07733e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -21,7 +21,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar import jax import weakref from jax import core -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.core import ConcreteArray, ShapedArray, raise_to_shaped from jax.interpreters import ad diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 30a185027..a8eb155f1 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -19,7 +19,7 @@ import operator import jax from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.core import raise_to_shaped from jax.interpreters import ad from jax.interpreters import batching diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9504c4a01..49123fafe 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -32,7 +32,7 @@ from jax._src import api_util from jax._src import array from jax._src import device_array from jax._src import dispatch -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src import dtypes from jax import tree_util from jax._src import source_info_util diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py new file mode 100644 index 000000000..6f755e724 --- /dev/null +++ b/jax/_src/linear_util.py @@ -0,0 +1,346 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for defining functions composed with transformations. + +For example, + + from jax._src import linear_util as lu + + wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f` + +A `WrappedFun` object represents a function `f`, together with a sequence of +nested transformations that are to be applied to the positional and keyword +arguments at call time and function return values at return time. +A transformation can take some static positional arguments that are given +at the wrapping time, and may also return some auxiliary output: + + wf, aux_out_thunk = trans1(wf, static_arg) + +We can call the transformed function. First, the transformation is applied +to the dynamic args and keyword args to produce new dynamic and keyword args. +Then the underlying function is called and the transformation is applied to +the results. +If there are multiple transformations, they form a stack. The arguments are +transformed first with the last applied transformation; the results are +transformed first with the first applied transformation. + + res = wf.call_wrapped(dynamic_args, kwargs) + # Now `aux_out_thunk()` is the auxiliary output. + +A transformation is written as a generator function that takes zero or more +static positional arguments (given when the transformation is instantiated), +along with positional and keyword arguments to be transformed. +The generator will yield twice: + + @lu.transformation_with_aux + def trans1(static_arg, *dynamic_args, **kwargs): + ... + # First yield: pair of transformed (args, kwargs). Get back the results. + results = yield (new_dynamic_args, new_kwargs) + ... + # Second yield: pair of (transformed results, and auxiliary output) + yield new_results, auxiliary_output + + +`WrappedFun` objects explicitly represent the set of transformations so that +they can be used as dictionary keys for memoization. `WrappedFun` objects +compare as equal only if they compute the same function. The static and the +dynamic positional arguments for the generators, and also the auxiliary output +data must be immutable, because it will be stored in function memoization tables. +""" +from __future__ import annotations + +from functools import partial +from typing import Any, Tuple, Callable +import weakref + +from jax.tree_util import tree_map +from jax.config import config +from jax._src import core +from jax._src import traceback_util +from jax._src.util import curry + +traceback_util.register_exclusion(__file__) + + +class StoreException(Exception): pass + + +class EmptyStoreValue: pass +_EMPTY_STORE_VALUE = EmptyStoreValue() + +class Store: + """Storage for a value, with checks for overwriting or reading empty store.""" + __slots__ = ("_val",) + + def __init__(self): + self._val = _EMPTY_STORE_VALUE + + def store(self, val): + if self._val is not _EMPTY_STORE_VALUE: + raise StoreException("Store occupied") + self._val = val + + def reset(self): + # This should only be called in exceptional circumstances (e.g. debugging). + self._val = _EMPTY_STORE_VALUE + + @property + def val(self): + if not self: + raise StoreException("Store empty") + return self._val + + def __nonzero__(self): + return self._val is not _EMPTY_STORE_VALUE + + __bool__ = __nonzero__ + + +class WrappedFun: + """Represents a function `f` to which `transforms` are to be applied. + + Args: + f: the function to be transformed. + transforms: a list of `(gen, gen_static_args)` tuples representing + transformations to apply to `f.` Here `gen` is a generator function and + `gen_static_args` is a tuple of static arguments for the generator. See + description at the start of this module for the expected behavior of the + generator. + stores: a list of out_store for the auxiliary output of the `transforms`. + params: extra parameters to pass as keyword arguments to `f`, along with the + transformed keyword arguments. + """ + __slots__ = ("f", "transforms", "stores", "params", "in_type") + + def __init__(self, f, transforms, stores, params, in_type): + self.f = f + self.transforms = transforms + self.stores = stores + self.params = params + self.in_type = in_type + + @property + def __name__(self): + return getattr(self.f, '__name__', '') + + def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: + """Add another transform and its store.""" + return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None) + + def populate_stores(self, stores): + """Copy the values from the `stores` into `self.stores`.""" + for self_store, other_store in zip(self.stores, stores): + if self_store is not None: + self_store.store(other_store.val) + + def call_wrapped(self, *args, **kwargs): + """Calls the underlying function, applying the transforms. + + The positional `args` and keyword `kwargs` are passed to the first + transformation generator. + """ + stack = [] + for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): + gen = gen(*(gen_static_args + tuple(args)), **kwargs) + args, kwargs = next(gen) + stack.append((gen, out_store)) + gen = gen_static_args = out_store = None + + try: + ans = self.f(*args, **dict(self.params, **kwargs)) + except: + # Some transformations yield from inside context managers, so we have to + # interrupt them before reraising the exception. Otherwise they will only + # get garbage-collected at some later time, running their cleanup tasks + # only after this exception is handled, which can corrupt the global + # state. + while stack: + stack.pop()[0].close() + raise + + args = kwargs = None + while stack: + gen, out_store = stack.pop() + try: + ans = gen.send(ans) + except: + # As above does for the first half of the transformation, exceptions + # raised in the second half of the transformation also require us to + # clean up references here. + while stack: + stack.pop()[0].close() + raise + if out_store is not None: + ans, side = ans + out_store.store(side) + + return ans + + def __repr__(self): + def transform_to_str(x): + i, (gen, args) = x + return f"{i} : {fun_name(gen)} {fun_name(args)}" + transformation_stack = map(transform_to_str, enumerate(self.transforms)) + return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n' + + def __hash__(self): + return hash((self.f, self.transforms, self.params, self.in_type)) + + def __eq__(self, other): + return (self.f == other.f and self.transforms == other.transforms and + self.params == other.params and self.in_type == other.in_type) + +@curry +def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + """Adds one more transformation to a WrappedFun. + Args: + gen: the transformation generator function + fun: a WrappedFun on which to apply the transformation + gen_static_args: static args for the generator function + """ + return fun.wrap(gen, gen_static_args, None) + +@curry +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]: + """Adds one more transformation with auxiliary output to a WrappedFun.""" + out_store = Store() + out_thunk = lambda: out_store.val + return fun.wrap(gen, gen_static_args, out_store), out_thunk + +def fun_name(f): + try: + return f.__name__ + except: + return str(f) + +def wrap_init(f, params=None) -> WrappedFun: + """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" + params = () if params is None else tuple(sorted(params.items())) + return WrappedFun(f, (), (), params, None) + + +def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun: + assert f.in_type is None + if in_type is None: + return f + _check_input_type(in_type) + return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type) + +def _check_input_type(in_type: core.InputType) -> None: + # Check that in_type is syntactically well-formed + assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) + assert all(isinstance(a, core.AbstractValue) and type(b) is bool + and not isinstance(a, core.ConcreteArray) for a, b in in_type) + + def valid_size(d) -> bool: + if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: + return True + return (isinstance(d, (int, core.DBIdx, core.DArray)) and + (not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape)) + assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray + for d in a.shape) + + # Check that all DBIdx point to positions to the left of the input on which + # they appear. + assert all(d.val < i for i, (aval, _) in enumerate(in_type) + if isinstance(aval, core.DShapedArray) for d in aval.shape + if isinstance(d, core.DBIdx)) + + # Check that all implicit arguments have at least one DBIdx pointing to them. + provided = [e for _, e in in_type] + for aval, _ in in_type: + if type(aval) is core.DShapedArray: + for d in aval.shape: + if isinstance(d, core.DBIdx): + provided[d.val] = True + assert all(provided) + + +def cache(call: Callable): + """Memoization decorator for functions taking a WrappedFun as first argument. + + Args: + call: a Python callable that takes a WrappedFun as its first argument. The + underlying transforms and params on the WrappedFun are used as part of the + memoization cache key. + + Returns: + A memoized version of ``call``. + """ + fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + def memoized_fun(fun: WrappedFun, *args): + cache = fun_caches.setdefault(fun.f, {}) + if config.jax_check_tracer_leaks: + key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, + config.x64_enabled, config.jax_default_device, + config._trace_context()) + else: + key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled, + config.jax_default_device, config._trace_context()) + result = cache.get(key, None) + if result is not None: + ans, stores = result + fun.populate_stores(stores) + else: + ans = call(fun, *args) + cache[key] = (ans, fun.stores) + + return ans + + def _evict_function(f): + fun_caches.pop(f, None) + + memoized_fun.cache_clear = fun_caches.clear # type: ignore + memoized_fun.evict_function = _evict_function # type: ignore + + return memoized_fun + +@partial(partial, tree_map) +def _copy_main_traces(x): + if isinstance(x, core.MainTrace): + return core.MainTrace(x.level, x.trace_type, **x.payload) + else: + return x + + +@transformation +def hashable_partial(x, *args): + ans = yield (x,) + args, {} + yield ans + + +def merge_linear_aux(aux1, aux2): + try: + out1 = aux1() + except StoreException: + # store 1 was not occupied, so store 2 better be + try: + out2 = aux2() + except StoreException: + raise StoreException("neither store occupied") from None + else: + return False, out2 + else: + # store 1 was occupied, so let's check store 2 is not occupied + try: + out2 = aux2() + except StoreException: + return True, out1 + else: + raise StoreException("both stores occupied") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 2fc722e90..0a6589729 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -29,7 +29,7 @@ from jax._src.sharding import ( NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding, XLADeviceAssignment, SingleDeviceSharding) from jax import core -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import stages from jax._src import array from jax._src.api import (_check_callable, _check_arg, FLAGS, _resolve_argnums, diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 79e87607f..7d44aa3ab 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -22,7 +22,7 @@ import numpy as np from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.interpreters import partial_eval as pe from jax._src.util import safe_map, safe_zip, split_list diff --git a/jax/experimental/callback.py b/jax/experimental/callback.py index ccbc9f348..78945e001 100644 --- a/jax/experimental/callback.py +++ b/jax/experimental/callback.py @@ -24,7 +24,7 @@ from jax.core import Trace, Tracer, jaxpr_as_fun from jax import lax from jax import custom_derivatives as cd from jax.interpreters import partial_eval as pe -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src.util import safe_map, wraps, split_list from jax._src.lax import control_flow as lcf diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 124583907..0e7a9dec5 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Tuple import jax from jax import core from jax import tree_util -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.experimental import pjit from jax._src.lib.mlir.dialects import hlo diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index dac1ec88f..09804abf8 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -25,7 +25,7 @@ import jax from jax import lax from jax import config from jax import core, custom_derivatives -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import random, tree_util from jax import numpy as jnp from jax.experimental import maps diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 3c4efbe0e..7a9dba09a 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -62,7 +62,6 @@ import jax from jax import core from jax import lax from jax.interpreters import xla -import jax.linear_util as lu import jax.numpy as jnp from jax.tree_util import (register_pytree_node, tree_structure, treedef_is_leaf, tree_flatten, tree_unflatten) @@ -70,6 +69,7 @@ from jax.tree_util import (register_pytree_node, tree_structure, from jax._src import ad_util from jax._src import dispatch from jax._src.lax import lax as lax_internal +from jax._src import linear_util as lu from jax._src.util import unzip2 diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 49b30f458..490ea0e3d 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -25,7 +25,7 @@ from enum import Enum from jax import numpy as jnp from jax import core -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import stages from jax._src.api import _check_callable, _check_arg from jax._src import dispatch diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index 66184b07b..7cea34e88 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -38,7 +38,7 @@ from jax._src.numpy.util import _promote_dtypes_inexact from jax._src.util import safe_map, safe_zip from jax.flatten_util import ravel_pytree from jax.tree_util import tree_leaves, tree_map -from jax import linear_util as lu +from jax._src import linear_util as lu map = safe_map zip = safe_zip diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 0d1c0a933..bf97b4c81 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -54,7 +54,7 @@ import numpy as np from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse import jax.numpy as jnp from jax._src.api_util import flatten_fun_nokwargs diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 177fa1b9f..513b81d51 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -19,7 +19,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.interpreters import partial_eval as pe from jax.config import config from jax.tree_util import (tree_flatten, tree_unflatten, diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 264515049..11591b4f0 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -30,7 +30,7 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p, Zero) -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 72d41ab5f..3f273fc09 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -29,7 +29,7 @@ import warnings import numpy as np -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.interpreters import ad from jax.interpreters import partial_eval as pe diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 2227f93cb..8ddfd92b8 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -26,7 +26,7 @@ from weakref import ref import numpy as np -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax._src import api_util from jax._src import core diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 8106d7238..85af60964 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -47,7 +47,7 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, import numpy as np import jax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.errors import JAXTypeError from jax.interpreters import ad from jax.interpreters import batching diff --git a/jax/linear_util.py b/jax/linear_util.py index d6afac622..a17b396a7 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -12,335 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Utilities for defining functions composed with transformations. +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 -For example, +# TODO(jakevdp): deprecate these and remove this module. - from jax import linear_util as lu - - wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f` - -A `WrappedFun` object represents a function `f`, together with a sequence of -nested transformations that are to be applied to the positional and keyword -arguments at call time and function return values at return time. -A transformation can take some static positional arguments that are given -at the wrapping time, and may also return some auxiliary output: - - wf, aux_out_thunk = trans1(wf, static_arg) - -We can call the transformed function. First, the transformation is applied -to the dynamic args and keyword args to produce new dynamic and keyword args. -Then the underlying function is called and the transformation is applied to -the results. -If there are multiple transformations, they form a stack. The arguments are -transformed first with the last applied transformation; the results are -transformed first with the first applied transformation. - - res = wf.call_wrapped(dynamic_args, kwargs) - # Now `aux_out_thunk()` is the auxiliary output. - -A transformation is written as a generator function that takes zero or more -static positional arguments (given when the transformation is instantiated), -along with positional and keyword arguments to be transformed. -The generator will yield twice: - - @lu.transformation_with_aux - def trans1(static_arg, *dynamic_args, **kwargs): - ... - # First yield: pair of transformed (args, kwargs). Get back the results. - results = yield (new_dynamic_args, new_kwargs) - ... - # Second yield: pair of (transformed results, and auxiliary output) - yield new_results, auxiliary_output - - -`WrappedFun` objects explicitly represent the set of transformations so that -they can be used as dictionary keys for memoization. `WrappedFun` objects -compare as equal only if they compute the same function. The static and the -dynamic positional arguments for the generators, and also the auxiliary output -data must be immutable, because it will be stored in function memoization tables. -""" -from __future__ import annotations - -from functools import partial -from typing import Any, Tuple, Callable -import weakref - -from jax.tree_util import tree_map -from jax.config import config -from jax._src import core -from jax._src import traceback_util -from jax._src.util import curry - -traceback_util.register_exclusion(__file__) - - -class StoreException(Exception): pass - - -class EmptyStoreValue: pass -_EMPTY_STORE_VALUE = EmptyStoreValue() - -class Store: - """Storage for a value, with checks for overwriting or reading empty store.""" - __slots__ = ("_val",) - - def __init__(self): - self._val = _EMPTY_STORE_VALUE - - def store(self, val): - if self._val is not _EMPTY_STORE_VALUE: - raise StoreException("Store occupied") - self._val = val - - def reset(self): - # This should only be called in exceptional circumstances (e.g. debugging). - self._val = _EMPTY_STORE_VALUE - - @property - def val(self): - if not self: - raise StoreException("Store empty") - return self._val - - def __nonzero__(self): - return self._val is not _EMPTY_STORE_VALUE - - __bool__ = __nonzero__ - - -class WrappedFun: - """Represents a function `f` to which `transforms` are to be applied. - - Args: - f: the function to be transformed. - transforms: a list of `(gen, gen_static_args)` tuples representing - transformations to apply to `f.` Here `gen` is a generator function and - `gen_static_args` is a tuple of static arguments for the generator. See - description at the start of this module for the expected behavior of the - generator. - stores: a list of out_store for the auxiliary output of the `transforms`. - params: extra parameters to pass as keyword arguments to `f`, along with the - transformed keyword arguments. - """ - __slots__ = ("f", "transforms", "stores", "params", "in_type") - - def __init__(self, f, transforms, stores, params, in_type): - self.f = f - self.transforms = transforms - self.stores = stores - self.params = params - self.in_type = in_type - - @property - def __name__(self): - return getattr(self.f, '__name__', '') - - def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: - """Add another transform and its store.""" - return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None) - - def populate_stores(self, stores): - """Copy the values from the `stores` into `self.stores`.""" - for self_store, other_store in zip(self.stores, stores): - if self_store is not None: - self_store.store(other_store.val) - - def call_wrapped(self, *args, **kwargs): - """Calls the underlying function, applying the transforms. - - The positional `args` and keyword `kwargs` are passed to the first - transformation generator. - """ - stack = [] - for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): - gen = gen(*(gen_static_args + tuple(args)), **kwargs) - args, kwargs = next(gen) - stack.append((gen, out_store)) - gen = gen_static_args = out_store = None - - try: - ans = self.f(*args, **dict(self.params, **kwargs)) - except: - # Some transformations yield from inside context managers, so we have to - # interrupt them before reraising the exception. Otherwise they will only - # get garbage-collected at some later time, running their cleanup tasks - # only after this exception is handled, which can corrupt the global - # state. - while stack: - stack.pop()[0].close() - raise - - args = kwargs = None - while stack: - gen, out_store = stack.pop() - try: - ans = gen.send(ans) - except: - # As above does for the first half of the transformation, exceptions - # raised in the second half of the transformation also require us to - # clean up references here. - while stack: - stack.pop()[0].close() - raise - if out_store is not None: - ans, side = ans - out_store.store(side) - - return ans - - def __repr__(self): - def transform_to_str(x): - i, (gen, args) = x - return f"{i} : {fun_name(gen)} {fun_name(args)}" - transformation_stack = map(transform_to_str, enumerate(self.transforms)) - return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n' - - def __hash__(self): - return hash((self.f, self.transforms, self.params, self.in_type)) - - def __eq__(self, other): - return (self.f == other.f and self.transforms == other.transforms and - self.params == other.params and self.in_type == other.in_type) - -@curry -def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: - """Adds one more transformation to a WrappedFun. - Args: - gen: the transformation generator function - fun: a WrappedFun on which to apply the transformation - gen_static_args: static args for the generator function - """ - return fun.wrap(gen, gen_static_args, None) - -@curry -def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]: - """Adds one more transformation with auxiliary output to a WrappedFun.""" - out_store = Store() - out_thunk = lambda: out_store.val - return fun.wrap(gen, gen_static_args, out_store), out_thunk - -def fun_name(f): - try: - return f.__name__ - except: - return str(f) - -def wrap_init(f, params=None) -> WrappedFun: - """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" - params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None) - - -def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun: - assert f.in_type is None - if in_type is None: - return f - _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type) - -def _check_input_type(in_type: core.InputType) -> None: - # Check that in_type is syntactically well-formed - assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) - assert all(isinstance(a, core.AbstractValue) and type(b) is bool - and not isinstance(a, core.ConcreteArray) for a, b in in_type) - - def valid_size(d) -> bool: - if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: - return True - return (isinstance(d, (int, core.DBIdx, core.DArray)) and - (not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape)) - assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray - for d in a.shape) - - # Check that all DBIdx point to positions to the left of the input on which - # they appear. - assert all(d.val < i for i, (aval, _) in enumerate(in_type) - if isinstance(aval, core.DShapedArray) for d in aval.shape - if isinstance(d, core.DBIdx)) - - # Check that all implicit arguments have at least one DBIdx pointing to them. - provided = [e for _, e in in_type] - for aval, _ in in_type: - if type(aval) is core.DShapedArray: - for d in aval.shape: - if isinstance(d, core.DBIdx): - provided[d.val] = True - assert all(provided) - - -def cache(call: Callable): - """Memoization decorator for functions taking a WrappedFun as first argument. - - Args: - call: a Python callable that takes a WrappedFun as its first argument. The - underlying transforms and params on the WrappedFun are used as part of the - memoization cache key. - - Returns: - A memoized version of ``call``. - """ - fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() - - def memoized_fun(fun: WrappedFun, *args): - cache = fun_caches.setdefault(fun.f, {}) - if config.jax_check_tracer_leaks: - key = (_copy_main_traces(fun.transforms), fun.params, fun.in_type, args, - config.x64_enabled, config.jax_default_device, - config._trace_context()) - else: - key = (fun.transforms, fun.params, fun.in_type, args, config.x64_enabled, - config.jax_default_device, config._trace_context()) - result = cache.get(key, None) - if result is not None: - ans, stores = result - fun.populate_stores(stores) - else: - ans = call(fun, *args) - cache[key] = (ans, fun.stores) - - return ans - - def _evict_function(f): - fun_caches.pop(f, None) - - memoized_fun.cache_clear = fun_caches.clear # type: ignore - memoized_fun.evict_function = _evict_function # type: ignore - - return memoized_fun - -@partial(partial, tree_map) -def _copy_main_traces(x): - if isinstance(x, core.MainTrace): - return core.MainTrace(x.level, x.trace_type, **x.payload) - else: - return x - - -@transformation -def hashable_partial(x, *args): - ans = yield (x,) + args, {} - yield ans - - -def merge_linear_aux(aux1, aux2): - try: - out1 = aux1() - except StoreException: - # store 1 was not occupied, so store 2 better be - try: - out2 = aux2() - except StoreException: - raise StoreException("neither store occupied") from None - else: - return False, out2 - else: - # store 1 was occupied, so let's check store 2 is not occupied - try: - out2 = aux2() - except StoreException: - return True, out1 - else: - raise StoreException("both stores occupied") +from jax._src.linear_util import ( + EmptyStoreValue as EmptyStoreValue, + Store as Store, + StoreException as StoreException, + WrappedFun as WrappedFun, + _EMPTY_STORE_VALUE as _EMPTY_STORE_VALUE, + _check_input_type as _check_input_type, + _copy_main_traces as _copy_main_traces, + annotate as annotate, + annotations as annotations, + cache as cache, + config as config, + core as core, + curry as curry, + fun_name as fun_name, + hashable_partial as hashable_partial, + merge_linear_aux as merge_linear_aux, + traceback_util as traceback_util, + transformation as transformation, + transformation_with_aux as transformation_with_aux, + tree_map as tree_map, + wrap_init as wrap_init, +) diff --git a/setup.cfg b/setup.cfg index 0c3a255fa..6358f6092 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ per-file-ignores = jax/dtypes.py:F401 jax/errors.py:F401 jax/flatten_util.py:F401 + jax/linear_util.py:F401 jax/prng.py:F401 jax/profiler.py:F401 jax/random.py:F401 diff --git a/tests/api_test.py b/tests/api_test.py index c91d229ae..ca318404e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -65,7 +65,7 @@ from jax._src import prng from jax._src.lib import xla_client from jax._src import test_util as jtu from jax import tree_util -from jax import linear_util as lu +from jax._src import linear_util as lu import jax._src.util as jax_util from jax._src.ad_checkpoint import saved_residuals from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name @@ -4746,7 +4746,7 @@ class RematTest(jtu.JaxTestCase): @jax_util.curry def call(f, *args): return jax.core.call( - jax.linear_util.wrap_init(lambda *args: [f(*args)]), + lu.wrap_init(lambda *args: [f(*args)]), *args, name='foo')[0] f = call(add_one) diff --git a/tests/core_test.py b/tests/core_test.py index 9f15c91dd..5f35c9c3e 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -27,7 +27,7 @@ import jax from jax import core from jax import lax from jax import numpy as jnp -from jax import linear_util as lu +from jax._src import linear_util as lu from jax import jvp, linearize, vjp, jit, make_jaxpr from jax.core import UnshapedArray, ShapedArray, DBIdx from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce, diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index dd698e6ea..50f4a637a 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -21,7 +21,7 @@ import jax import jax.numpy as jnp from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.interpreters import ad from jax.experimental import maps diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index f3fceaa54..8b704b2ae 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax._src import test_util as jtu from jax._src.lib import xla_client diff --git a/tests/state_test.py b/tests/state_test.py index 5ab78f167..03af656e8 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -22,7 +22,7 @@ import numpy as np import jax from jax import core from jax import lax -from jax import linear_util as lu +from jax._src import linear_util as lu from jax.config import config from jax.interpreters import partial_eval as pe from jax._src import test_util as jtu diff --git a/tests/util_test.py b/tests/util_test.py index 0b338f331..5329a3591 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -14,7 +14,7 @@ from absl.testing import absltest -from jax import linear_util as lu +from jax._src import linear_util as lu from jax._src import test_util as jtu from jax.config import config