mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21834 from jakevdp:jit-warning
PiperOrigin-RevId: 643050911
This commit is contained in:
commit
a123470810
@ -54,7 +54,6 @@ filterwarnings = [
|
||||
"error",
|
||||
"default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
|
||||
"default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
|
||||
"default:backend and device argument on jit is deprecated.*:DeprecationWarning",
|
||||
"default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
|
||||
# TODO(jakevdp): remove when array_api_tests stabilize
|
||||
# start array_api_tests-related warnings
|
||||
|
@ -229,7 +229,9 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
def test_jit_device(self):
|
||||
device = jax.devices()[-1]
|
||||
x = jit(lambda x: x, device=device)(3.)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
x = jit(lambda x: x, device=device)(3.)
|
||||
_check_instance(self, x)
|
||||
self.assertEqual(x.devices(), {device})
|
||||
|
||||
|
@ -244,14 +244,16 @@ class HostCallbackImportsTest(jtu.JaxTestCase):
|
||||
class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
|
||||
# skipping here skips teardown, so do this before super().setUp().
|
||||
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
||||
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest("host_callback not implemented in PJRT C API")
|
||||
|
||||
super().setUp()
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="backend and device argument"))
|
||||
testing_stream.reset()
|
||||
testing_stream._test_method_name = self._testMethodName
|
||||
self.old_flags = os.getenv("XLA_FLAGS", "")
|
||||
@ -2040,13 +2042,16 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"""Tests for hcb.call"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
|
||||
# skipping here skips teardown, so do this before super().setUp().
|
||||
if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1:
|
||||
raise SkipTest("host_callback broken on multi-GPU platforms (#6447)")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise SkipTest("host_callback not implemented in PJRT C API")
|
||||
super().setUp()
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="The host_callback APIs are deprecated"))
|
||||
self.enter_context(jtu.ignore_warning(
|
||||
category=DeprecationWarning, message="backend and device argument"))
|
||||
|
||||
testing_stream.reset()
|
||||
testing_stream._test_method_name = self._testMethodName
|
||||
|
@ -105,6 +105,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
self.assertRaises(ValueError, lambda: fun(x, y))
|
||||
|
||||
@jtu.sample_product(backend=['cpu', 'gpu', 'tpu'])
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument")
|
||||
def testGpuMultiBackendOpByOpReturn(self, backend):
|
||||
if backend not in ('cpu', jtu.device_under_test()):
|
||||
raise SkipTest("Backend is not CPU or the device under test")
|
||||
@ -119,6 +121,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
self.assertEqual(list(w.devices())[0].platform, backend)
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument")
|
||||
def testJitCpu(self):
|
||||
@partial(jax.jit, backend='cpu')
|
||||
def get_arr(scale):
|
||||
@ -135,6 +139,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
self.assertEqual(c.devices(), {jax.devices('cpu')[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test can only fail with non-cpu backends
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument")
|
||||
def test_closed_over_values_device_placement(self):
|
||||
# see https://github.com/google/jax/issues/1431
|
||||
def f(): return jnp.add(3., 4.)
|
||||
@ -144,6 +150,8 @@ class MultiBackendTest(jtu.JaxTestCase):
|
||||
{jax.devices('cpu')[0]})
|
||||
|
||||
@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument")
|
||||
def test_jit_on_nondefault_backend(self):
|
||||
cpus = jax.devices("cpu")
|
||||
self.assertNotEmpty(cpus)
|
||||
|
@ -2464,6 +2464,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"):
|
||||
my_nested_pjit(committed_inp, committed_inp, committed_inp)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument")
|
||||
def test_jit_device_with_sharding_constraint_error(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@ -2889,7 +2891,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
f = pjit(mul, device=jax.devices()[1])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
f = pjit(mul, device=jax.devices()[1])
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
f_out = f(x)
|
||||
f_out2 = f(f_out)
|
||||
@ -2904,7 +2908,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
|
||||
h = pjit(mul, device=jax.devices()[-1])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
h = pjit(mul, device=jax.devices()[-1])
|
||||
h_out = h(y)
|
||||
cache_info3 = pjit_lib._pjit_lower_cached.cache_info()
|
||||
_check(h_out, jax.devices()[-1], y)
|
||||
@ -2924,7 +2930,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = pjit(lambda x: x * 2)(y)
|
||||
|
||||
expected_device = jax.devices()[2]
|
||||
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
final_out = pjit(lambda x: x * 3, device=expected_device)(out)
|
||||
|
||||
self.assertEqual(final_out.devices(), {expected_device})
|
||||
self.assertLen(final_out.sharding.device_set, 1)
|
||||
@ -2938,7 +2946,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, expected_out)
|
||||
|
||||
x = jnp.arange(8)
|
||||
g = pjit(lambda x: x, backend='tpu')
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
g = pjit(lambda x: x, backend='tpu')
|
||||
g_out = g(x)
|
||||
_check(g_out, jax.devices()[0], x)
|
||||
|
||||
@ -2951,8 +2961,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.skipTest('Test requires more >1 device.')
|
||||
# Add a constant captured by the nested pjit to make things more complicated
|
||||
h = jnp.arange(4.)
|
||||
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
|
||||
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
f = pjit(lambda x: x.sum(1) * h.sum(), device=jax.devices()[1])
|
||||
g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), device=jax.devices()[1])
|
||||
jtu.check_grads(g, (jnp.arange(16.).reshape((4, 4)) / 100,), order=2)
|
||||
|
||||
def test_pjit_device_backend_axis_resources_error(self):
|
||||
@ -2961,13 +2973,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'in_shardings should not be specified.'):
|
||||
pjit(lambda x: x, in_shardings=s, backend='cpu')
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x, in_shardings=s, backend='cpu')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'If backend or device is specified on jit, then '
|
||||
'out_shardings should not be specified.'):
|
||||
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x, out_shardings=s, device=jax.devices()[0])
|
||||
|
||||
def test_check_arg_error(self):
|
||||
sds = jax.ShapeDtypeStruct((4, 2), np.int32)
|
||||
@ -2982,7 +2998,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_pjit_device_backend_both_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "can't specify both a device and a backend for jit"):
|
||||
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x, device=jax.devices()[0], backend='cpu')
|
||||
|
||||
def test_pjit_mesh_with_device_or_backend_error(self):
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
@ -2991,7 +3009,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit."):
|
||||
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x, device=jax.devices()[0])(jnp.arange(8))
|
||||
|
||||
def test_pjit_inline(self):
|
||||
@partial(pjit, inline=False)
|
||||
@ -3208,7 +3228,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
|
||||
self.assertEqual(count[0], 10)
|
||||
|
||||
pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
|
||||
@ -3217,7 +3239,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
pf(inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pf1(inp)
|
||||
@ -3632,7 +3656,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(x.sharding, y.sharding)
|
||||
|
||||
def test_vmap_pjit_single_device(self):
|
||||
jf = pjit(lambda x: x, device=jax.devices()[0])
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
jf = pjit(lambda x: x, device=jax.devices()[0])
|
||||
out = jax.vmap(jf)(jnp.ones((3,))) # doesn't crash
|
||||
self.assertIsInstance(out.sharding, SingleDeviceSharding)
|
||||
|
||||
@ -3649,8 +3675,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.devices(), {jax.devices()[0]})
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
|
||||
out2 = jax.jit(identity, device=jax.devices()[0])(
|
||||
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
out2 = jax.jit(identity, device=jax.devices()[0])(
|
||||
jax.device_put(np_inp, NamedSharding(mesh, P('x'))))
|
||||
self.assertEqual(out2.devices(), {jax.devices()[0]})
|
||||
self.assertArraysEqual(out2, np_inp)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user