mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate remaining functionality in jax.test_util
This commit is contained in:
parent
bef5e02816
commit
d9508304e4
10
CHANGELOG.md
10
CHANGELOG.md
@ -25,6 +25,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
are not of an integer type, matching the behavior of
|
||||
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
|
||||
cast to integers.
|
||||
* Deprecations
|
||||
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
|
||||
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
|
||||
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
|
||||
`_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`,
|
||||
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
|
||||
Many of these utilities still exist in `jax._src.test_util`, but these are not public APIs and
|
||||
as such may be changed or removed without notice.
|
||||
|
||||
## jaxlib 0.3.8 (Unreleased)
|
||||
* [GitHub
|
||||
@ -56,7 +64,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
|
||||
* Deprecations:
|
||||
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
|
||||
for an alternative API.
|
||||
for an alternative API.
|
||||
|
||||
## jax 0.3.5 (April 7, 2022)
|
||||
* [GitHub
|
||||
|
@ -12,30 +12,40 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.public_test_util import (
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
check_grads as check_grads,
|
||||
check_jvp as check_jvp,
|
||||
check_vjp as check_vjp,
|
||||
)
|
||||
|
||||
# Conditional imports of private test utilities; these require their own BUILD target.
|
||||
# TODO(jakevdp): remove these imports once downstream dependencies are cleaned.
|
||||
# TODO(jakevdp): remove everything below once downstream callers are fixed.
|
||||
|
||||
# Unconditionally import private test_util because it contains flag definitions.
|
||||
# In bazel, jax._src.test_util requires its own BUILD target so it may not be present.
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
from jax._src.test_util import ( # pytype: disable=import-error
|
||||
cases_from_list,
|
||||
check_close,
|
||||
check_eq,
|
||||
device_under_test,
|
||||
format_shape_dtype_string,
|
||||
rand_uniform,
|
||||
skip_on_devices,
|
||||
with_config as with_config,
|
||||
xla_bridge,
|
||||
_default_tolerance,
|
||||
DeprecatedJaxTestCase as JaxTestCase,
|
||||
DeprecatedJaxTestLoader as JaxTestLoader,
|
||||
DeprecatedBufferDonationTestCase as BufferDonationTestCase,
|
||||
)
|
||||
import jax._src.test_util as _private_test_util
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
del _private_test_util
|
||||
|
||||
# Use module-level getattr to add warnings to imports of deprecated names.
|
||||
# pylint: disable=import-outside-toplevel
|
||||
def __getattr__(attr):
|
||||
try:
|
||||
from jax._src import test_util
|
||||
except ImportError:
|
||||
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
|
||||
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
|
||||
'with_config', 'xla_bridge', '_default_tolerance']:
|
||||
import warnings
|
||||
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
|
||||
return getattr(test_util, attr)
|
||||
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
|
||||
# Do the TestCase imports separately, since they were previously deprecated via a different
|
||||
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
|
||||
return getattr(test_util, 'Deprecated' + attr)
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user