mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34
PiperOrigin-RevId: 684104487
This commit is contained in:
commit
bcb0f6466a
@ -5600,11 +5600,16 @@ class ReportedIssuesTests(jtu.JaxTestCase):
|
||||
|
||||
@staticmethod
|
||||
def compile_and_exec(module, args, run_on_cpu=False):
|
||||
backend = jax.lib.xla_bridge.get_backend('METAL')
|
||||
from jax.extend.backend import get_backend
|
||||
backend = get_backend('METAL')
|
||||
if run_on_cpu:
|
||||
backend = jax.lib.xla_bridge.get_backend('cpu')
|
||||
backend = get_backend('cpu')
|
||||
executable = backend.compile(module)
|
||||
return executable.execute(args)
|
||||
def put(arg):
|
||||
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
|
||||
arguments = [put(arg) for arg in args]
|
||||
outputs = executable.execute(arguments)
|
||||
return [np.asarray(x) for x in outputs]
|
||||
|
||||
@staticmethod
|
||||
def jax_metal_supported(target_ver):
|
||||
|
Loading…
x
Reference in New Issue
Block a user