Merge pull request #21834 from jakevdp:jit-warning

PiperOrigin-RevId: 643050911
This commit is contained in:
jax authors 2024-06-13 10:46:47 -07:00
commit a123470810
5 changed files with 66 additions and 24 deletions

View File

@ -54,7 +54,6 @@ filterwarnings = [
"error",
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
"default:backend and device argument on jit is deprecated.*:DeprecationWarning",
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
# TODO(jakevdp): remove when array_api_tests stabilize
# start array_api_tests-related warnings

View File

@ -229,7 +229,9 @@ class JitTest(jtu.BufferDonationTestCase):
def test_jit_device(self):
device = jax.devices()[-1]
x = jit(lambda x: x, device=device)(3.)
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
x = jit(lambda x: x, device=device)(3.)
_check_instance(self, x)
self.assertEqual(x.devices(), {device})

View File

@ -244,14 +244,16 @@ class HostCallbackImportsTest(jtu.JaxTestCase):
class HostCallbackTapTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
# skipping here skips teardown, so do this before super().setUp().
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
if xla_bridge.using_pjrt_c_api():
raise SkipTest("host_callback not implemented in PJRT C API")
super().setUp()
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="backend and device argument"))
testing_stream.reset()
testing_stream._test_method_name = self._testMethodName
self.old_flags = os.getenv("XLA_FLAGS", "")
@ -2040,13 +2042,16 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"""Tests for hcb.call"""
def setUp(self):
super().setUp()
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
# skipping here skips teardown, so do this before super().setUp().
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
if xla_bridge.using_pjrt_c_api():
raise SkipTest("host_callback not implemented in PJRT C API")
super().setUp()
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
self.enter_context(jtu.ignore_warning(
category=DeprecationWarning, message="backend and device argument"))
testing_stream.reset()
testing_stream._test_method_name = self._testMethodName

View File

@ -105,6 +105,8 @@ class MultiBackendTest(jtu.JaxTestCase):
self.assertRaises(ValueError, lambda: fun(x, y))
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu'])
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def testGpuMultiBackendOpByOpReturn(self, backend):
if backend not in ('cpu', jtu.device_under_test()):
raise SkipTest("Backend is not CPU or the device under test")
@ -119,6 +121,8 @@ class MultiBackendTest(jtu.JaxTestCase):
self.assertEqual(list(w.devices())[0].platform, backend)
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def testJitCpu(self):
@partial(jax.jit, backend='cpu')
def get_arr(scale):
@ -135,6 +139,8 @@ class MultiBackendTest(jtu.JaxTestCase):
self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def test_closed_over_values_device_placement(self):
# see https://github.com/google/jax/issues/1431
def f(): return jnp.add(3., 4.)
@ -144,6 +150,8 @@ class MultiBackendTest(jtu.JaxTestCase):
{jax.devices('cpu')[0]})
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def test_jit_on_nondefault_backend(self):
cpus = jax.devices("cpu")
self.assertNotEmpty(cpus)

View File

@ -2464,6 +2464,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"):
my_nested_pjit(committed_inp, committed_inp, committed_inp)
@jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument")
def test_jit_device_with_sharding_constraint_error(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -2889,7 +2891,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
f = pjit(mul, device=jax.devices()[1])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
f = pjit(mul, device=jax.devices()[1])
x = jnp.arange(8).reshape(4, 2)
f_out = f(x)
f_out2 = f(f_out)
@ -2904,7 +2908,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
h = pjit(mul, device=jax.devices()[-1])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
h = pjit(mul, device=jax.devices()[-1])
h_out = h(y)
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
_check(h_out, jax.devices()[-1], y)
@ -2924,7 +2930,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = pjit(lambda x: x * 2)(y)
expected_device = jax.devices()[2]
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
self.assertEqual(final_out.devices(), {expected_device})
self.assertLen(final_out.sharding.device_set, 1)
@ -2938,7 +2946,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out, expected_out)
x = jnp.arange(8)
g = pjit(lambda x: x, backend='tpu')
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
g = pjit(lambda x: x, backend='tpu')
g_out = g(x)
_check(g_out, jax.devices()[0], x)
@ -2951,8 +2961,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.skipTest('Test requires more >1 device.')
# Add a constant captured by the nested pjit to make things more complicated
h = jnp.arange(4.)
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
def test_pjit_device_backend_axis_resources_error(self):
@ -2961,13 +2973,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
ValueError,
'If backend or device is specified on jit, then '
'in_shardings should not be specified.'):
pjit(lambda x: x, in_shardings=s, backend='cpu')
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x, in_shardings=s, backend='cpu')
with self.assertRaisesRegex(
ValueError,
'If backend or device is specified on jit, then '
'out_shardings should not be specified.'):
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
def test_check_arg_error(self):
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
@ -2982,7 +2998,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_pjit_device_backend_both_error(self):
with self.assertRaisesRegex(
ValueError, "can't specify both a device and a backend for jit"):
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
def test_pjit_mesh_with_device_or_backend_error(self):
mesh = jtu.create_global_mesh((1,), ('x',))
@ -2991,7 +3009,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
ValueError,
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit."):
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
def test_pjit_inline(self):
@partial(pjit, inline=False)
@ -3208,7 +3228,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
self.assertEqual(count[0], 10)
pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
@ -3217,7 +3239,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
pf(inp)
self.assertEqual(count[0], 1)
pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf1(inp)
@ -3632,7 +3656,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertNotEqual(x.sharding, y.sharding)
def test_vmap_pjit_single_device(self):
jf = pjit(lambda x: x, device=jax.devices()[0])
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
jf = pjit(lambda x: x, device=jax.devices()[0])
out = jax.vmap(jf)(jnp.ones((3,))) # doesn't crash
self.assertIsInstance(out.sharding, SingleDeviceSharding)
@ -3649,8 +3675,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out.devices(), {jax.devices()[0]})
self.assertArraysEqual(out, np_inp)
out2 = jax.jit(identity, device=jax.devices()[0])(
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
out2 = jax.jit(identity, device=jax.devices()[0])(
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
self.assertEqual(out2.devices(), {jax.devices()[0]})
self.assertArraysEqual(out2, np_inp)