mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7 PiperOrigin-RevId: 608676645
This commit is contained in:
parent
f1ea67117e
commit
57e59eb6c3
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -140,7 +140,7 @@ jobs:
|
|||||||
PY_COLORS: 1
|
PY_COLORS: 1
|
||||||
run: |
|
run: |
|
||||||
pytest -n auto --tb=short docs
|
pytest -n auto --tb=short docs
|
||||||
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
|
||||||
|
|
||||||
|
|
||||||
documentation_render:
|
documentation_render:
|
||||||
|
@ -25,6 +25,14 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
|
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
|
||||||
of the size of the array. Previously a deprecation warning was raised in the case of
|
of the size of the array. Previously a deprecation warning was raised in the case of
|
||||||
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
|
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
|
||||||
|
* The previously deprecated configuration APIs have been removed
|
||||||
|
following a standard 3 months deprecation cycle (see {ref}`api-compatibility`).
|
||||||
|
These include
|
||||||
|
* the `jax.config.config` object and
|
||||||
|
* the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`.
|
||||||
|
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
|
||||||
|
To configure JAX use `import jax` and then reference the config object
|
||||||
|
via `jax.config`.
|
||||||
|
|
||||||
## jaxlib 0.4.25
|
## jaxlib 0.4.25
|
||||||
|
|
||||||
|
@ -33,12 +33,6 @@ except Exception as exc:
|
|||||||
del _warn
|
del _warn
|
||||||
del _cloud_tpu_init
|
del _cloud_tpu_init
|
||||||
|
|
||||||
# Confusingly there are two things named "config": the module and the class.
|
|
||||||
# We want the exported object to be the class, so we first import the module
|
|
||||||
# to make sure a later import doesn't overwrite the class.
|
|
||||||
from jax import config as _config_module
|
|
||||||
del _config_module
|
|
||||||
|
|
||||||
# Force early import, allowing use of `jax.core` after importing `jax`.
|
# Force early import, allowing use of `jax.core` after importing `jax`.
|
||||||
import jax.core as _core
|
import jax.core as _core
|
||||||
del _core
|
del _core
|
||||||
|
@ -23,7 +23,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
|
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
|
||||||
import warnings
|
|
||||||
|
|
||||||
from jax._src import lib
|
from jax._src import lib
|
||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
@ -70,23 +69,6 @@ UPGRADE_BOOL_HELP = (
|
|||||||
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
||||||
|
|
||||||
|
|
||||||
_CONFIG_DEPRECATIONS = {
|
|
||||||
# Added October 26, 2023:
|
|
||||||
"check_exists",
|
|
||||||
"DEFINE_bool",
|
|
||||||
"DEFINE_integer",
|
|
||||||
"DEFINE_float",
|
|
||||||
"DEFINE_string",
|
|
||||||
"DEFINE_enum",
|
|
||||||
"define_bool_state",
|
|
||||||
"define_enum_state",
|
|
||||||
"define_int_state",
|
|
||||||
"define_float_state",
|
|
||||||
"define_string_state",
|
|
||||||
"define_string_or_object_state",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||||
|
|
||||||
@ -100,20 +82,6 @@ class Config:
|
|||||||
self.use_absl = False
|
self.use_absl = False
|
||||||
self._contextmanager_flags = set()
|
self._contextmanager_flags = set()
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
fn = None
|
|
||||||
if name in _CONFIG_DEPRECATIONS:
|
|
||||||
fn = globals().get(name, None)
|
|
||||||
if fn is None:
|
|
||||||
raise AttributeError(
|
|
||||||
f"'{type(self).__name__!r} object has no attribute {name!r}")
|
|
||||||
message = (
|
|
||||||
f"jax.config.{name} is deprecated. Please use other libraries "
|
|
||||||
"for configuration instead."
|
|
||||||
)
|
|
||||||
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
|
||||||
return fn
|
|
||||||
|
|
||||||
def update(self, name, val):
|
def update(self, name, val):
|
||||||
if name not in self._value_holders:
|
if name not in self._value_holders:
|
||||||
raise AttributeError(f"Unrecognized config option: {name}")
|
raise AttributeError(f"Unrecognized config option: {name}")
|
||||||
|
@ -12,23 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from jax._src.config import config as _deprecated_config # noqa: F401
|
import warnings
|
||||||
|
|
||||||
# Deprecations
|
# Added February 16, 2024.
|
||||||
|
warnings.warn(
|
||||||
_deprecations = {
|
"Importing the jax.config submodule via `import jax.config` is deprecated."
|
||||||
# Added October 27, 2023
|
" To configure JAX use `import jax` and then reference the config object"
|
||||||
"config": (
|
" via `jax.config`.",
|
||||||
"Accessing jax.config via the jax.config submodule is deprecated.",
|
DeprecationWarning,
|
||||||
_deprecated_config),
|
stacklevel=2,
|
||||||
}
|
)
|
||||||
|
del warnings
|
||||||
import typing
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
config = _deprecated_config
|
|
||||||
else:
|
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
|
||||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
|
||||||
del _deprecation_getattr
|
|
||||||
del typing
|
|
||||||
del _deprecated_config
|
|
||||||
|
@ -73,14 +73,6 @@ jax_test(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "config_test",
|
|
||||||
srcs = ["config_test.py"],
|
|
||||||
deps = [
|
|
||||||
"//jax",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
jax_test(
|
jax_test(
|
||||||
name = "core_test",
|
name = "core_test",
|
||||||
srcs = ["core_test.py"],
|
srcs = ["core_test.py"],
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
# Copyright 2023 The JAX Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from jax import config
|
|
||||||
|
|
||||||
class ConfigTest(unittest.TestCase):
|
|
||||||
def test_deprecations(self):
|
|
||||||
for name in ["DEFINE_bool", "define_bool_state"]:
|
|
||||||
with (
|
|
||||||
self.subTest(name),
|
|
||||||
self.assertWarnsRegex(
|
|
||||||
DeprecationWarning,
|
|
||||||
"other libraries for configuration"),
|
|
||||||
):
|
|
||||||
getattr(config, name)
|
|
||||||
|
|
||||||
def test_missing_attribute(self):
|
|
||||||
with self.assertRaises(AttributeError):
|
|
||||||
config.missing_attribute
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
Loading…
x
Reference in New Issue
Block a user