mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove code that disabled tests on "stream_executor" backends.
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more. Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices. PiperOrigin-RevId: 565367453
This commit is contained in:
parent
8acf597eba
commit
bbfba9ace8
@ -1047,6 +1047,7 @@ jax_test(
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -1061,6 +1062,7 @@ jax_test(
|
||||
jax_test(
|
||||
name = "python_callback_test",
|
||||
srcs = ["python_callback_test.py"],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
|
@ -25,7 +25,6 @@ from jax import config
|
||||
from jax.experimental import pjit
|
||||
from jax._src import debugger
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -98,9 +97,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_can_print_value_in_jit(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p x", "c"])
|
||||
|
||||
@jax.jit
|
||||
@ -117,9 +113,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_can_print_multiple_values(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"])
|
||||
|
||||
@jax.jit
|
||||
@ -136,9 +129,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_can_print_context(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["l", "c"])
|
||||
|
||||
@jax.jit
|
||||
@ -161,9 +151,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_can_print_backtrace(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["bt", "c"])
|
||||
|
||||
@jax.jit
|
||||
@ -180,9 +167,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_can_work_with_multiple_stack_frames(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"])
|
||||
|
||||
def f(x):
|
||||
@ -221,9 +205,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_can_use_multiple_breakpoints(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
|
||||
|
||||
def f(x):
|
||||
@ -249,9 +230,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_works_with_vmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
|
||||
|
||||
# On TPU, the breakpoints can be reordered inside of vmap but can be fixed
|
||||
@ -281,9 +259,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_works_with_pmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.local_device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
||||
@ -309,9 +284,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_works_with_pjit(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.default_backend() != "tpu":
|
||||
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
|
||||
|
||||
@ -361,9 +333,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_debugger_accesses_globals(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])
|
||||
|
||||
@jax.jit
|
||||
@ -379,8 +348,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_can_limit_num_frames(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
stdin, stdout = make_fake_stdin_stdout(["u", "p x", "c"])
|
||||
|
||||
def g():
|
||||
@ -425,9 +392,6 @@ class CliDebuggerTest(jtu.JaxTestCase):
|
||||
self.assertRegex(stdout.getvalue(), expected)
|
||||
|
||||
def test_can_handle_dictionaries_with_unsortable_keys(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
stdin, stdout = make_fake_stdin_stdout(["p x", "p weird_dict",
|
||||
"p weirder_dict", "c"])
|
||||
|
||||
|
@ -27,7 +27,6 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -91,9 +90,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\ny: 3\n")
|
||||
|
||||
def test_can_stage_out_debug_print(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
debug_print('x: {x}', x=x)
|
||||
@ -103,9 +99,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
def test_can_stage_out_debug_print_with_donate_argnums(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.default_backend() not in {"gpu", "tpu"}:
|
||||
raise unittest.SkipTest("Donate argnums not supported.")
|
||||
|
||||
@ -119,9 +112,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
def test_can_stage_out_ordered_print(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
debug_print('x: {x}', x=x, ordered=True)
|
||||
@ -131,9 +121,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
def test_can_stage_out_ordered_print_with_donate_argnums(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.default_backend() not in {"gpu", "tpu"}:
|
||||
raise unittest.SkipTest("Donate argnums not supported.")
|
||||
|
||||
@ -147,9 +134,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
def test_can_stage_out_prints_with_donate_argnums(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.default_backend() not in {"gpu", "tpu"}:
|
||||
raise unittest.SkipTest("Donate argnums not supported.")
|
||||
|
||||
@ -164,9 +148,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\nx: 2\n")
|
||||
|
||||
def test_can_double_stage_out_ordered_print(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -177,9 +158,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "x: 2\n")
|
||||
|
||||
def test_can_stage_out_ordered_print_with_pytree(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
struct = dict(foo=x)
|
||||
@ -190,8 +168,6 @@ class DebugPrintTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
|
||||
|
||||
def test_debug_print_should_use_default_layout(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
data = np.array(
|
||||
[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14],
|
||||
@ -437,9 +413,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)
|
||||
|
||||
def test_debug_print_in_staged_out_custom_jvp(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@jax.custom_jvp
|
||||
@ -464,9 +437,6 @@ class DebugPrintTransformationTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "goodbye: 2.0 3.0\n")
|
||||
|
||||
def test_debug_print_in_staged_out_custom_vjp(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@jax.custom_vjp
|
||||
@ -508,9 +478,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_scan(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(xs):
|
||||
def _body(carry, x):
|
||||
debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)
|
||||
@ -528,9 +495,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_for_loop(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def _body(i, x):
|
||||
debug_print("i: {i}", i=i, ordered=ordered)
|
||||
@ -559,9 +523,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_while_loop_body(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def _cond(x):
|
||||
return x < 10
|
||||
@ -582,9 +543,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_while_loop_cond(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def _cond(x):
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
@ -614,9 +572,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_in_batched_while_cond(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def _cond(x):
|
||||
debug_print("x: {x}", x=x, ordered=ordered)
|
||||
@ -674,9 +629,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_cond(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
debug_print("true: {}", x, ordered=ordered)
|
||||
@ -700,9 +652,6 @@ class DebugPrintControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(ordered=[False, True])
|
||||
def test_can_print_inside_switch(self, ordered):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def b1(x):
|
||||
debug_print("b1: {}", x, ordered=ordered)
|
||||
@ -752,9 +701,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
f(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_unordered_print_works_in_pmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
||||
@ -777,9 +723,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
|
||||
|
||||
def test_unordered_print_with_pjit(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
debug_print("{}", x, ordered=False)
|
||||
return x
|
||||
@ -805,10 +748,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "140\n")
|
||||
|
||||
def test_nested_pjit_debug_print(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise self.skipTest(
|
||||
'Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
debug_print("{}", x)
|
||||
return x
|
||||
@ -819,9 +758,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")
|
||||
|
||||
def test_unordered_print_of_pjit_of_while(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def cond(carry):
|
||||
i, *_ = carry
|
||||
@ -848,9 +784,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
"[ 4 5 6 7 8 9 10 11]\n")
|
||||
|
||||
def test_unordered_print_of_pjit_of_xmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
def foo(x):
|
||||
idx = lax.axis_index('foo')
|
||||
@ -872,9 +805,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
self._assertLinesEqual(output(), "\n".join(lines))
|
||||
|
||||
def test_unordered_print_with_xmap(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def f(x):
|
||||
debug_print("{}", x, ordered=False)
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
|
||||
@ -887,9 +817,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
self._assertLinesEqual(output(), "".join(lines))
|
||||
|
||||
def test_unordered_print_works_in_pmap_of_while(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
||||
|
@ -33,7 +33,6 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -593,11 +592,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_can_execute_python_callback(self):
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
log = []
|
||||
def log_value(x):
|
||||
log.append(x)
|
||||
@ -616,11 +610,6 @@ class EffectOrderingTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
||||
@ -704,11 +693,6 @@ class ParallelEffectsTest(jtu.JaxTestCase):
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_can_pmap_unordered_callback(self):
|
||||
# TODO(sharadmv): enable this test on GPU and TPU when backends are
|
||||
# supported
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
||||
|
@ -1453,8 +1453,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape,
|
||||
mesh_axis_names):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
@ -1471,8 +1469,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
def test_xla_arr_sharding_mismatch(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (4, 2)
|
||||
input_data = np.arange(
|
||||
@ -1496,8 +1492,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
compiled(arr)
|
||||
|
||||
def test_gda_auto_shardings_len(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (4, 2)
|
||||
input_data = np.arange(
|
||||
@ -1518,8 +1512,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_jit_arr_partial_auto_sharding_array(
|
||||
self, mesh_shape, mesh_axis_names, pspec):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
global_input_shape = (8, 4)
|
||||
input_data = np.arange(
|
||||
@ -1555,9 +1547,6 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
@unittest.skip('The error is not raised yet. Enable this back once we raise '
|
||||
'the error in pjit again.')
|
||||
def test_pjit_array_error(self):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
|
@ -25,7 +25,6 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
@ -71,11 +70,6 @@ with_pure_and_io_callbacks = parameterized.named_parameters(
|
||||
|
||||
class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
@ -494,11 +488,6 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
class PureCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
@ -875,11 +864,6 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
class IOCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('Host callback not supported for runtime type: stream_executor.')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
|
Loading…
x
Reference in New Issue
Block a user