From 57e59eb6c309acfdfb49a4ce76c996fc5c2a016c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 20 Feb 2024 11:24:40 -0800 Subject: [PATCH] Removed deprecated jax.config methods and jax.config.config Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7 PiperOrigin-RevId: 608676645 --- .github/workflows/ci-build.yaml | 2 +- CHANGELOG.md | 8 ++++++++ jax/__init__.py | 6 ------ jax/_src/config.py | 32 ----------------------------- jax/config.py | 29 +++++++++----------------- tests/BUILD | 8 -------- tests/config_test.py | 36 --------------------------------- 7 files changed, 19 insertions(+), 102 deletions(-) delete mode 100644 tests/config_test.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index e6c7455a2..bcdbcd7f9 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -140,7 +140,7 @@ jobs: PY_COLORS: 1 run: | pytest -n auto --tb=short docs - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas documentation_render: diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e9656db0..bf3e5360b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,14 @@ Remember to align the itemized text with the first line of an item within a list * Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy. + * The previously deprecated configuration APIs have been removed + following a standard 3 months deprecation cycle (see {ref}`api-compatibility`). + These include + * the `jax.config.config` object and + * the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`. + * Importing the `jax.config` submodule via `import jax.config` is deprecated. + To configure JAX use `import jax` and then reference the config object + via `jax.config`. ## jaxlib 0.4.25 diff --git a/jax/__init__.py b/jax/__init__.py index ece7c7611..68279407f 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -33,12 +33,6 @@ except Exception as exc: del _warn del _cloud_tpu_init -# Confusingly there are two things named "config": the module and the class. -# We want the exported object to be the class, so we first import the module -# to make sure a later import doesn't overwrite the class. -from jax import config as _config_module -del _config_module - # Force early import, allowing use of `jax.core` after importing `jax`. import jax.core as _core del _core diff --git a/jax/_src/config.py b/jax/_src/config.py index a9d7050d2..497085f92 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -23,7 +23,6 @@ import os import sys import threading from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast -import warnings from jax._src import lib from jax._src.lib import jax_jit @@ -70,23 +69,6 @@ UPGRADE_BOOL_HELP = ( UPGRADE_BOOL_EXTRA_DESC = " (transient)" -_CONFIG_DEPRECATIONS = { - # Added October 26, 2023: - "check_exists", - "DEFINE_bool", - "DEFINE_integer", - "DEFINE_float", - "DEFINE_string", - "DEFINE_enum", - "define_bool_state", - "define_enum_state", - "define_int_state", - "define_float_state", - "define_string_state", - "define_string_or_object_state", -} - - class Config: _HAS_DYNAMIC_ATTRIBUTES = True @@ -100,20 +82,6 @@ class Config: self.use_absl = False self._contextmanager_flags = set() - def __getattr__(self, name): - fn = None - if name in _CONFIG_DEPRECATIONS: - fn = globals().get(name, None) - if fn is None: - raise AttributeError( - f"'{type(self).__name__!r} object has no attribute {name!r}") - message = ( - f"jax.config.{name} is deprecated. Please use other libraries " - "for configuration instead." - ) - warnings.warn(message, DeprecationWarning, stacklevel=2) - return fn - def update(self, name, val): if name not in self._value_holders: raise AttributeError(f"Unrecognized config option: {name}") diff --git a/jax/config.py b/jax/config.py index 9435308d1..763fe6c0f 100644 --- a/jax/config.py +++ b/jax/config.py @@ -12,23 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.config import config as _deprecated_config # noqa: F401 +import warnings -# Deprecations - -_deprecations = { - # Added October 27, 2023 - "config": ( - "Accessing jax.config via the jax.config submodule is deprecated.", - _deprecated_config), -} - -import typing -if typing.TYPE_CHECKING: - config = _deprecated_config -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing -del _deprecated_config +# Added February 16, 2024. +warnings.warn( + "Importing the jax.config submodule via `import jax.config` is deprecated." + " To configure JAX use `import jax` and then reference the config object" + " via `jax.config`.", + DeprecationWarning, + stacklevel=2, +) +del warnings diff --git a/tests/BUILD b/tests/BUILD index dde59fd5d..5c6b35e5a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -73,14 +73,6 @@ jax_test( }, ) -py_test( - name = "config_test", - srcs = ["config_test.py"], - deps = [ - "//jax", - ], -) - jax_test( name = "core_test", srcs = ["core_test.py"], diff --git a/tests/config_test.py b/tests/config_test.py deleted file mode 100644 index b801f1d7c..000000000 --- a/tests/config_test.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from jax import config - -class ConfigTest(unittest.TestCase): - def test_deprecations(self): - for name in ["DEFINE_bool", "define_bool_state"]: - with ( - self.subTest(name), - self.assertWarnsRegex( - DeprecationWarning, - "other libraries for configuration"), - ): - getattr(config, name) - - def test_missing_attribute(self): - with self.assertRaises(AttributeError): - config.missing_attribute - - -if __name__ == '__main__': - unittest.main()