JEX: move jax.linear_util to jax.extend.linear_util

This commit is contained in:
Jake VanderPlas 2023-08-30 15:14:47 -07:00
parent 437d7be735
commit ca39457ea9
12 changed files with 122 additions and 15 deletions

View File

@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
* Internal deprecations/removals: * Internal deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at * The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`. {mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.16 ## jaxlib 0.4.16
@ -73,6 +75,8 @@ Remember to align the itemized text with the first line of an item within a list
* The utility `jax.interpreters.xla.register_collective_primitive` has been * The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed. to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.15 (Aug 30 2023) ## jaxlib 0.4.15 (Aug 30 2023)

View File

@ -3,6 +3,23 @@
.. automodule:: jax.extend .. automodule:: jax.extend
``jax.extend.linear_util``
--------------------------
.. automodule:: jax.extend.linear_util
.. autosummary::
:toctree: _autosummary
StoreException
WrappedFun
cache
merge_linear_aux
transformation
transformation_with_aux
wrap_init
``jax.extend.random`` ``jax.extend.random``
--------------------- ---------------------

View File

@ -1,3 +1,4 @@
(jax-extend-jep)=
# `jax.extend`: a module for extensions # `jax.extend`: a module for extensions
[@froystig](https://github.com/froystig), [@froystig](https://github.com/froystig),

View File

@ -161,7 +161,7 @@ from jax import dtypes as dtypes
from jax import errors as errors from jax import errors as errors
from jax import image as image from jax import image as image
from jax import lax as lax from jax import lax as lax
from jax import linear_util as linear_util from jax import linear_util as _deprecated_linear_util
from jax import monitoring as monitoring from jax import monitoring as monitoring
from jax import nn as nn from jax import nn as nn
from jax import numpy as numpy from jax import numpy as numpy
@ -210,12 +210,18 @@ _deprecations = {
"tree_unflatten": ( "tree_unflatten": (
"jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.", "jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.",
_deprecated_tree_unflatten _deprecated_tree_unflatten
) ),
# Added Aug 29 2023
"linear_util": (
"jax.linear_util is deprecated: use jax.extend.linear_util.",
_deprecated_linear_util,
),
} }
import typing as _typing import typing as _typing
if _typing.TYPE_CHECKING: if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays from jax import abstract_arrays as abstract_arrays
from jax import linear_util as linear_util
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves from jax._src.tree_util import tree_leaves as tree_leaves

View File

@ -235,6 +235,7 @@ class WrappedFun:
@curry @curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""Adds one more transformation to a WrappedFun. """Adds one more transformation to a WrappedFun.
Args: Args:
gen: the transformation generator function gen: the transformation generator function
fun: a WrappedFun on which to apply the transformation fun: a WrappedFun on which to apply the transformation

View File

@ -22,7 +22,6 @@ from typing import Any, Callable, Dict, Sequence, Tuple
import jax import jax
from jax import api_util from jax import api_util
from jax import linear_util as lu
from jax import tree_util from jax import tree_util
from jax import lax from jax import lax
from jax.interpreters import ad from jax.interpreters import ad
@ -31,6 +30,7 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla from jax.interpreters import xla
from jax._src import ad_util from jax._src import ad_util
from jax._src import core as jax_core from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.state import discharge as state_discharge from jax._src.state import discharge as state_discharge
from jax._src.util import ( from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache, split_list, safe_map, safe_zip, weakref_lru_cache,

View File

@ -181,7 +181,7 @@ def curry(f):
>>> curry(f)(2, 3, 4, 5)() >>> curry(f)(2, 3, 4, 5)()
26 26
""" """
return partial(partial, f) return wraps(f)(partial(partial, f))
def toposort(end_nodes): def toposort(end_nodes):
if not end_nodes: return [] if not end_nodes: return []

View File

@ -29,5 +29,6 @@ Breaking changes will be announced via the
""" """
from jax.extend import ( from jax.extend import (
linear_util as linear_util,
random as random, random as random,
) )

26
jax/extend/linear_util.py Normal file
View File

@ -0,0 +1,26 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.linear_util import (
StoreException as StoreException,
WrappedFun as WrappedFun,
cache as cache,
merge_linear_aux as merge_linear_aux,
transformation as transformation,
transformation_with_aux as transformation_with_aux,
wrap_init as wrap_init,
)

View File

@ -17,12 +17,53 @@
# TODO(jakevdp): deprecate these and remove this module. # TODO(jakevdp): deprecate these and remove this module.
from jax._src.linear_util import ( from jax._src import linear_util as _lu
StoreException as StoreException,
WrappedFun as WrappedFun,
cache as cache, _deprecations = {
merge_linear_aux as merge_linear_aux, # Added August 29, 2023:
transformation as transformation, "StoreException": (
transformation_with_aux as transformation_with_aux, "jax.linear_util.StoreException is deprecated. Use jax.extend.linear_util.StoreException instead.",
wrap_init as wrap_init, _lu.StoreException,
) ),
"WrappedFun": (
"jax.linear_util.WrappedFun is deprecated. Use jax.extend.linear_util.WrappedFun instead.",
_lu.WrappedFun,
),
"cache": (
"jax.linear_util.cache is deprecated. Use jax.extend.linear_util.cache instead.",
_lu.cache,
),
"merge_linear_aux": (
"jax.linear_util.merge_linear_aux is deprecated. Use jax.extend.linear_util.merge_linear_aux instead.",
_lu.merge_linear_aux
),
"transformation": (
"jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.",
_lu.transformation
),
"transformation_with_aux": (
"jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.",
_lu.transformation_with_aux
),
"wrap_init": (
"jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.",
_lu.wrap_init
),
}
import typing
if typing.TYPE_CHECKING:
StoreException = _lu.StoreException
WrappedFun = _lu.WrappedFun
cache = _lu.cache
merge_linear_aux = _lu.merge_linear_aux
transformation = _lu.transformation
transformation_with_aux = _lu.transformation_with_aux
wrap_init = _lu.wrap_init
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _lu

View File

@ -17,6 +17,7 @@ from absl.testing import absltest
import jax import jax
import jax.extend as jex import jax.extend as jex
from jax._src import linear_util
from jax._src import prng from jax._src import prng
from jax._src import test_util as jtu from jax._src import test_util as jtu
@ -35,6 +36,15 @@ class ExtendTest(jtu.JaxTestCase):
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl) self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl) self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
# Assume these are tested elsewhere, only check equivalence
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)
self.assertIs(jex.linear_util.cache, linear_util.cache)
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
class RandomTest(jtu.JaxTestCase): class RandomTest(jtu.JaxTestCase):
def test_wrap_key_default(self): def test_wrap_key_default(self):

View File

@ -24,8 +24,8 @@ from absl.testing import parameterized
import jax import jax
from jax import lax from jax import lax
from jax import linear_util as lu
from jax import random from jax import random
from jax._src import linear_util as lu
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import state from jax._src import state
from jax._src.lax.control_flow.for_loop import for_loop from jax._src.lax.control_flow.for_loop import for_loop