JEX: move jax.linear_util to jax.extend.linear_util

This commit is contained in:
Jake VanderPlas 2023-08-30 15:14:47 -07:00
parent 437d7be735
commit ca39457ea9
12 changed files with 122 additions and 15 deletions

View File

@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
* Internal deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.16
@ -73,6 +75,8 @@ Remember to align the itemized text with the first line of an item within a list
* The utility `jax.interpreters.xla.register_collective_primitive` has been
removed. This utility did nothing useful in recent JAX releases and calls
to it can be safely removed.
* The internal submodule path `jax.linear_util` has been deprecated. Use
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
## jaxlib 0.4.15 (Aug 30 2023)

View File

@ -3,6 +3,23 @@
.. automodule:: jax.extend
``jax.extend.linear_util``
--------------------------
.. automodule:: jax.extend.linear_util
.. autosummary::
:toctree: _autosummary
StoreException
WrappedFun
cache
merge_linear_aux
transformation
transformation_with_aux
wrap_init
``jax.extend.random``
---------------------

View File

@ -1,3 +1,4 @@
(jax-extend-jep)=
# `jax.extend`: a module for extensions
[@froystig](https://github.com/froystig),

View File

@ -161,7 +161,7 @@ from jax import dtypes as dtypes
from jax import errors as errors
from jax import image as image
from jax import lax as lax
from jax import linear_util as linear_util
from jax import linear_util as _deprecated_linear_util
from jax import monitoring as monitoring
from jax import nn as nn
from jax import numpy as numpy
@ -210,12 +210,18 @@ _deprecations = {
"tree_unflatten": (
"jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.",
_deprecated_tree_unflatten
)
),
# Added Aug 29 2023
"linear_util": (
"jax.linear_util is deprecated: use jax.extend.linear_util.",
_deprecated_linear_util,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax import abstract_arrays as abstract_arrays
from jax import linear_util as linear_util
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
from jax._src.tree_util import tree_flatten as tree_flatten
from jax._src.tree_util import tree_leaves as tree_leaves

View File

@ -235,6 +235,7 @@ class WrappedFun:
@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
"""Adds one more transformation to a WrappedFun.
Args:
gen: the transformation generator function
fun: a WrappedFun on which to apply the transformation

View File

@ -22,7 +22,6 @@ from typing import Any, Callable, Dict, Sequence, Tuple
import jax
from jax import api_util
from jax import linear_util as lu
from jax import tree_util
from jax import lax
from jax.interpreters import ad
@ -31,6 +30,7 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.state import discharge as state_discharge
from jax._src.util import (
split_list, safe_map, safe_zip, weakref_lru_cache,

View File

@ -181,7 +181,7 @@ def curry(f):
>>> curry(f)(2, 3, 4, 5)()
26
"""
return partial(partial, f)
return wraps(f)(partial(partial, f))
def toposort(end_nodes):
if not end_nodes: return []

View File

@ -29,5 +29,6 @@ Breaking changes will be announced via the
"""
from jax.extend import (
linear_util as linear_util,
random as random,
)

26
jax/extend/linear_util.py Normal file
View File

@ -0,0 +1,26 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.linear_util import (
StoreException as StoreException,
WrappedFun as WrappedFun,
cache as cache,
merge_linear_aux as merge_linear_aux,
transformation as transformation,
transformation_with_aux as transformation_with_aux,
wrap_init as wrap_init,
)

View File

@ -17,12 +17,53 @@
# TODO(jakevdp): deprecate these and remove this module.
from jax._src.linear_util import (
StoreException as StoreException,
WrappedFun as WrappedFun,
cache as cache,
merge_linear_aux as merge_linear_aux,
transformation as transformation,
transformation_with_aux as transformation_with_aux,
wrap_init as wrap_init,
)
from jax._src import linear_util as _lu
_deprecations = {
# Added August 29, 2023:
"StoreException": (
"jax.linear_util.StoreException is deprecated. Use jax.extend.linear_util.StoreException instead.",
_lu.StoreException,
),
"WrappedFun": (
"jax.linear_util.WrappedFun is deprecated. Use jax.extend.linear_util.WrappedFun instead.",
_lu.WrappedFun,
),
"cache": (
"jax.linear_util.cache is deprecated. Use jax.extend.linear_util.cache instead.",
_lu.cache,
),
"merge_linear_aux": (
"jax.linear_util.merge_linear_aux is deprecated. Use jax.extend.linear_util.merge_linear_aux instead.",
_lu.merge_linear_aux
),
"transformation": (
"jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.",
_lu.transformation
),
"transformation_with_aux": (
"jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.",
_lu.transformation_with_aux
),
"wrap_init": (
"jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.",
_lu.wrap_init
),
}
import typing
if typing.TYPE_CHECKING:
StoreException = _lu.StoreException
WrappedFun = _lu.WrappedFun
cache = _lu.cache
merge_linear_aux = _lu.merge_linear_aux
transformation = _lu.transformation
transformation_with_aux = _lu.transformation_with_aux
wrap_init = _lu.wrap_init
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _lu

View File

@ -17,6 +17,7 @@ from absl.testing import absltest
import jax
import jax.extend as jex
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
@ -35,6 +36,15 @@ class ExtendTest(jtu.JaxTestCase):
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
# Assume these are tested elsewhere, only check equivalence
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)
self.assertIs(jex.linear_util.cache, linear_util.cache)
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
class RandomTest(jtu.JaxTestCase):
def test_wrap_key_default(self):

View File

@ -24,8 +24,8 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import linear_util as lu
from jax import random
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import state
from jax._src.lax.control_flow.for_loop import for_loop