From bbfba9ace8aada128db3824bd07a8ded18bbb437 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 14 Sep 2023 07:52:07 -0700 Subject: [PATCH] Remove code that disabled tests on "stream_executor" backends. These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more. Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices. PiperOrigin-RevId: 565367453 --- tests/BUILD | 2 + tests/debugger_test.py | 36 --------------- tests/debugging_primitives_test.py | 73 ------------------------------ tests/jaxpr_effects_test.py | 16 ------- tests/pjit_test.py | 11 ----- tests/python_callback_test.py | 16 ------- 6 files changed, 2 insertions(+), 152 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index c86dd6c2d..ff10b9b2b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1047,6 +1047,7 @@ jax_test( "gpu", "cpu", ], + tags = ["multiaccelerator"], ) jax_test( @@ -1061,6 +1062,7 @@ jax_test( jax_test( name = "python_callback_test", srcs = ["python_callback_test.py"], + tags = ["multiaccelerator"], deps = [ "//jax:experimental", ], diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 1ac11b04c..31800864f 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -25,7 +25,6 @@ from jax import config from jax.experimental import pjit from jax._src import debugger from jax._src import test_util as jtu -from jax._src import xla_bridge import jax.numpy as jnp import numpy as np @@ -98,9 +97,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertEqual(stdout.getvalue(), expected) def test_debugger_can_print_value_in_jit(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p x", "c"]) @jax.jit @@ -117,9 +113,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertEqual(stdout.getvalue(), expected) def test_debugger_can_print_multiple_values(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"]) @jax.jit @@ -136,9 +129,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertEqual(stdout.getvalue(), expected) def test_debugger_can_print_context(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["l", "c"]) @jax.jit @@ -161,9 +151,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_debugger_can_print_backtrace(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["bt", "c"]) @jax.jit @@ -180,9 +167,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_debugger_can_work_with_multiple_stack_frames(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"]) def f(x): @@ -221,9 +205,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_can_use_multiple_breakpoints(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"]) def f(x): @@ -249,9 +230,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertEqual(stdout.getvalue(), expected) def test_debugger_works_with_vmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"]) # On TPU, the breakpoints can be reordered inside of vmap but can be fixed @@ -281,9 +259,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertEqual(stdout.getvalue(), expected) def test_debugger_works_with_pmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.local_device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @@ -309,9 +284,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_debugger_works_with_pjit(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.default_backend() != "tpu": raise unittest.SkipTest("`pjit` doesn't work with CustomCall.") @@ -361,9 +333,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_debugger_accesses_globals(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p foo", "c"]) @jax.jit @@ -379,8 +348,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_can_limit_num_frames(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"]) def g(): @@ -425,9 +392,6 @@ class CliDebuggerTest(jtu.JaxTestCase): self.assertRegex(stdout.getvalue(), expected) def test_can_handle_dictionaries_with_unsortable_keys(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict", "p weirder_dict", "c"]) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 07516e6f1..89fe88899 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -27,7 +27,6 @@ from jax._src import ad_checkpoint from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu -from jax._src import xla_bridge import jax.numpy as jnp import numpy as np @@ -91,9 +90,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\ny: 3\n") def test_can_stage_out_debug_print(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): debug_print('x: {x}', x=x) @@ -103,9 +99,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\n") def test_can_stage_out_debug_print_with_donate_argnums(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.default_backend() not in {"gpu", "tpu"}: raise unittest.SkipTest("Donate argnums not supported.") @@ -119,9 +112,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\n") def test_can_stage_out_ordered_print(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): debug_print('x: {x}', x=x, ordered=True) @@ -131,9 +121,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\n") def test_can_stage_out_ordered_print_with_donate_argnums(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.default_backend() not in {"gpu", "tpu"}: raise unittest.SkipTest("Donate argnums not supported.") @@ -147,9 +134,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\n") def test_can_stage_out_prints_with_donate_argnums(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.default_backend() not in {"gpu", "tpu"}: raise unittest.SkipTest("Donate argnums not supported.") @@ -164,9 +148,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\nx: 2\n") def test_can_double_stage_out_ordered_print(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit @jax.jit def f(x): @@ -177,9 +158,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), "x: 2\n") def test_can_stage_out_ordered_print_with_pytree(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): struct = dict(foo=x) @@ -190,8 +168,6 @@ class DebugPrintTest(jtu.JaxTestCase): self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n") def test_debug_print_should_use_default_layout(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') data = np.array( [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14], @@ -437,9 +413,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase): self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2) def test_debug_print_in_staged_out_custom_jvp(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): @jax.custom_jvp @@ -464,9 +437,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase): self.assertEqual(output(), "goodbye: 2.0 3.0\n") def test_debug_print_in_staged_out_custom_vjp(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - @jax.jit def f(x): @jax.custom_vjp @@ -508,9 +478,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_scan(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(xs): def _body(carry, x): debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered) @@ -528,9 +495,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_for_loop(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def _body(i, x): debug_print("i: {i}", i=i, ordered=ordered) @@ -559,9 +523,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_while_loop_body(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def _cond(x): return x < 10 @@ -582,9 +543,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_while_loop_cond(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def _cond(x): debug_print("x: {x}", x=x, ordered=ordered) @@ -614,9 +572,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_in_batched_while_cond(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def _cond(x): debug_print("x: {x}", x=x, ordered=ordered) @@ -674,9 +629,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_cond(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def true_fun(x): debug_print("true: {}", x, ordered=ordered) @@ -700,9 +652,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase): @jtu.sample_product(ordered=[False, True]) def test_can_print_inside_switch(self, ordered): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def b1(x): debug_print("b1: {}", x, ordered=ordered) @@ -752,9 +701,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): f(jnp.arange(jax.local_device_count())) def test_unordered_print_works_in_pmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @@ -777,9 +723,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n") def test_unordered_print_with_pjit(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): debug_print("{}", x, ordered=False) return x @@ -805,10 +748,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): self.assertEqual(output(), "140\n") def test_nested_pjit_debug_print(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise self.skipTest( - 'Host callback not supported for runtime type: stream_executor.') - def f(x): debug_print("{}", x) return x @@ -819,9 +758,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n") def test_unordered_print_of_pjit_of_while(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def cond(carry): i, *_ = carry @@ -848,9 +784,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): "[ 4 5 6 7 8 9 10 11]\n") def test_unordered_print_of_pjit_of_xmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): def foo(x): idx = lax.axis_index('foo') @@ -872,9 +805,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): self._assertLinesEqual(output(), "\n".join(lines)) def test_unordered_print_with_xmap(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', @@ -887,9 +817,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase): self._assertLinesEqual(output(), "".join(lines)) def test_unordered_print_works_in_pmap_of_while(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index d31e8a428..b6906a529 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -33,7 +33,6 @@ from jax._src import ad_checkpoint from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src import xla_bridge import numpy as np config.parse_flags_with_absl() @@ -593,11 +592,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase): class EffectOrderingTest(jtu.JaxTestCase): def test_can_execute_python_callback(self): - # TODO(sharadmv): enable this test on GPU and TPU when backends are - # supported - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - log = [] def log_value(x): log.append(x) @@ -616,11 +610,6 @@ class EffectOrderingTest(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_ordered_effect_remains_ordered_across_multiple_devices(self): - # TODO(sharadmv): enable this test on GPU and TPU when backends are - # supported - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @@ -704,11 +693,6 @@ class ParallelEffectsTest(jtu.JaxTestCase): jax.pmap(f)(jnp.arange(jax.local_device_count())) def test_can_pmap_unordered_callback(self): - # TODO(sharadmv): enable this test on GPU and TPU when backends are - # supported - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ed12e2fdf..db6f5d5c1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1453,8 +1453,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ) def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, mesh_axis_names): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1471,8 +1469,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase): self.assertArraysEqual(out._value, input_data) def test_xla_arr_sharding_mismatch(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( @@ -1496,8 +1492,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase): compiled(arr) def test_gda_auto_shardings_len(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( @@ -1518,8 +1512,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ) def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( @@ -1555,9 +1547,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase): @unittest.skip('The error is not raised yet. Enable this back once we raise ' 'the error in pjit again.') def test_pjit_array_error(self): - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 15b1e9d7e..edf0c275f 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -25,7 +25,6 @@ from jax._src import core from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util -from jax._src import xla_bridge from jax._src.lib import xla_client from jax.experimental import maps from jax.experimental import pjit @@ -71,11 +70,6 @@ with_pure_and_io_callbacks = parameterized.named_parameters( class PythonCallbackTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() @@ -494,11 +488,6 @@ class PythonCallbackTest(jtu.JaxTestCase): class PureCallbackTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear() @@ -875,11 +864,6 @@ class PureCallbackTest(jtu.JaxTestCase): class IOCallbackTest(jtu.JaxTestCase): - def setUp(self): - super().setUp() - if xla_bridge.get_backend().runtime_type == 'stream_executor': - raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.') - def tearDown(self): super().tearDown() dispatch.runtime_tokens.clear()