diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index e403daba5..5f1781c3b 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -5600,11 +5600,16 @@ class ReportedIssuesTests(jtu.JaxTestCase): @staticmethod def compile_and_exec(module, args, run_on_cpu=False): - backend = jax.lib.xla_bridge.get_backend('METAL') + from jax.extend.backend import get_backend + backend = get_backend('METAL') if run_on_cpu: - backend = jax.lib.xla_bridge.get_backend('cpu') + backend = get_backend('cpu') executable = backend.compile(module) - return executable.execute(args) + def put(arg): + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + arguments = [put(arg) for arg in args] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] @staticmethod def jax_metal_supported(target_ver):