mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[IREE] Allow backend selection via a flag.
Avoid eagerly creating NumPy arrays for IREE buffers.
This commit is contained in:
parent
bc877faae0
commit
a7e041fa4e
@ -22,14 +22,23 @@ using IREE to compile and run JAX computations instead of XLA.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Any, List, Sequence, Optional
|
||||
|
||||
import iree.compiler
|
||||
from iree import runtime as iree_runtime
|
||||
import iree.runtime
|
||||
|
||||
from jax._src.config import flags
|
||||
from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
'jax_iree_backend', os.getenv('JAX_IREE_BACKEND', 'dylib'),
|
||||
'IREE compiler backend to use.')
|
||||
|
||||
class IreeDevice:
|
||||
|
||||
@ -56,20 +65,20 @@ class IreeDevice:
|
||||
|
||||
class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
|
||||
def __init__(self, client, device, npy_value):
|
||||
def __init__(self, client, device, buffer):
|
||||
self.client = client
|
||||
self._device = device
|
||||
assert device is not None
|
||||
self._npy_value = np.asarray(npy_value)
|
||||
self._buffer = buffer
|
||||
|
||||
def copy_to_device(self, device):
|
||||
return self
|
||||
|
||||
def to_py(self) -> np.ndarray:
|
||||
return self._npy_value
|
||||
return np.asarray(self._buffer)
|
||||
|
||||
def to_iree(self):
|
||||
return self._npy_value
|
||||
return self._buffer
|
||||
|
||||
def platform(self):
|
||||
return self.client.platform
|
||||
@ -82,7 +91,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
|
||||
# overrides repr on base class which expects _value and aval attributes
|
||||
def __repr__(self):
|
||||
return f'IreeBuffer({self._npy_value})'
|
||||
return f'IreeBuffer({self.to_py()})'
|
||||
|
||||
class IreeExecutable:
|
||||
|
||||
@ -113,12 +122,13 @@ class IreeClient:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
compile_target_backends: Sequence[str] = ("cpu",),
|
||||
runtime_driver: str = "dylib"):
|
||||
runtime_driver: str = None):
|
||||
self.platform = "iree"
|
||||
self.platform_version = "0.0.1"
|
||||
self.runtime_type = "iree"
|
||||
self.iree_config = iree_runtime.system_api.Config(runtime_driver)
|
||||
self.runtime_driver = (FLAGS.jax_iree_backend if runtime_driver is None
|
||||
else runtime_driver)
|
||||
self.iree_config = iree.runtime.system_api.Config(self.runtime_driver)
|
||||
self._devices = [IreeDevice(self)]
|
||||
|
||||
def process_index(self) -> int:
|
||||
@ -147,14 +157,18 @@ class IreeClient:
|
||||
def compile(self, computation: str,
|
||||
compile_options: xla_client.CompileOptions) -> IreeExecutable:
|
||||
del compile_options # Ignored.
|
||||
extra_args = []
|
||||
# extra_args=["--mlir-print-ir-after-all"]
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin21.5.0"]
|
||||
iree_binary = iree.compiler.compile_str(
|
||||
computation, target_backends=["dylib"], input_type="mhlo",
|
||||
# extra_args=["--mlir-print-ir-after-all"],
|
||||
computation, target_backends=[self.runtime_driver], input_type="mhlo",
|
||||
# extended_diagnostics=True,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
# Load it into the runtime.
|
||||
vm_module = iree_runtime.VmModule.from_flatbuffer(iree_binary)
|
||||
module_object = iree_runtime.load_vm_module(vm_module, self.iree_config)
|
||||
vm_module = iree.runtime.VmModule.from_flatbuffer(iree_binary)
|
||||
module_object = iree.runtime.load_vm_module(vm_module, self.iree_config)
|
||||
return IreeExecutable(self, self._devices, module_object, "main")
|
||||
|
||||
def buffer_from_pyval(
|
||||
|
Loading…
x
Reference in New Issue
Block a user