mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
adopt jax.extend.backend
This commit is contained in:
parent
18f48bd52a
commit
1aa32f51ee
@ -5600,11 +5600,16 @@ class ReportedIssuesTests(jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def compile_and_exec(module, args, run_on_cpu=False):
|
||||
backend = jax.lib.xla_bridge.get_backend('METAL')
|
||||
from jax.extend.backend import get_backend
|
||||
backend = get_backend('METAL')
|
||||
if run_on_cpu:
|
||||
backend = jax.lib.xla_bridge.get_backend('cpu')
|
||||
backend = get_backend('cpu')
|
||||
executable = backend.compile(module)
|
||||
return executable.execute(args)
|
||||
def put(arg):
|
||||
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
|
||||
arguments = [put(arg) for arg in args]
|
||||
outputs = executable.execute(arguments)
|
||||
return [np.asarray(x) for x in outputs]
|
||||
|
||||
@staticmethod
|
||||
def jax_metal_supported(target_ver):
|
||||
|
Loading…
x
Reference in New Issue
Block a user