Removed deprecated jax.config methods and jax.config.config

Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
This commit is contained in:
Sergei Lebedev 2024-02-20 11:24:40 -08:00 committed by jax authors
parent f1ea67117e
commit 57e59eb6c3
7 changed files with 19 additions and 102 deletions

View File

@ -140,7 +140,7 @@ jobs:
PY_COLORS: 1 PY_COLORS: 1
run: | run: |
pytest -n auto --tb=short docs pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
documentation_render: documentation_render:

View File

@ -25,6 +25,14 @@ Remember to align the itemized text with the first line of an item within a list
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless * Conversion of a non-scalar array to a Python scalar now raises an error, regardless
of the size of the array. Previously a deprecation warning was raised in the case of of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy. non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
* The previously deprecated configuration APIs have been removed
following a standard 3 months deprecation cycle (see {ref}`api-compatibility`).
These include
* the `jax.config.config` object and
* the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`.
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
To configure JAX use `import jax` and then reference the config object
via `jax.config`.
## jaxlib 0.4.25 ## jaxlib 0.4.25

View File

@ -33,12 +33,6 @@ except Exception as exc:
del _warn del _warn
del _cloud_tpu_init del _cloud_tpu_init
# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
from jax import config as _config_module
del _config_module
# Force early import, allowing use of `jax.core` after importing `jax`. # Force early import, allowing use of `jax.core` after importing `jax`.
import jax.core as _core import jax.core as _core
del _core del _core

View File

@ -23,7 +23,6 @@ import os
import sys import sys
import threading import threading
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
import warnings
from jax._src import lib from jax._src import lib
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
@ -70,23 +69,6 @@ UPGRADE_BOOL_HELP = (
UPGRADE_BOOL_EXTRA_DESC = " (transient)" UPGRADE_BOOL_EXTRA_DESC = " (transient)"
_CONFIG_DEPRECATIONS = {
# Added October 26, 2023:
"check_exists",
"DEFINE_bool",
"DEFINE_integer",
"DEFINE_float",
"DEFINE_string",
"DEFINE_enum",
"define_bool_state",
"define_enum_state",
"define_int_state",
"define_float_state",
"define_string_state",
"define_string_or_object_state",
}
class Config: class Config:
_HAS_DYNAMIC_ATTRIBUTES = True _HAS_DYNAMIC_ATTRIBUTES = True
@ -100,20 +82,6 @@ class Config:
self.use_absl = False self.use_absl = False
self._contextmanager_flags = set() self._contextmanager_flags = set()
def __getattr__(self, name):
fn = None
if name in _CONFIG_DEPRECATIONS:
fn = globals().get(name, None)
if fn is None:
raise AttributeError(
f"'{type(self).__name__!r} object has no attribute {name!r}")
message = (
f"jax.config.{name} is deprecated. Please use other libraries "
"for configuration instead."
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
return fn
def update(self, name, val): def update(self, name, val):
if name not in self._value_holders: if name not in self._value_holders:
raise AttributeError(f"Unrecognized config option: {name}") raise AttributeError(f"Unrecognized config option: {name}")

View File

@ -12,23 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from jax._src.config import config as _deprecated_config # noqa: F401 import warnings
# Deprecations # Added February 16, 2024.
warnings.warn(
_deprecations = { "Importing the jax.config submodule via `import jax.config` is deprecated."
# Added October 27, 2023 " To configure JAX use `import jax` and then reference the config object"
"config": ( " via `jax.config`.",
"Accessing jax.config via the jax.config submodule is deprecated.", DeprecationWarning,
_deprecated_config), stacklevel=2,
} )
del warnings
import typing
if typing.TYPE_CHECKING:
config = _deprecated_config
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _deprecated_config

View File

@ -73,14 +73,6 @@ jax_test(
}, },
) )
py_test(
name = "config_test",
srcs = ["config_test.py"],
deps = [
"//jax",
],
)
jax_test( jax_test(
name = "core_test", name = "core_test",
srcs = ["core_test.py"], srcs = ["core_test.py"],

View File

@ -1,36 +0,0 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from jax import config
class ConfigTest(unittest.TestCase):
def test_deprecations(self):
for name in ["DEFINE_bool", "define_bool_state"]:
with (
self.subTest(name),
self.assertWarnsRegex(
DeprecationWarning,
"other libraries for configuration"),
):
getattr(config, name)
def test_missing_attribute(self):
with self.assertRaises(AttributeError):
config.missing_attribute
if __name__ == '__main__':
unittest.main()