mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[internal] add deprecation test utilities
This commit is contained in:
parent
5d0ee43222
commit
1ee015674f
@ -44,6 +44,7 @@ from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import lib as _jaxlib
|
||||
@ -1310,6 +1311,16 @@ class JaxTestCase(parameterized.TestCase):
|
||||
def rng(self):
|
||||
return self._rng
|
||||
|
||||
def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
|
||||
"""Assert warning or error, depending on deprecation state.
|
||||
|
||||
For use with functions that call :func:`jax._src.deprecations.warn`.
|
||||
"""
|
||||
if deprecations.is_accelerated(deprecation_id):
|
||||
return self.assertRaisesRegex(ValueError, message)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, message)
|
||||
|
||||
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
|
||||
allow_object_dtype=False, verbose=True):
|
||||
"""Assert that x and y arrays are exactly equal."""
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
@ -20,6 +22,14 @@ from jax._src.internal_test_util import deprecation_module as m
|
||||
|
||||
class DeprecationTest(absltest.TestCase):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def deprecation_context(self, deprecation_id):
|
||||
deprecations.register(deprecation_id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deprecations.unregister(deprecation_id)
|
||||
|
||||
def testModuleDeprecation(self):
|
||||
with test_warning_util.raise_on_warnings():
|
||||
self.assertEqual(m.x, 42)
|
||||
@ -36,13 +46,10 @@ class DeprecationTest(absltest.TestCase):
|
||||
|
||||
def testNamedDeprecation(self):
|
||||
some_unique_id = "some-unique-id"
|
||||
try:
|
||||
deprecations.register(some_unique_id)
|
||||
with self.deprecation_context(some_unique_id):
|
||||
self.assertFalse(deprecations.is_accelerated(some_unique_id))
|
||||
deprecations.accelerate(some_unique_id)
|
||||
self.assertTrue(deprecations.is_accelerated(some_unique_id))
|
||||
finally:
|
||||
deprecations.unregister(some_unique_id)
|
||||
|
||||
msg = f"deprecation_id={some_unique_id!r} not registered"
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
@ -52,6 +59,19 @@ class DeprecationTest(absltest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
deprecations.unregister(some_unique_id)
|
||||
|
||||
def testNamedDeprecationWarns(self):
|
||||
deprecation_id = "some-unique-id"
|
||||
deprecation_message = "This API is deprecated."
|
||||
with self.deprecation_context(deprecation_id):
|
||||
self.assertFalse(deprecations.is_accelerated(deprecation_id))
|
||||
with self.assertWarnsRegex(DeprecationWarning, deprecation_message):
|
||||
deprecations.warn(deprecation_id, deprecation_message, stacklevel=1)
|
||||
|
||||
deprecations.accelerate(deprecation_id)
|
||||
self.assertTrue(deprecations.is_accelerated(deprecation_id))
|
||||
with self.assertRaisesRegex(ValueError, deprecation_message):
|
||||
deprecations.warn(deprecation_id, deprecation_message, stacklevel=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -28,7 +28,6 @@ import jax.sharding as shd
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
@ -210,12 +209,8 @@ class FfiTest(jtu.JaxTestCase):
|
||||
def fun(x):
|
||||
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
if deprecations.is_accelerated("jax-ffi-call-args"):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def test_input_output_aliases(self):
|
||||
def fun(x):
|
||||
|
@ -27,7 +27,6 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
@ -454,12 +453,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
x = jnp.zeros((10,), dtype)
|
||||
where = jnp.ones(10, dtype=int)
|
||||
func = getattr(jnp, rec.name)
|
||||
def assert_warns_or_errors(msg):
|
||||
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where",
|
||||
f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
func(x, where=where, initial=jnp.array(0, dtype=dtype))
|
||||
|
||||
@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
|
||||
@ -468,12 +463,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
x = jnp.zeros((10,), dtype)
|
||||
where = jnp.ones(10, dtype=int)
|
||||
func = getattr(jnp, rec.name)
|
||||
def assert_warns_or_errors(msg):
|
||||
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where",
|
||||
f"jnp.{rec.name}: where must be None or a boolean array"):
|
||||
func(x, where=where)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
@ -756,13 +747,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
|
||||
)
|
||||
def testQuantileDeprecatedArgs(self, op):
|
||||
func = getattr(jnp, op)
|
||||
msg = f"The interpolation= argument to '{op}' is deprecated. "
|
||||
def assert_warns_or_errors(msg=msg):
|
||||
if deprecations.is_accelerated("jax-numpy-quantile-interpolation"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(msg):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-quantile-interpolation",
|
||||
f"The interpolation= argument to '{op}' is deprecated. "):
|
||||
func(jnp.arange(4), 0.5, interpolation='linear')
|
||||
|
||||
@unittest.skipIf(not config.enable_x64.value, "test requires X64")
|
||||
|
@ -47,7 +47,6 @@ from jax.test_util import check_grads
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -1061,13 +1060,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.clip(x, max=jnp.array([-1+5j]))
|
||||
|
||||
def testClipDeprecatedArgs(self):
|
||||
msg = "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"
|
||||
def assert_warns_or_errors(msg=msg):
|
||||
if deprecations.is_accelerated("jax-numpy-clip-args"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors(msg):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-clip-args",
|
||||
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"):
|
||||
jnp.clip(jnp.arange(4), a_min=2, a_max=3)
|
||||
|
||||
def testHypotComplexInputError(self):
|
||||
@ -4186,13 +4180,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def testAstypeComplexDowncast(self):
|
||||
x = jnp.array(2.0+1.5j, dtype='complex64')
|
||||
msg = "Casting from complex to real dtypes.*"
|
||||
def assert_warns_or_errors(msg=msg):
|
||||
if deprecations.is_accelerated("jax-numpy-astype-complex-to-real"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors():
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-astype-complex-to-real",
|
||||
"Casting from complex to real dtypes.*"):
|
||||
x.astype('float32')
|
||||
|
||||
@parameterized.parameters('int2', 'int4')
|
||||
|
@ -32,7 +32,6 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax._src import config
|
||||
from jax._src import deprecations
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -1177,14 +1176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3)
|
||||
|
||||
def testPinvDeprecatedArgs(self):
|
||||
msg = "The rcond argument for linalg.pinv is deprecated."
|
||||
def assert_warns_or_errors(msg=msg):
|
||||
if deprecations.is_accelerated("jax-numpy-linalg-pinv-rcond"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
x = jnp.ones((3, 3))
|
||||
with assert_warns_or_errors(msg):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-pinv-rcond",
|
||||
"The rcond argument for linalg.pinv is deprecated."):
|
||||
jnp.linalg.pinv(x, rcond=1E-2)
|
||||
|
||||
def testPinvGradIssue2792(self):
|
||||
@ -1230,14 +1224,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
check_dtypes=False, rtol=1e-3)
|
||||
|
||||
def testMatrixRankDeprecatedArgs(self):
|
||||
msg = "The tol argument for linalg.matrix_rank is deprecated."
|
||||
def assert_warns_or_errors(msg=msg):
|
||||
if deprecations.is_accelerated("jax-numpy-linalg-matrix_rank-tol"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
x = jnp.ones((3, 3))
|
||||
with assert_warns_or_errors(msg):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-matrix_rank-tol",
|
||||
"The tol argument for linalg.matrix_rank is deprecated."):
|
||||
jnp.linalg.matrix_rank(x, tol=1E-2)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
@ -27,7 +27,6 @@ import scipy.stats
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax.test_util import check_grads
|
||||
@ -531,12 +530,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
|
||||
def testOneHotNonInteger(self):
|
||||
def assert_warns_or_errors(msg):
|
||||
if deprecations.is_accelerated("jax-nn-one-hot-float-input"):
|
||||
return self.assertRaisesRegex(ValueError, msg)
|
||||
else:
|
||||
return self.assertWarnsRegex(DeprecationWarning, msg)
|
||||
with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"):
|
||||
with self.assertDeprecationWarnsOrRaises("jax-nn-one-hot-float-input",
|
||||
"jax.nn.one_hot input should be integer-typed"):
|
||||
nn.one_hot(jnp.array([1.0]), 3)
|
||||
|
||||
def testTanhExists(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user