Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34

PiperOrigin-RevId: 684104487
This commit is contained in:
jax authors 2024-10-09 11:15:59 -07:00
commit bcb0f6466a

View File

@ -5600,11 +5600,16 @@ class ReportedIssuesTests(jtu.JaxTestCase):
@staticmethod
def compile_and_exec(module, args, run_on_cpu=False):
backend = jax.lib.xla_bridge.get_backend('METAL')
from jax.extend.backend import get_backend
backend = get_backend('METAL')
if run_on_cpu:
backend = jax.lib.xla_bridge.get_backend('cpu')
backend = get_backend('cpu')
executable = backend.compile(module)
return executable.execute(args)
def put(arg):
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
arguments = [put(arg) for arg in args]
outputs = executable.execute(arguments)
return [np.asarray(x) for x in outputs]
@staticmethod
def jax_metal_supported(target_ver):