Deprecate remaining functionality in jax.test_util

This commit is contained in:
Jake VanderPlas 2022-04-21 12:12:40 -07:00
parent bef5e02816
commit d9508304e4
2 changed files with 38 additions and 20 deletions

View File

@ -25,6 +25,14 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
`_default_tolerance` ({jax-issue}`#10389`). These, along with previously-deprecated `JaxTestCase`,
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
Many of these utilities still exist in `jax._src.test_util`, but these are not public APIs and
as such may be changed or removed without notice.
## jaxlib 0.3.8 (Unreleased)
* [GitHub
@ -56,7 +64,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
pod. Fixes [#10218](https://github.com/google/jax/issues/10218).
* Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278`
for an alternative API.
for an alternative API.
## jax 0.3.5 (April 7, 2022)
* [GitHub

View File

@ -12,30 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from jax._src.public_test_util import (
from jax._src.public_test_util import ( # noqa: F401
check_grads as check_grads,
check_jvp as check_jvp,
check_vjp as check_vjp,
)
# Conditional imports of private test utilities; these require their own BUILD target.
# TODO(jakevdp): remove these imports once downstream dependencies are cleaned.
# TODO(jakevdp): remove everything below once downstream callers are fixed.
# Unconditionally import private test_util because it contains flag definitions.
# In bazel, jax._src.test_util requires its own BUILD target so it may not be present.
# pytype: disable=import-error
try:
from jax._src.test_util import ( # pytype: disable=import-error
cases_from_list,
check_close,
check_eq,
device_under_test,
format_shape_dtype_string,
rand_uniform,
skip_on_devices,
with_config as with_config,
xla_bridge,
_default_tolerance,
DeprecatedJaxTestCase as JaxTestCase,
DeprecatedJaxTestLoader as JaxTestLoader,
DeprecatedBufferDonationTestCase as BufferDonationTestCase,
)
import jax._src.test_util as _private_test_util
except ImportError:
pass
else:
del _private_test_util
# Use module-level getattr to add warnings to imports of deprecated names.
# pylint: disable=import-outside-toplevel
def __getattr__(attr):
try:
from jax._src import test_util
except ImportError:
raise AttributeError(f"module {__name__} has no attribute {attr}")
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
'with_config', 'xla_bridge', '_default_tolerance']:
import warnings
warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
return getattr(test_util, attr)
elif attr in ['JaxTestCase', 'JaxTestLoader', 'BufferDonationTestCase']:
# Do the TestCase imports separately, since they were previously deprecated via a different
# mechanism & we don't want to annoy projects who may have temporarily filtered a specific warning.
return getattr(test_util, 'Deprecated' + attr)
else:
raise AttributeError(f"module {__name__} has no attribute {attr}")