[IREE] Allow backend selection via a flag.

Avoid eagerly creating NumPy arrays for IREE buffers.
This commit is contained in:
Peter Hawkins 2022-06-02 16:09:28 -04:00
parent bc877faae0
commit a7e041fa4e

View File

@ -22,14 +22,23 @@ using IREE to compile and run JAX computations instead of XLA.
from __future__ import annotations
import os
import platform
from typing import Any, List, Sequence, Optional
import iree.compiler
from iree import runtime as iree_runtime
import iree.runtime
from jax._src.config import flags
from jax._src.lib import xla_client
import numpy as np
FLAGS = flags.FLAGS
flags.DEFINE_string(
'jax_iree_backend', os.getenv('JAX_IREE_BACKEND', 'dylib'),
'IREE compiler backend to use.')
class IreeDevice:
@ -56,20 +65,20 @@ class IreeDevice:
class IreeBuffer(xla_client.DeviceArrayBase):
def __init__(self, client, device, npy_value):
def __init__(self, client, device, buffer):
self.client = client
self._device = device
assert device is not None
self._npy_value = np.asarray(npy_value)
self._buffer = buffer
def copy_to_device(self, device):
return self
def to_py(self) -> np.ndarray:
return self._npy_value
return np.asarray(self._buffer)
def to_iree(self):
return self._npy_value
return self._buffer
def platform(self):
return self.client.platform
@ -82,7 +91,7 @@ class IreeBuffer(xla_client.DeviceArrayBase):
# overrides repr on base class which expects _value and aval attributes
def __repr__(self):
return f'IreeBuffer({self._npy_value})'
return f'IreeBuffer({self.to_py()})'
class IreeExecutable:
@ -113,12 +122,13 @@ class IreeClient:
def __init__(self,
*,
compile_target_backends: Sequence[str] = ("cpu",),
runtime_driver: str = "dylib"):
runtime_driver: str = None):
self.platform = "iree"
self.platform_version = "0.0.1"
self.runtime_type = "iree"
self.iree_config = iree_runtime.system_api.Config(runtime_driver)
self.runtime_driver = (FLAGS.jax_iree_backend if runtime_driver is None
else runtime_driver)
self.iree_config = iree.runtime.system_api.Config(self.runtime_driver)
self._devices = [IreeDevice(self)]
def process_index(self) -> int:
@ -147,14 +157,18 @@ class IreeClient:
def compile(self, computation: str,
compile_options: xla_client.CompileOptions) -> IreeExecutable:
del compile_options # Ignored.
extra_args = []
# extra_args=["--mlir-print-ir-after-all"]
if platform.system() == "Darwin" and platform.machine() == "arm64":
extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin21.5.0"]
iree_binary = iree.compiler.compile_str(
computation, target_backends=["dylib"], input_type="mhlo",
# extra_args=["--mlir-print-ir-after-all"],
computation, target_backends=[self.runtime_driver], input_type="mhlo",
# extended_diagnostics=True,
extra_args=extra_args,
)
# Load it into the runtime.
vm_module = iree_runtime.VmModule.from_flatbuffer(iree_binary)
module_object = iree_runtime.load_vm_module(vm_module, self.iree_config)
vm_module = iree.runtime.VmModule.from_flatbuffer(iree_binary)
module_object = iree.runtime.load_vm_module(vm_module, self.iree_config)
return IreeExecutable(self, self._devices, module_object, "main")
def buffer_from_pyval(