From ca39457ea9f143e779e8aaff4fefb9115fbd7a64 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 30 Aug 2023 15:14:47 -0700 Subject: [PATCH] JEX: move jax.linear_util to jax.extend.linear_util --- CHANGELOG.md | 4 +++ docs/jax.extend.rst | 17 ++++++++++ docs/jep/15856-jex.md | 1 + jax/__init__.py | 12 +++++-- jax/_src/linear_util.py | 1 + jax/_src/pallas/pallas_call.py | 2 +- jax/_src/util.py | 2 +- jax/extend/__init__.py | 1 + jax/extend/linear_util.py | 26 +++++++++++++++ jax/linear_util.py | 59 ++++++++++++++++++++++++++++------ tests/extend_test.py | 10 ++++++ tests/pallas/pallas_test.py | 2 +- 12 files changed, 122 insertions(+), 15 deletions(-) create mode 100644 jax/extend/linear_util.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 39bc0a6b8..384911148 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list * Internal deprecations/removals: * The internal submodule `jax.prng` is now deprecated. Its contents are available at {mod}`jax.extend.random`. + * The internal submodule path `jax.linear_util` has been deprecated. Use + {mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`) ## jaxlib 0.4.16 @@ -73,6 +75,8 @@ Remember to align the itemized text with the first line of an item within a list * The utility `jax.interpreters.xla.register_collective_primitive` has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed. + * The internal submodule path `jax.linear_util` has been deprecated. Use + {mod}`jax.extend.linear_util` instead (Part of {ref}`jax-extend-jep`) ## jaxlib 0.4.15 (Aug 30 2023) diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index a846cf93a..f737e51ef 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -3,6 +3,23 @@ .. automodule:: jax.extend +``jax.extend.linear_util`` +-------------------------- + +.. automodule:: jax.extend.linear_util + +.. autosummary:: + :toctree: _autosummary + + StoreException + WrappedFun + cache + merge_linear_aux + transformation + transformation_with_aux + wrap_init + + ``jax.extend.random`` --------------------- diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index 6f97cc1e0..bec060001 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -1,3 +1,4 @@ +(jax-extend-jep)= # `jax.extend`: a module for extensions [@froystig](https://github.com/froystig), diff --git a/jax/__init__.py b/jax/__init__.py index 8d61800f7..327b53244 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -161,7 +161,7 @@ from jax import dtypes as dtypes from jax import errors as errors from jax import image as image from jax import lax as lax -from jax import linear_util as linear_util +from jax import linear_util as _deprecated_linear_util from jax import monitoring as monitoring from jax import nn as nn from jax import numpy as numpy @@ -210,12 +210,18 @@ _deprecations = { "tree_unflatten": ( "jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.", _deprecated_tree_unflatten - ) + ), + # Added Aug 29 2023 + "linear_util": ( + "jax.linear_util is deprecated: use jax.extend.linear_util.", + _deprecated_linear_util, + ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src import abstract_arrays as abstract_arrays + from jax import abstract_arrays as abstract_arrays + from jax import linear_util as linear_util from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index c41ddfe01..916895f88 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -235,6 +235,7 @@ class WrappedFun: @curry def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. + Args: gen: the transformation generator function fun: a WrappedFun on which to apply the transformation diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 385b6db88..20b7ead7b 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -22,7 +22,6 @@ from typing import Any, Callable, Dict, Sequence, Tuple import jax from jax import api_util -from jax import linear_util as lu from jax import tree_util from jax import lax from jax.interpreters import ad @@ -31,6 +30,7 @@ from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax._src import ad_util from jax._src import core as jax_core +from jax._src import linear_util as lu from jax._src.state import discharge as state_discharge from jax._src.util import ( split_list, safe_map, safe_zip, weakref_lru_cache, diff --git a/jax/_src/util.py b/jax/_src/util.py index 4603933d4..95d03451d 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -181,7 +181,7 @@ def curry(f): >>> curry(f)(2, 3, 4, 5)() 26 """ - return partial(partial, f) + return wraps(f)(partial(partial, f)) def toposort(end_nodes): if not end_nodes: return [] diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index 994108b14..bc2bb8201 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -29,5 +29,6 @@ Breaking changes will be announced via the """ from jax.extend import ( + linear_util as linear_util, random as random, ) diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py new file mode 100644 index 000000000..1706f8c8c --- /dev/null +++ b/jax/extend/linear_util.py @@ -0,0 +1,26 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.linear_util import ( + StoreException as StoreException, + WrappedFun as WrappedFun, + cache as cache, + merge_linear_aux as merge_linear_aux, + transformation as transformation, + transformation_with_aux as transformation_with_aux, + wrap_init as wrap_init, +) diff --git a/jax/linear_util.py b/jax/linear_util.py index 4f92e6ef6..ea5b751da 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -17,12 +17,53 @@ # TODO(jakevdp): deprecate these and remove this module. -from jax._src.linear_util import ( - StoreException as StoreException, - WrappedFun as WrappedFun, - cache as cache, - merge_linear_aux as merge_linear_aux, - transformation as transformation, - transformation_with_aux as transformation_with_aux, - wrap_init as wrap_init, -) +from jax._src import linear_util as _lu + + +_deprecations = { + # Added August 29, 2023: + "StoreException": ( + "jax.linear_util.StoreException is deprecated. Use jax.extend.linear_util.StoreException instead.", + _lu.StoreException, + ), + "WrappedFun": ( + "jax.linear_util.WrappedFun is deprecated. Use jax.extend.linear_util.WrappedFun instead.", + _lu.WrappedFun, + ), + "cache": ( + "jax.linear_util.cache is deprecated. Use jax.extend.linear_util.cache instead.", + _lu.cache, + ), + "merge_linear_aux": ( + "jax.linear_util.merge_linear_aux is deprecated. Use jax.extend.linear_util.merge_linear_aux instead.", + _lu.merge_linear_aux + ), + "transformation": ( + "jax.linear_util.transformation is deprecated. Use jax.extend.linear_util.transformation instead.", + _lu.transformation + ), + "transformation_with_aux": ( + "jax.linear_util.transformation_with_aux is deprecated. Use jax.extend.linear_util.transformation_with_aux instead.", + _lu.transformation_with_aux + ), + "wrap_init": ( + "jax.linear_util.wrap_init is deprecated. Use jax.extend.linear_util.wrap_init instead.", + _lu.wrap_init + ), +} + +import typing +if typing.TYPE_CHECKING: + StoreException = _lu.StoreException + WrappedFun = _lu.WrappedFun + cache = _lu.cache + merge_linear_aux = _lu.merge_linear_aux + transformation = _lu.transformation + transformation_with_aux = _lu.transformation_with_aux + wrap_init = _lu.wrap_init +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _lu diff --git a/tests/extend_test.py b/tests/extend_test.py index 78cf4685f..31c3e76be 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest import jax import jax.extend as jex +from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu @@ -35,6 +36,15 @@ class ExtendTest(jtu.JaxTestCase): self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl) self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl) + # Assume these are tested elsewhere, only check equivalence + self.assertIs(jex.linear_util.StoreException, linear_util.StoreException) + self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun) + self.assertIs(jex.linear_util.cache, linear_util.cache) + self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux) + self.assertIs(jex.linear_util.transformation, linear_util.transformation) + self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux) + self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init) + class RandomTest(jtu.JaxTestCase): def test_wrap_key_default(self): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 00360b264..452f0361c 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -24,8 +24,8 @@ from absl.testing import parameterized import jax from jax import lax -from jax import linear_util as lu from jax import random +from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import state from jax._src.lax.control_flow.for_loop import for_loop