mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
JEX: move jax.linear_util to jax.extend.linear_util
This commit is contained in:
parent
437d7be735
commit
ca39457ea9
@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Internal deprecations/removals:
|
* Internal deprecations/removals:
|
||||||
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
|
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
|
||||||
{mod}`jax.extend.random`.
|
{mod}`jax.extend.random`.
|
||||||
|
* The internal submodule path `jax.linear_util` has been deprecated. Use
|
||||||
|
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
|
||||||
|
|
||||||
## jaxlib 0.4.16
|
## jaxlib 0.4.16
|
||||||
|
|
||||||
@ -73,6 +75,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* The utility `jax.interpreters.xla.register_collective_primitive` has been
|
* The utility `jax.interpreters.xla.register_collective_primitive` has been
|
||||||
removed. This utility did nothing useful in recent JAX releases and calls
|
removed. This utility did nothing useful in recent JAX releases and calls
|
||||||
to it can be safely removed.
|
to it can be safely removed.
|
||||||
|
* The internal submodule path `jax.linear_util` has been deprecated. Use
|
||||||
|
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
|
||||||
|
|
||||||
## jaxlib 0.4.15 (Aug 30 2023)
|
## jaxlib 0.4.15 (Aug 30 2023)
|
||||||
|
|
||||||
|
@ -3,6 +3,23 @@
|
|||||||
|
|
||||||
.. automodule:: jax.extend
|
.. automodule:: jax.extend
|
||||||
|
|
||||||
|
``jax.extend.linear_util``
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
.. automodule:: jax.extend.linear_util
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
StoreException
|
||||||
|
WrappedFun
|
||||||
|
cache
|
||||||
|
merge_linear_aux
|
||||||
|
transformation
|
||||||
|
transformation_with_aux
|
||||||
|
wrap_init
|
||||||
|
|
||||||
|
|
||||||
``jax.extend.random``
|
``jax.extend.random``
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
(jax-extend-jep)=
|
||||||
# `jax.extend`: a module for extensions
|
# `jax.extend`: a module for extensions
|
||||||
|
|
||||||
[@froystig](https://github.com/froystig),
|
[@froystig](https://github.com/froystig),
|
||||||
|
@ -161,7 +161,7 @@ from jax import dtypes as dtypes
|
|||||||
from jax import errors as errors
|
from jax import errors as errors
|
||||||
from jax import image as image
|
from jax import image as image
|
||||||
from jax import lax as lax
|
from jax import lax as lax
|
||||||
from jax import linear_util as linear_util
|
from jax import linear_util as _deprecated_linear_util
|
||||||
from jax import monitoring as monitoring
|
from jax import monitoring as monitoring
|
||||||
from jax import nn as nn
|
from jax import nn as nn
|
||||||
from jax import numpy as numpy
|
from jax import numpy as numpy
|
||||||
@ -210,12 +210,18 @@ _deprecations = {
|
|||||||
"tree_unflatten": (
|
"tree_unflatten": (
|
||||||
"jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.",
|
"jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.",
|
||||||
_deprecated_tree_unflatten
|
_deprecated_tree_unflatten
|
||||||
)
|
),
|
||||||
|
# Added Aug 29 2023
|
||||||
|
"linear_util": (
|
||||||
|
"jax.linear_util is deprecated: use jax.extend.linear_util.",
|
||||||
|
_deprecated_linear_util,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing as _typing
|
import typing as _typing
|
||||||
if _typing.TYPE_CHECKING:
|
if _typing.TYPE_CHECKING:
|
||||||
from jax._src import abstract_arrays as abstract_arrays
|
from jax import abstract_arrays as abstract_arrays
|
||||||
|
from jax import linear_util as linear_util
|
||||||
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
||||||
from jax._src.tree_util import tree_flatten as tree_flatten
|
from jax._src.tree_util import tree_flatten as tree_flatten
|
||||||
from jax._src.tree_util import tree_leaves as tree_leaves
|
from jax._src.tree_util import tree_leaves as tree_leaves
|
||||||
|
@ -235,6 +235,7 @@ class WrappedFun:
|
|||||||
@curry
|
@curry
|
||||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||||
"""Adds one more transformation to a WrappedFun.
|
"""Adds one more transformation to a WrappedFun.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gen: the transformation generator function
|
gen: the transformation generator function
|
||||||
fun: a WrappedFun on which to apply the transformation
|
fun: a WrappedFun on which to apply the transformation
|
||||||
|
@ -22,7 +22,6 @@ from typing import Any, Callable, Dict, Sequence, Tuple
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import api_util
|
from jax import api_util
|
||||||
from jax import linear_util as lu
|
|
||||||
from jax import tree_util
|
from jax import tree_util
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.interpreters import ad
|
from jax.interpreters import ad
|
||||||
@ -31,6 +30,7 @@ from jax.interpreters import partial_eval as pe
|
|||||||
from jax.interpreters import xla
|
from jax.interpreters import xla
|
||||||
from jax._src import ad_util
|
from jax._src import ad_util
|
||||||
from jax._src import core as jax_core
|
from jax._src import core as jax_core
|
||||||
|
from jax._src import linear_util as lu
|
||||||
from jax._src.state import discharge as state_discharge
|
from jax._src.state import discharge as state_discharge
|
||||||
from jax._src.util import (
|
from jax._src.util import (
|
||||||
split_list, safe_map, safe_zip, weakref_lru_cache,
|
split_list, safe_map, safe_zip, weakref_lru_cache,
|
||||||
|
@ -181,7 +181,7 @@ def curry(f):
|
|||||||
>>> curry(f)(2, 3, 4, 5)()
|
>>> curry(f)(2, 3, 4, 5)()
|
||||||
26
|
26
|
||||||
"""
|
"""
|
||||||
return partial(partial, f)
|
return wraps(f)(partial(partial, f))
|
||||||
|
|
||||||
def toposort(end_nodes):
|
def toposort(end_nodes):
|
||||||
if not end_nodes: return []
|
if not end_nodes: return []
|
||||||
|
@ -29,5 +29,6 @@ Breaking changes will be announced via the
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from jax.extend import (
|
from jax.extend import (
|
||||||
|
linear_util as linear_util,
|
||||||
random as random,
|
random as random,
|
||||||
)
|
)
|
||||||
|
26
jax/extend/linear_util.py
Normal file
26
jax/extend/linear_util.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright 2023 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Note: import <name> as <name> is required for names to be exported.
|
||||||
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||||
|
|
||||||
|
from jax._src.linear_util import (
|
||||||
|
StoreException as StoreException,
|
||||||
|
WrappedFun as WrappedFun,
|
||||||
|
cache as cache,
|
||||||
|
merge_linear_aux as merge_linear_aux,
|
||||||
|
transformation as transformation,
|
||||||
|
transformation_with_aux as transformation_with_aux,
|
||||||
|
wrap_init as wrap_init,
|
||||||
|
)
|
@ -17,12 +17,53 @@
|
|||||||
|
|
||||||
# TODO(jakevdp): deprecate these and remove this module.
|
# TODO(jakevdp): deprecate these and remove this module.
|
||||||
|
|
||||||
from jax._src.linear_util import (
|
from jax._src import linear_util as _lu
|
||||||
StoreException as StoreException,
|
|
||||||
WrappedFun as WrappedFun,
|
|
||||||
cache as cache,
|
_deprecations = {
|
||||||
merge_linear_aux as merge_linear_aux,
|
# Added August 29, 2023:
|
||||||
transformation as transformation,
|
"StoreException": (
|
||||||
transformation_with_aux as transformation_with_aux,
|
"jax.linear_util.StoreException is deprecated. Use jax.extend.linear_util.StoreException instead.",
|
||||||
wrap_init as wrap_init,
|
_lu.StoreException,
|
||||||
)
|
),
|
||||||
|
"WrappedFun": (
|
||||||
|
"jax.linear_util.WrappedFun is deprecated. Use jax.extend.linear_util.WrappedFun instead.",
|
||||||
|
_lu.WrappedFun,
|
||||||
|
),
|
||||||
|
"cache": (
|
||||||
|
"jax.linear_util.cache is deprecated. Use jax.extend.linear_util.cache instead.",
|
||||||
|
_lu.cache,
|
||||||
|
),
|
||||||
|
"merge_linear_aux": (
|
||||||
|
"jax.linear_util.merge_linear_aux is deprecated. Use jax.extend.linear_util.merge_linear_aux instead.",
|
||||||
|
_lu.merge_linear_aux
|
||||||
|
),
|
||||||
|
"transformation": (
|
||||||
|
"jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.",
|
||||||
|
_lu.transformation
|
||||||
|
),
|
||||||
|
"transformation_with_aux": (
|
||||||
|
"jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.",
|
||||||
|
_lu.transformation_with_aux
|
||||||
|
),
|
||||||
|
"wrap_init": (
|
||||||
|
"jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.",
|
||||||
|
_lu.wrap_init
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
import typing
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
StoreException = _lu.StoreException
|
||||||
|
WrappedFun = _lu.WrappedFun
|
||||||
|
cache = _lu.cache
|
||||||
|
merge_linear_aux = _lu.merge_linear_aux
|
||||||
|
transformation = _lu.transformation
|
||||||
|
transformation_with_aux = _lu.transformation_with_aux
|
||||||
|
wrap_init = _lu.wrap_init
|
||||||
|
else:
|
||||||
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
|
del _deprecation_getattr
|
||||||
|
del typing
|
||||||
|
del _lu
|
||||||
|
@ -17,6 +17,7 @@ from absl.testing import absltest
|
|||||||
import jax
|
import jax
|
||||||
import jax.extend as jex
|
import jax.extend as jex
|
||||||
|
|
||||||
|
from jax._src import linear_util
|
||||||
from jax._src import prng
|
from jax._src import prng
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
|
|
||||||
@ -35,6 +36,15 @@ class ExtendTest(jtu.JaxTestCase):
|
|||||||
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
|
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
|
||||||
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
|
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
|
||||||
|
|
||||||
|
# Assume these are tested elsewhere, only check equivalence
|
||||||
|
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
|
||||||
|
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)
|
||||||
|
self.assertIs(jex.linear_util.cache, linear_util.cache)
|
||||||
|
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
|
||||||
|
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
|
||||||
|
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
|
||||||
|
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
|
||||||
|
|
||||||
|
|
||||||
class RandomTest(jtu.JaxTestCase):
|
class RandomTest(jtu.JaxTestCase):
|
||||||
def test_wrap_key_default(self):
|
def test_wrap_key_default(self):
|
||||||
|
@ -24,8 +24,8 @@ from absl.testing import parameterized
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import linear_util as lu
|
|
||||||
from jax import random
|
from jax import random
|
||||||
|
from jax._src import linear_util as lu
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import state
|
from jax._src import state
|
||||||
from jax._src.lax.control_flow.for_loop import for_loop
|
from jax._src.lax.control_flow.for_loop import for_loop
|
||||||
|
Loading…
x
Reference in New Issue
Block a user