Add a prototype IREE backend for JAX.

This is to support experimentation with the combination of JAX/IREE. Many things do not work yet.

PiperOrigin-RevId: 409980064
This commit is contained in:
Peter Hawkins 2021-11-15 07:56:34 -08:00 committed by jax authors
parent 6fa860d5ac
commit 70b8a6a806
8 changed files with 226 additions and 22 deletions

View File

@ -132,4 +132,4 @@ jobs:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
run: |
pytest -n 1 --tb=short docs
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py jax
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py jax

157
jax/_src/iree.py Normal file
View File

@ -0,0 +1,157 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental IREE backend for JAX.
This backend is quite incomplete, but exists to allow experimenting with
using IREE to compile and run JAX computations instead of XLA.
"""
# pytype: skip-file
from typing import Any, List, Sequence
import iree.compiler
from iree.compiler.api import driver as compiler_driver
from iree import runtime as iree_runtime
from jax._src.lib import xla_client
import numpy as np
class IreeDevice:
def __init__(self, client):
self.id = 0
self.host_id = 0
self.process_index = 0
self.platform = "iree"
self.device_kind = "IREE device"
self.client = client
def __str__(self) -> str:
return "IreeDevice"
def transfer_to_infeed(self, literal: Any):
raise NotImplementedError("transfer_to_infeed")
def transfer_from_outfeed(self, shape: xla_client.Shape):
raise NotImplementedError("transfer_to_outfeed")
def live_buffers(self) -> List['IreeBuffer']:
raise NotImplementedError("live_buffers")
class IreeBuffer(xla_client.DeviceArrayBase):
def __init__(self, client, device, npy_value):
self.client = client
self._device = device
self._npy_value = np.asarray(npy_value)
def to_py(self) -> np.ndarray:
return self._npy_value
def to_iree(self):
return self._npy_value
class IreeExecutable:
def __init__(self, client, devices, module_object, function_name):
self.client = client
self.traceback = None
self.fingerprint = None
self._devices = devices
self.module_object = module_object
self.function_name = function_name
def local_devices(self) -> List[IreeDevice]:
return self._devices
def execute(self, arguments: Sequence[IreeBuffer]) -> List[IreeBuffer]:
inputs = [arg.to_iree() for arg in arguments]
outputs = self.module_object[self.function_name](*inputs)
# TODO(phawkins): Have a way to just have it always return the list,
# regardless of arity.
if not isinstance(outputs, list):
outputs = [outputs]
return [
IreeBuffer(self.client, self._devices[0], output) for output in outputs
]
class IreeClient:
def __init__(self,
*,
compile_target_backends: Sequence[str] = ("cpu",),
runtime_driver: str = "dylib"):
self.platform = "iree"
self.platform_version = "0.0.1"
self.runtime_type = "iree"
self.iree_config = iree_runtime.system_api.Config(runtime_driver)
self.compiler_options = compiler_driver.CompilerOptions()
self.compiler_options.set_input_dialect_mhlo()
for target_backend in compile_target_backends:
self.compiler_options.add_target_backend(target_backend)
self._devices = [IreeDevice(self)]
def process_index(self) -> int:
return 0
def device_count(self) -> int:
return len(self._devices)
def devices(self) -> List[IreeDevice]:
return self._devices
def local_devices(self) -> List[IreeDevice]:
return self._devices
def local_device_count(self) -> int:
return len(self._devices)
def get_default_device_assignment(
self,
num_replicas: int,
num_partitions: int = 1) -> List[List[IreeDevice]]:
if num_replicas != 1 or num_partitions != 1:
raise NotImplementedError("Only single-device computations implemented")
return [[self._devices[0]]]
def compile(self, computation: str,
compile_options: xla_client.CompileOptions) -> IreeExecutable:
iree_binary = iree.compiler.compile_str(
computation, target_backends=["dylib"], input_type="mhlo")
# Load it into the runtime.
vm_module = iree_runtime.binding.VmModule.from_flatbuffer(iree_binary)
module_object = iree_runtime.load_vm_module(vm_module, self.iree_config)
return IreeExecutable(self, self._devices, module_object, "main")
def buffer_from_pyval(
self,
argument: Any,
device: IreeDevice,
force_copy: bool = True,
host_buffer_semantics: xla_client.HostBufferSemantics = xla_client
.HostBufferSemantics.ZERO_COPY
) -> IreeBuffer:
# TODO(phawkins): IREE's python API will accept a numpy array directly but
# may want to explicitly construct a lower level BufferView to avoid copies.
return IreeBuffer(self, device, np.array(argument, copy=True))
def iree_client_factory():
return IreeClient()

View File

@ -22,6 +22,7 @@ XLA. There are also a handful of related casting utilities.
from functools import partial, lru_cache
import os
import threading
from typing import Any, Dict, List, Optional, Tuple, Union
import warnings
@ -35,7 +36,13 @@ from . import tpu_driver_client
from . import xla_client
from jax._src import util, traceback_util
import numpy as np
import threading
iree: Optional[Any]
try:
import jax._src.iree as iree # type: ignore
except ModuleNotFoundError:
iree = None
traceback_util.register_exclusion(__file__)
@ -194,6 +201,9 @@ register_backend_factory('gpu', xla_client.make_gpu_client,
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
if iree is not None:
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
def backends():
global _backends

View File

@ -43,9 +43,8 @@ from jax.experimental.maps import mesh
FLAGS = flags.FLAGS
flags.DEFINE_enum(
flags.DEFINE_string(
'jax_test_dut', '',
enum_values=['', 'cpu', 'gpu', 'tpu'],
help=
'Describes the device under test in case special consideration is required.'
)
@ -405,6 +404,9 @@ def supported_dtypes():
if device_under_test() == "tpu":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
elif device_under_test() == "iree":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
np.uint32, np.float32}
else:
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,

View File

@ -89,7 +89,7 @@ def split_dict(dct, names):
assert not dct
return lst
def concatenate(xs: Iterable[Sequence[T]]) -> Sequence[T]:
def concatenate(xs: Iterable[Sequence[T]]) -> List[T]:
"""Concatenates/flattens a list of lists."""
return list(it.chain.from_iterable(xs))

View File

@ -261,11 +261,15 @@ class LoweringContext:
axis_env: xla.AxisEnv
name_stack: str
# Should function results be tupled?
tuple_results: bool
def __init__(self, platform: str, axis_env: xla.AxisEnv, name_stack: str,
context: Optional[ir.Context] = None,
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None):
symbol_table: Optional[ir.SymbolTable] = None,
tuple_results: bool = True):
self.context = context or ir.Context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
@ -273,6 +277,7 @@ class LoweringContext:
self.platform = platform
self.axis_env = axis_env
self.name_stack = name_stack
self.tuple_results = tuple_results
mhlo.register_mhlo_dialect(self.context)
chlo.register_chlo_dialect(self.context)
@ -308,7 +313,7 @@ def _flatten_lowering_ir_args(
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
jaxpr: core.ClosedJaxpr, *,
tuple_arguments: bool = False,
uniquify_name: bool = True) -> str:
public: bool = False) -> str:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
@ -318,13 +323,21 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
input_types = map(aval_to_ir_types, jaxpr.in_avals)
output_types = map(aval_to_ir_types, jaxpr.out_avals)
flat_input_types = util.flatten(input_types)
output_tuple_type = ir.TupleType.get_tuple(util.flatten(output_types))
flat_output_types = util.flatten(output_types)
if ctx.tuple_results:
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
fn_output_types = [output_tuple_type]
else:
fn_output_types = flat_output_types
if tuple_arguments:
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
ftype = ir.FunctionType.get([input_tuple_type], [output_tuple_type])
fn_input_types = [input_tuple_type]
else:
ftype = ir.FunctionType.get(flat_input_types, [output_tuple_type])
fn_input_types = flat_input_types
ftype = ir.FunctionType.get(fn_input_types, fn_output_types)
func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
@ -340,7 +353,12 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
*unflattened_args)
std.ReturnOp([mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).result])
flat_outputs = util.flatten(out_vals)
if ctx.tuple_results:
std.ReturnOp([mhlo.TupleOp(output_tuple_type, flat_outputs).result])
else:
std.ReturnOp(flat_outputs)
return symbol_name
@ -447,16 +465,24 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
xla.check_backend_matches(backend, ctx.platform)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
if ctx.tuple_results:
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
call_output_types = [output_tuple_type]
else:
call_output_types = flat_output_types
sub_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
core.ClosedJaxpr(call_jaxpr, ()))
call = std.CallOp([output_tuple_type],
call = std.CallOp(call_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
_flatten_lowering_ir_args(args)).result
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
for i, typ in enumerate(flat_output_types)]
_flatten_lowering_ir_args(args))
if ctx.tuple_results:
flat_results = [
mhlo.GetTupleElementOp(typ, call.result, _i32_attr(i)).result
for i, typ in enumerate(flat_output_types)]
else:
flat_results = call.results
return util.unflatten(flat_results, map(len, output_types))
def _xla_call_lower(ctx, avals_in, avals_out, *args,
@ -774,16 +800,21 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
"extra data movement anyway, so maybe you don't want it after all).")
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
ctx = LoweringContext(backend.platform if backend is not None else None,
xla.AxisEnv(nreps, (), ()), "")
backend = xb.get_backend(backend)
if backend.runtime_type == "iree":
tuple_args = False
ctx = ctx.replace(tuple_results=False)
else:
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
with ctx.context, ir.Location.unknown(ctx.context):
lower_jaxpr_to_fun(ctx, "main", core.ClosedJaxpr(jaxpr, consts),
tuple_arguments=tuple_args, uniquify_name=False)
tuple_arguments=tuple_args, public=True)
assert not any(donated_invars), donated_invars
backend = xb.get_backend(backend)
# TODO(b/203122001): implement buffer donation.
# if backend.platform in ("gpu", "tpu"):
# donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
@ -794,9 +825,10 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
# warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
ctx.module.operation.verify()
output = io.StringIO()
ctx.module.operation.print(file=output, enable_debug_info=True,
ctx.module.operation.print(file=output, #enable_debug_info=True,
print_generic_op_form=False)
module = output.getvalue()
# print("MLIR module to be compiled:")
# print(module)
return XlaComputation(
name, module, False, nreps, device, backend, tuple_args, avals_in,

View File

@ -1482,7 +1482,8 @@ class _DeviceArray(DeviceArray): # type: ignore
if config.jax_enable_checks:
assert type(aval) is ShapedArray
npy_value = self._value
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape, (
aval, npy_value.shape, npy_value.dtype)
assert (device is None) or device is device_buffer.device()
def _check_if_deleted(self):

View File

@ -22,3 +22,5 @@ ignore_errors = True
ignore_missing_imports = True
[mypy-jaxlib.mlir.*]
ignore_missing_imports = True
[mypy-iree.*]
ignore_missing_imports = True