adopt jax.extend.backend

This commit is contained in:
Shuhan Ding 2024-10-07 12:27:35 -07:00
parent 18f48bd52a
commit 1aa32f51ee

View File

@ -5600,11 +5600,16 @@ class ReportedIssuesTests(jtu.JaxTestCase):
@staticmethod
def compile_and_exec(module, args, run_on_cpu=False):
backend = jax.lib.xla_bridge.get_backend('METAL')
from jax.extend.backend import get_backend
backend = get_backend('METAL')
if run_on_cpu:
backend = jax.lib.xla_bridge.get_backend('cpu')
backend = get_backend('cpu')
executable = backend.compile(module)
return executable.execute(args)
def put(arg):
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
arguments = [put(arg) for arg in args]
outputs = executable.execute(arguments)
return [np.asarray(x) for x in outputs]
@staticmethod
def jax_metal_supported(target_ver):