Remove code that disabled tests on "stream_executor" backends.

These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.

Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.

PiperOrigin-RevId: 565367453
This commit is contained in:
Peter Hawkins 2023-09-14 07:52:07 -07:00 committed by jax authors
parent 8acf597eba
commit bbfba9ace8
6 changed files with 2 additions and 152 deletions

View File

@ -1047,6 +1047,7 @@ jax_test(
"gpu",
"cpu",
],
tags = ["multiaccelerator"],
)
jax_test(
@ -1061,6 +1062,7 @@ jax_test(
jax_test(
name = "python_callback_test",
srcs = ["python_callback_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],

View File

@ -25,7 +25,6 @@ from jax import config
from jax.experimental import pjit
from jax._src import debugger
from jax._src import test_util as jtu
from jax._src import xla_bridge
import jax.numpy as jnp
import numpy as np
@ -98,9 +97,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertEqual(stdout.getvalue(), expected)
def test_debugger_can_print_value_in_jit(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p x", "c"])
@jax.jit
@ -117,9 +113,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertEqual(stdout.getvalue(), expected)
def test_debugger_can_print_multiple_values(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"])
@jax.jit
@ -136,9 +129,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertEqual(stdout.getvalue(), expected)
def test_debugger_can_print_context(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["l", "c"])
@jax.jit
@ -161,9 +151,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_debugger_can_print_backtrace(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["bt", "c"])
@jax.jit
@ -180,9 +167,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_debugger_can_work_with_multiple_stack_frames(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"])
def f(x):
@ -221,9 +205,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_can_use_multiple_breakpoints(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
def f(x):
@ -249,9 +230,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertEqual(stdout.getvalue(), expected)
def test_debugger_works_with_vmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
# On TPU, the breakpoints can be reordered inside of vmap but can be fixed
@ -281,9 +259,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertEqual(stdout.getvalue(), expected)
def test_debugger_works_with_pmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.local_device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@ -309,9 +284,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_debugger_works_with_pjit(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.default_backend() != "tpu":
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
@ -361,9 +333,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_debugger_accesses_globals(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])
@jax.jit
@ -379,8 +348,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_can_limit_num_frames(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"])
def g():
@ -425,9 +392,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
self.assertRegex(stdout.getvalue(), expected)
def test_can_handle_dictionaries_with_unsortable_keys(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict",
"p weirder_dict", "c"])

View File

@ -27,7 +27,6 @@ from jax._src import ad_checkpoint
from jax._src import debugging
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import xla_bridge
import jax.numpy as jnp
import numpy as np
@ -91,9 +90,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\ny: 3\n")
def test_can_stage_out_debug_print(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
debug_print('x: {x}', x=x)
@ -103,9 +99,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
def test_can_stage_out_debug_print_with_donate_argnums(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
@ -119,9 +112,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
def test_can_stage_out_ordered_print(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
debug_print('x: {x}', x=x, ordered=True)
@ -131,9 +121,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
def test_can_stage_out_ordered_print_with_donate_argnums(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
@ -147,9 +134,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
def test_can_stage_out_prints_with_donate_argnums(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
@ -164,9 +148,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\nx: 2\n")
def test_can_double_stage_out_ordered_print(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
@jax.jit
def f(x):
@ -177,9 +158,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), "x: 2\n")
def test_can_stage_out_ordered_print_with_pytree(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
struct = dict(foo=x)
@ -190,8 +168,6 @@ class DebugPrintTest(jtu.JaxTestCase):
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
def test_debug_print_should_use_default_layout(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
data = np.array(
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14],
@ -437,9 +413,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)
def test_debug_print_in_staged_out_custom_jvp(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
@jax.custom_jvp
@ -464,9 +437,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
self.assertEqual(output(), "goodbye: 2.0 3.0\n")
def test_debug_print_in_staged_out_custom_vjp(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
@jax.jit
def f(x):
@jax.custom_vjp
@ -508,9 +478,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_scan(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(xs):
def _body(carry, x):
debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)
@ -528,9 +495,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_for_loop(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def _body(i, x):
debug_print("i: {i}", i=i, ordered=ordered)
@ -559,9 +523,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_while_loop_body(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def _cond(x):
return x < 10
@ -582,9 +543,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_while_loop_cond(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def _cond(x):
debug_print("x: {x}", x=x, ordered=ordered)
@ -614,9 +572,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_in_batched_while_cond(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def _cond(x):
debug_print("x: {x}", x=x, ordered=ordered)
@ -674,9 +629,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_cond(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def true_fun(x):
debug_print("true: {}", x, ordered=ordered)
@ -700,9 +652,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
@jtu.sample_product(ordered=[False, True])
def test_can_print_inside_switch(self, ordered):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def b1(x):
debug_print("b1: {}", x, ordered=ordered)
@ -752,9 +701,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
f(jnp.arange(jax.local_device_count()))
def test_unordered_print_works_in_pmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@ -777,9 +723,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
def test_unordered_print_with_pjit(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
debug_print("{}", x, ordered=False)
return x
@ -805,10 +748,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
self.assertEqual(output(), "140\n")
def test_nested_pjit_debug_print(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise self.skipTest(
'Host callback not supported for runtime type: stream_executor.')
def f(x):
debug_print("{}", x)
return x
@ -819,9 +758,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")
def test_unordered_print_of_pjit_of_while(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def cond(carry):
i, *_ = carry
@ -848,9 +784,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
"[ 4 5 6 7 8 9 10 11]\n")
def test_unordered_print_of_pjit_of_xmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
def foo(x):
idx = lax.axis_index('foo')
@ -872,9 +805,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
self._assertLinesEqual(output(), "\n".join(lines))
def test_unordered_print_with_xmap(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def f(x):
debug_print("{}", x, ordered=False)
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
@ -887,9 +817,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
self._assertLinesEqual(output(), "".join(lines))
def test_unordered_print_works_in_pmap_of_while(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")

View File

@ -33,7 +33,6 @@ from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
import numpy as np
config.parse_flags_with_absl()
@ -593,11 +592,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
class EffectOrderingTest(jtu.JaxTestCase):
def test_can_execute_python_callback(self):
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
log = []
def log_value(x):
log.append(x)
@ -616,11 +610,6 @@ class EffectOrderingTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@ -704,11 +693,6 @@ class ParallelEffectsTest(jtu.JaxTestCase):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_can_pmap_unordered_callback(self):
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")

View File

@ -1453,8 +1453,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
)
def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
mesh_axis_names):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
input_data = np.arange(
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
@ -1471,8 +1469,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out._value, input_data)
def test_xla_arr_sharding_mismatch(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
@ -1496,8 +1492,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
compiled(arr)
def test_gda_auto_shardings_len(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
@ -1518,8 +1512,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
)
def test_jit_arr_partial_auto_sharding_array(
self, mesh_shape, mesh_axis_names, pspec):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
@ -1555,9 +1547,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
@unittest.skip('The error is not raised yet. Enable this back once we raise '
'the error in pjit again.')
def test_pjit_array_error(self):
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(

View File

@ -25,7 +25,6 @@ from jax._src import core
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax.experimental import maps
from jax.experimental import pjit
@ -71,11 +70,6 @@ with_pure_and_io_callbacks = parameterized.named_parameters(
class PythonCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@ -494,11 +488,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
class PureCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@ -875,11 +864,6 @@ class PureCallbackTest(jtu.JaxTestCase):
class IOCallbackTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_bridge.get_backend().runtime_type == 'stream_executor':
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()