[internal] add deprecation test utilities

This commit is contained in:
Jake VanderPlas 2025-01-09 15:11:20 -08:00
parent 5d0ee43222
commit 1ee015674f
7 changed files with 53 additions and 68 deletions

View File

@ -44,6 +44,7 @@ from jax import lax
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src import lib as _jaxlib
@ -1310,6 +1311,16 @@ class JaxTestCase(parameterized.TestCase):
def rng(self):
return self._rng
def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str):
"""Assert warning or error, depending on deprecation state.
For use with functions that call :func:`jax._src.deprecations.warn`.
"""
if deprecations.is_accelerated(deprecation_id):
return self.assertRaisesRegex(ValueError, message)
else:
return self.assertWarnsRegex(DeprecationWarning, message)
def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='',
allow_object_dtype=False, verbose=True):
"""Assert that x and y arrays are exactly equal."""

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from absl.testing import absltest
from jax._src import deprecations
from jax._src import test_util as jtu
@ -20,6 +22,14 @@ from jax._src.internal_test_util import deprecation_module as m
class DeprecationTest(absltest.TestCase):
@contextlib.contextmanager
def deprecation_context(self, deprecation_id):
deprecations.register(deprecation_id)
try:
yield
finally:
deprecations.unregister(deprecation_id)
def testModuleDeprecation(self):
with test_warning_util.raise_on_warnings():
self.assertEqual(m.x, 42)
@ -36,13 +46,10 @@ class DeprecationTest(absltest.TestCase):
def testNamedDeprecation(self):
some_unique_id = "some-unique-id"
try:
deprecations.register(some_unique_id)
with self.deprecation_context(some_unique_id):
self.assertFalse(deprecations.is_accelerated(some_unique_id))
deprecations.accelerate(some_unique_id)
self.assertTrue(deprecations.is_accelerated(some_unique_id))
finally:
deprecations.unregister(some_unique_id)
msg = f"deprecation_id={some_unique_id!r} not registered"
with self.assertRaisesRegex(ValueError, msg):
@ -52,6 +59,19 @@ class DeprecationTest(absltest.TestCase):
with self.assertRaisesRegex(ValueError, msg):
deprecations.unregister(some_unique_id)
def testNamedDeprecationWarns(self):
deprecation_id = "some-unique-id"
deprecation_message = "This API is deprecated."
with self.deprecation_context(deprecation_id):
self.assertFalse(deprecations.is_accelerated(deprecation_id))
with self.assertWarnsRegex(DeprecationWarning, deprecation_message):
deprecations.warn(deprecation_id, deprecation_message, stacklevel=1)
deprecations.accelerate(deprecation_id)
self.assertTrue(deprecations.is_accelerated(deprecation_id))
with self.assertRaisesRegex(ValueError, deprecation_message):
deprecations.warn(deprecation_id, deprecation_message, stacklevel=1)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -28,7 +28,6 @@ import jax.sharding as shd
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
@ -210,12 +209,8 @@ class FfiTest(jtu.JaxTestCase):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg):
jax.jit(fun).lower(jnp.ones(5))
def test_input_output_aliases(self):
def fun(x):

View File

@ -27,7 +27,6 @@ import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
@ -454,12 +453,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where",
f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where, initial=jnp.array(0, dtype=dtype))
@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
@ -468,12 +463,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
x = jnp.zeros((10,), dtype)
where = jnp.ones(10, dtype=int)
func = getattr(jnp, rec.name)
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
with self.assertDeprecationWarnsOrRaises("jax-numpy-reduction-non-boolean-where",
f"jnp.{rec.name}: where must be None or a boolean array"):
func(x, where=where)
@parameterized.parameters(itertools.chain.from_iterable(
@ -756,13 +747,8 @@ class JaxNumpyReducerTests(jtu.JaxTestCase):
)
def testQuantileDeprecatedArgs(self, op):
func = getattr(jnp, op)
msg = f"The interpolation= argument to '{op}' is deprecated. "
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-quantile-interpolation"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(msg):
with self.assertDeprecationWarnsOrRaises("jax-numpy-quantile-interpolation",
f"The interpolation= argument to '{op}' is deprecated. "):
func(jnp.arange(4), 0.5, interpolation='linear')
@unittest.skipIf(not config.enable_x64.value, "test requires X64")

View File

@ -47,7 +47,6 @@ from jax.test_util import check_grads
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
@ -1061,13 +1060,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.clip(x, max=jnp.array([-1+5j]))
def testClipDeprecatedArgs(self):
msg = "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-clip-args"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors(msg):
with self.assertDeprecationWarnsOrRaises("jax-numpy-clip-args",
"Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"):
jnp.clip(jnp.arange(4), a_min=2, a_max=3)
def testHypotComplexInputError(self):
@ -4186,13 +4180,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testAstypeComplexDowncast(self):
x = jnp.array(2.0+1.5j, dtype='complex64')
msg = "Casting from complex to real dtypes.*"
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-astype-complex-to-real"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors():
with self.assertDeprecationWarnsOrRaises("jax-numpy-astype-complex-to-real",
"Casting from complex to real dtypes.*"):
x.astype('float32')
@parameterized.parameters('int2', 'int4')

View File

@ -32,7 +32,6 @@ from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src import config
from jax._src import deprecations
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -1177,14 +1176,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3)
def testPinvDeprecatedArgs(self):
msg = "The rcond argument for linalg.pinv is deprecated."
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-linalg-pinv-rcond"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
x = jnp.ones((3, 3))
with assert_warns_or_errors(msg):
with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-pinv-rcond",
"The rcond argument for linalg.pinv is deprecated."):
jnp.linalg.pinv(x, rcond=1E-2)
def testPinvGradIssue2792(self):
@ -1230,14 +1224,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
check_dtypes=False, rtol=1e-3)
def testMatrixRankDeprecatedArgs(self):
msg = "The tol argument for linalg.matrix_rank is deprecated."
def assert_warns_or_errors(msg=msg):
if deprecations.is_accelerated("jax-numpy-linalg-matrix_rank-tol"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
x = jnp.ones((3, 3))
with assert_warns_or_errors(msg):
with self.assertDeprecationWarnsOrRaises("jax-numpy-linalg-matrix_rank-tol",
"The tol argument for linalg.matrix_rank is deprecated."):
jnp.linalg.matrix_rank(x, tol=1E-2)
@jtu.sample_product(

View File

@ -27,7 +27,6 @@ import scipy.stats
from jax._src import ad_checkpoint
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax.test_util import check_grads
@ -531,12 +530,8 @@ class NNFunctionsTest(jtu.JaxTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
def testOneHotNonInteger(self):
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-nn-one-hot-float-input"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"):
with self.assertDeprecationWarnsOrRaises("jax-nn-one-hot-float-input",
"jax.nn.one_hot input should be integer-typed"):
nn.one_hot(jnp.array([1.0]), 3)
def testTanhExists(self):