mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate remaining functionality in jax.test_util
This commit is contained in:
parent
bef5e02816
commit
d9508304e4
10
CHANGELOG.md
10
CHANGELOG.md
@ -25,6 +25,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
are not of an integer type, matching the behavior of
|
are not of an integer type, matching the behavior of
|
||||||
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
|
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
|
||||||
cast to integers.
|
cast to integers.
|
||||||
|
* Deprecations
|
||||||
|
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
||||||
|
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
||||||
|
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
|
||||||
|
`_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`,
|
||||||
|
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
|
||||||
|
Many of these utilities still exist in `jax._src.test_util`, but these are not public APIs and
|
||||||
|
as such may be changed or removed without notice.
|
||||||
|
|
||||||
## jaxlib 0.3.8 (Unreleased)
|
## jaxlib 0.3.8 (Unreleased)
|
||||||
* [GitHub
|
* [GitHub
|
||||||
@ -56,7 +64,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
|
pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
|
||||||
* Deprecations:
|
* Deprecations:
|
||||||
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
|
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
|
||||||
for an alternative API.
|
for an alternative API.
|
||||||
|
|
||||||
## jax 0.3.5 (April 7, 2022)
|
## jax 0.3.5 (April 7, 2022)
|
||||||
* [GitHub
|
* [GitHub
|
||||||
|
@ -12,30 +12,40 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# flake8: noqa: F401
|
from jax._src.public_test_util import ( # noqa: F401
|
||||||
from jax._src.public_test_util import (
|
|
||||||
check_grads as check_grads,
|
check_grads as check_grads,
|
||||||
check_jvp as check_jvp,
|
check_jvp as check_jvp,
|
||||||
check_vjp as check_vjp,
|
check_vjp as check_vjp,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conditional imports of private test utilities; these require their own BUILD target.
|
# TODO(jakevdp): remove everything below once downstream callers are fixed.
|
||||||
# TODO(jakevdp): remove these imports once downstream dependencies are cleaned.
|
|
||||||
|
# Unconditionally import private test_util because it contains flag definitions.
|
||||||
|
# In bazel, jax._src.test_util requires its own BUILD target so it may not be present.
|
||||||
|
# pytype: disable=import-error
|
||||||
try:
|
try:
|
||||||
from jax._src.test_util import ( # pytype: disable=import-error
|
import jax._src.test_util as _private_test_util
|
||||||
cases_from_list,
|
|
||||||
check_close,
|
|
||||||
check_eq,
|
|
||||||
device_under_test,
|
|
||||||
format_shape_dtype_string,
|
|
||||||
rand_uniform,
|
|
||||||
skip_on_devices,
|
|
||||||
with_config as with_config,
|
|
||||||
xla_bridge,
|
|
||||||
_default_tolerance,
|
|
||||||
DeprecatedJaxTestCase as JaxTestCase,
|
|
||||||
DeprecatedJaxTestLoader as JaxTestLoader,
|
|
||||||
DeprecatedBufferDonationTestCase as BufferDonationTestCase,
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
else:
|
||||||
|
del _private_test_util
|
||||||
|
|
||||||
|
# Use module-level getattr to add warnings to imports of deprecated names.
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
def __getattr__(attr):
|
||||||
|
try:
|
||||||
|
from jax._src import test_util
|
||||||
|
except ImportError:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||||
|
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
|
||||||
|
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
|
||||||
|
'with_config', 'xla_bridge', '_default_tolerance']:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
|
||||||
|
return getattr(test_util, attr)
|
||||||
|
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
|
||||||
|
# Do the TestCase imports separately, since they were previously deprecated via a different
|
||||||
|
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
|
||||||
|
return getattr(test_util, 'Deprecated' + attr)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user