mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
JEX: move jax.linear_util to jax.extend.linear_util
This commit is contained in:
parent
437d7be735
commit
ca39457ea9
@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Internal deprecations/removals:
|
||||
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
|
||||
{mod}`jax.extend.random`.
|
||||
* The internal submodule path `jax.linear_util` has been deprecated. Use
|
||||
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
|
||||
|
||||
## jaxlib 0.4.16
|
||||
|
||||
@ -73,6 +75,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* The utility `jax.interpreters.xla.register_collective_primitive` has been
|
||||
removed. This utility did nothing useful in recent JAX releases and calls
|
||||
to it can be safely removed.
|
||||
* The internal submodule path `jax.linear_util` has been deprecated. Use
|
||||
{mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`)
|
||||
|
||||
## jaxlib 0.4.15 (Aug 30 2023)
|
||||
|
||||
|
@ -3,6 +3,23 @@
|
||||
|
||||
.. automodule:: jax.extend
|
||||
|
||||
``jax.extend.linear_util``
|
||||
--------------------------
|
||||
|
||||
.. automodule:: jax.extend.linear_util
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
StoreException
|
||||
WrappedFun
|
||||
cache
|
||||
merge_linear_aux
|
||||
transformation
|
||||
transformation_with_aux
|
||||
wrap_init
|
||||
|
||||
|
||||
``jax.extend.random``
|
||||
---------------------
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
(jax-extend-jep)=
|
||||
# `jax.extend`: a module for extensions
|
||||
|
||||
[@froystig](https://github.com/froystig),
|
||||
|
@ -161,7 +161,7 @@ from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import image as image
|
||||
from jax import lax as lax
|
||||
from jax import linear_util as linear_util
|
||||
from jax import linear_util as _deprecated_linear_util
|
||||
from jax import monitoring as monitoring
|
||||
from jax import nn as nn
|
||||
from jax import numpy as numpy
|
||||
@ -210,12 +210,18 @@ _deprecations = {
|
||||
"tree_unflatten": (
|
||||
"jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.",
|
||||
_deprecated_tree_unflatten
|
||||
)
|
||||
),
|
||||
# Added Aug 29 2023
|
||||
"linear_util": (
|
||||
"jax.linear_util is deprecated: use jax.extend.linear_util.",
|
||||
_deprecated_linear_util,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
from jax._src import abstract_arrays as abstract_arrays
|
||||
from jax import abstract_arrays as abstract_arrays
|
||||
from jax import linear_util as linear_util
|
||||
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
||||
from jax._src.tree_util import tree_flatten as tree_flatten
|
||||
from jax._src.tree_util import tree_leaves as tree_leaves
|
||||
|
@ -235,6 +235,7 @@ class WrappedFun:
|
||||
@curry
|
||||
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
|
||||
"""Adds one more transformation to a WrappedFun.
|
||||
|
||||
Args:
|
||||
gen: the transformation generator function
|
||||
fun: a WrappedFun on which to apply the transformation
|
||||
|
@ -22,7 +22,6 @@ from typing import Any, Callable, Dict, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
@ -31,6 +30,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.util import (
|
||||
split_list, safe_map, safe_zip, weakref_lru_cache,
|
||||
|
@ -181,7 +181,7 @@ def curry(f):
|
||||
>>> curry(f)(2, 3, 4, 5)()
|
||||
26
|
||||
"""
|
||||
return partial(partial, f)
|
||||
return wraps(f)(partial(partial, f))
|
||||
|
||||
def toposort(end_nodes):
|
||||
if not end_nodes: return []
|
||||
|
@ -29,5 +29,6 @@ Breaking changes will be announced via the
|
||||
"""
|
||||
|
||||
from jax.extend import (
|
||||
linear_util as linear_util,
|
||||
random as random,
|
||||
)
|
||||
|
26
jax/extend/linear_util.py
Normal file
26
jax/extend/linear_util.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.linear_util import (
|
||||
StoreException as StoreException,
|
||||
WrappedFun as WrappedFun,
|
||||
cache as cache,
|
||||
merge_linear_aux as merge_linear_aux,
|
||||
transformation as transformation,
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
wrap_init as wrap_init,
|
||||
)
|
@ -17,12 +17,53 @@
|
||||
|
||||
# TODO(jakevdp): deprecate these and remove this module.
|
||||
|
||||
from jax._src.linear_util import (
|
||||
StoreException as StoreException,
|
||||
WrappedFun as WrappedFun,
|
||||
cache as cache,
|
||||
merge_linear_aux as merge_linear_aux,
|
||||
transformation as transformation,
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
wrap_init as wrap_init,
|
||||
)
|
||||
from jax._src import linear_util as _lu
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Added August 29, 2023:
|
||||
"StoreException": (
|
||||
"jax.linear_util.StoreException is deprecated. Use jax.extend.linear_util.StoreException instead.",
|
||||
_lu.StoreException,
|
||||
),
|
||||
"WrappedFun": (
|
||||
"jax.linear_util.WrappedFun is deprecated. Use jax.extend.linear_util.WrappedFun instead.",
|
||||
_lu.WrappedFun,
|
||||
),
|
||||
"cache": (
|
||||
"jax.linear_util.cache is deprecated. Use jax.extend.linear_util.cache instead.",
|
||||
_lu.cache,
|
||||
),
|
||||
"merge_linear_aux": (
|
||||
"jax.linear_util.merge_linear_aux is deprecated. Use jax.extend.linear_util.merge_linear_aux instead.",
|
||||
_lu.merge_linear_aux
|
||||
),
|
||||
"transformation": (
|
||||
"jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.",
|
||||
_lu.transformation
|
||||
),
|
||||
"transformation_with_aux": (
|
||||
"jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.",
|
||||
_lu.transformation_with_aux
|
||||
),
|
||||
"wrap_init": (
|
||||
"jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.",
|
||||
_lu.wrap_init
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
StoreException = _lu.StoreException
|
||||
WrappedFun = _lu.WrappedFun
|
||||
cache = _lu.cache
|
||||
merge_linear_aux = _lu.merge_linear_aux
|
||||
transformation = _lu.transformation
|
||||
transformation_with_aux = _lu.transformation_with_aux
|
||||
wrap_init = _lu.wrap_init
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _lu
|
||||
|
@ -17,6 +17,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
@ -35,6 +36,15 @@ class ExtendTest(jtu.JaxTestCase):
|
||||
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
|
||||
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
|
||||
|
||||
# Assume these are tested elsewhere, only check equivalence
|
||||
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
|
||||
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)
|
||||
self.assertIs(jex.linear_util.cache, linear_util.cache)
|
||||
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
|
||||
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
|
||||
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
|
||||
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
|
||||
|
||||
|
||||
class RandomTest(jtu.JaxTestCase):
|
||||
def test_wrap_key_default(self):
|
||||
|
@ -24,8 +24,8 @@ from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax import random
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import state
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
Loading…
x
Reference in New Issue
Block a user