mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add a prototype IREE backend for JAX.
This is to support experimentation with the combination of JAX/IREE. Many things do not work yet. PiperOrigin-RevId: 409980064
This commit is contained in:
parent
6fa860d5ac
commit
70b8a6a806
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -132,4 +132,4 @@ jobs:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
run: |
|
||||
pytest -n 1 --tb=short docs
|
||||
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py jax
|
||||
pytest -n 1 --tb=short --doctest-modules --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py jax
|
||||
|
157
jax/_src/iree.py
Normal file
157
jax/_src/iree.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Experimental IREE backend for JAX.
|
||||
|
||||
This backend is quite incomplete, but exists to allow experimenting with
|
||||
using IREE to compile and run JAX computations instead of XLA.
|
||||
"""
|
||||
|
||||
# pytype: skip-file
|
||||
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
import iree.compiler
|
||||
from iree.compiler.api import driver as compiler_driver
|
||||
from iree import runtime as iree_runtime
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
|
||||
class IreeDevice:
|
||||
|
||||
def __init__(self, client):
|
||||
self.id = 0
|
||||
self.host_id = 0
|
||||
self.process_index = 0
|
||||
self.platform = "iree"
|
||||
self.device_kind = "IREE device"
|
||||
self.client = client
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "IreeDevice"
|
||||
|
||||
def transfer_to_infeed(self, literal: Any):
|
||||
raise NotImplementedError("transfer_to_infeed")
|
||||
|
||||
def transfer_from_outfeed(self, shape: xla_client.Shape):
|
||||
raise NotImplementedError("transfer_to_outfeed")
|
||||
|
||||
def live_buffers(self) -> List['IreeBuffer']:
|
||||
raise NotImplementedError("live_buffers")
|
||||
|
||||
|
||||
class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
|
||||
def __init__(self, client, device, npy_value):
|
||||
self.client = client
|
||||
self._device = device
|
||||
self._npy_value = np.asarray(npy_value)
|
||||
|
||||
def to_py(self) -> np.ndarray:
|
||||
return self._npy_value
|
||||
|
||||
def to_iree(self):
|
||||
return self._npy_value
|
||||
|
||||
|
||||
class IreeExecutable:
|
||||
|
||||
def __init__(self, client, devices, module_object, function_name):
|
||||
self.client = client
|
||||
self.traceback = None
|
||||
self.fingerprint = None
|
||||
self._devices = devices
|
||||
self.module_object = module_object
|
||||
self.function_name = function_name
|
||||
|
||||
def local_devices(self) -> List[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def execute(self, arguments: Sequence[IreeBuffer]) -> List[IreeBuffer]:
|
||||
inputs = [arg.to_iree() for arg in arguments]
|
||||
outputs = self.module_object[self.function_name](*inputs)
|
||||
# TODO(phawkins): Have a way to just have it always return the list,
|
||||
# regardless of arity.
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
return [
|
||||
IreeBuffer(self.client, self._devices[0], output) for output in outputs
|
||||
]
|
||||
|
||||
|
||||
class IreeClient:
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
compile_target_backends: Sequence[str] = ("cpu",),
|
||||
runtime_driver: str = "dylib"):
|
||||
self.platform = "iree"
|
||||
self.platform_version = "0.0.1"
|
||||
self.runtime_type = "iree"
|
||||
self.iree_config = iree_runtime.system_api.Config(runtime_driver)
|
||||
self.compiler_options = compiler_driver.CompilerOptions()
|
||||
self.compiler_options.set_input_dialect_mhlo()
|
||||
for target_backend in compile_target_backends:
|
||||
self.compiler_options.add_target_backend(target_backend)
|
||||
self._devices = [IreeDevice(self)]
|
||||
|
||||
def process_index(self) -> int:
|
||||
return 0
|
||||
|
||||
def device_count(self) -> int:
|
||||
return len(self._devices)
|
||||
|
||||
def devices(self) -> List[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def local_devices(self) -> List[IreeDevice]:
|
||||
return self._devices
|
||||
|
||||
def local_device_count(self) -> int:
|
||||
return len(self._devices)
|
||||
|
||||
def get_default_device_assignment(
|
||||
self,
|
||||
num_replicas: int,
|
||||
num_partitions: int = 1) -> List[List[IreeDevice]]:
|
||||
if num_replicas != 1 or num_partitions != 1:
|
||||
raise NotImplementedError("Only single-device computations implemented")
|
||||
return [[self._devices[0]]]
|
||||
|
||||
def compile(self, computation: str,
|
||||
compile_options: xla_client.CompileOptions) -> IreeExecutable:
|
||||
iree_binary = iree.compiler.compile_str(
|
||||
computation, target_backends=["dylib"], input_type="mhlo")
|
||||
# Load it into the runtime.
|
||||
vm_module = iree_runtime.binding.VmModule.from_flatbuffer(iree_binary)
|
||||
module_object = iree_runtime.load_vm_module(vm_module, self.iree_config)
|
||||
return IreeExecutable(self, self._devices, module_object, "main")
|
||||
|
||||
def buffer_from_pyval(
|
||||
self,
|
||||
argument: Any,
|
||||
device: IreeDevice,
|
||||
force_copy: bool = True,
|
||||
host_buffer_semantics: xla_client.HostBufferSemantics = xla_client
|
||||
.HostBufferSemantics.ZERO_COPY
|
||||
) -> IreeBuffer:
|
||||
# TODO(phawkins): IREE's python API will accept a numpy array directly but
|
||||
# may want to explicitly construct a lower level BufferView to avoid copies.
|
||||
return IreeBuffer(self, device, np.array(argument, copy=True))
|
||||
|
||||
|
||||
def iree_client_factory():
|
||||
return IreeClient()
|
@ -22,6 +22,7 @@ XLA. There are also a handful of related casting utilities.
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
@ -35,7 +36,13 @@ from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
iree: Optional[Any]
|
||||
|
||||
try:
|
||||
import jax._src.iree as iree # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
iree = None
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
@ -194,6 +201,9 @@ register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
register_backend_factory(
|
||||
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
||||
|
||||
if iree is not None:
|
||||
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
|
||||
|
||||
|
||||
def backends():
|
||||
global _backends
|
||||
|
@ -43,9 +43,8 @@ from jax.experimental.maps import mesh
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_enum(
|
||||
flags.DEFINE_string(
|
||||
'jax_test_dut', '',
|
||||
enum_values=['', 'cpu', 'gpu', 'tpu'],
|
||||
help=
|
||||
'Describes the device under test in case special consideration is required.'
|
||||
)
|
||||
@ -405,6 +404,9 @@ def supported_dtypes():
|
||||
if device_under_test() == "tpu":
|
||||
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
||||
np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
|
||||
elif device_under_test() == "iree":
|
||||
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
|
||||
np.uint32, np.float32}
|
||||
else:
|
||||
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
|
@ -89,7 +89,7 @@ def split_dict(dct, names):
|
||||
assert not dct
|
||||
return lst
|
||||
|
||||
def concatenate(xs: Iterable[Sequence[T]]) -> Sequence[T]:
|
||||
def concatenate(xs: Iterable[Sequence[T]]) -> List[T]:
|
||||
"""Concatenates/flattens a list of lists."""
|
||||
return list(it.chain.from_iterable(xs))
|
||||
|
||||
|
@ -261,11 +261,15 @@ class LoweringContext:
|
||||
axis_env: xla.AxisEnv
|
||||
name_stack: str
|
||||
|
||||
# Should function results be tupled?
|
||||
tuple_results: bool
|
||||
|
||||
def __init__(self, platform: str, axis_env: xla.AxisEnv, name_stack: str,
|
||||
context: Optional[ir.Context] = None,
|
||||
module: Optional[ir.Module] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
symbol_table: Optional[ir.SymbolTable] = None):
|
||||
symbol_table: Optional[ir.SymbolTable] = None,
|
||||
tuple_results: bool = True):
|
||||
self.context = context or ir.Context()
|
||||
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
||||
self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
|
||||
@ -273,6 +277,7 @@ class LoweringContext:
|
||||
self.platform = platform
|
||||
self.axis_env = axis_env
|
||||
self.name_stack = name_stack
|
||||
self.tuple_results = tuple_results
|
||||
mhlo.register_mhlo_dialect(self.context)
|
||||
chlo.register_chlo_dialect(self.context)
|
||||
|
||||
@ -308,7 +313,7 @@ def _flatten_lowering_ir_args(
|
||||
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
jaxpr: core.ClosedJaxpr, *,
|
||||
tuple_arguments: bool = False,
|
||||
uniquify_name: bool = True) -> str:
|
||||
public: bool = False) -> str:
|
||||
"""Lowers jaxpr and its callees to an IR function.
|
||||
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
@ -318,13 +323,21 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
input_types = map(aval_to_ir_types, jaxpr.in_avals)
|
||||
output_types = map(aval_to_ir_types, jaxpr.out_avals)
|
||||
flat_input_types = util.flatten(input_types)
|
||||
output_tuple_type = ir.TupleType.get_tuple(util.flatten(output_types))
|
||||
flat_output_types = util.flatten(output_types)
|
||||
if ctx.tuple_results:
|
||||
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
||||
fn_output_types = [output_tuple_type]
|
||||
else:
|
||||
fn_output_types = flat_output_types
|
||||
if tuple_arguments:
|
||||
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
|
||||
ftype = ir.FunctionType.get([input_tuple_type], [output_tuple_type])
|
||||
fn_input_types = [input_tuple_type]
|
||||
else:
|
||||
ftype = ir.FunctionType.get(flat_input_types, [output_tuple_type])
|
||||
fn_input_types = flat_input_types
|
||||
ftype = ir.FunctionType.get(fn_input_types, fn_output_types)
|
||||
func_op = builtin.FuncOp(name, ftype, ip=ctx.ip)
|
||||
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
|
||||
"public" if public else "private")
|
||||
symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value
|
||||
entry_block = func_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
@ -340,7 +353,12 @@ def lower_jaxpr_to_fun(ctx: LoweringContext, name: str,
|
||||
out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
|
||||
jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
|
||||
*unflattened_args)
|
||||
std.ReturnOp([mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).result])
|
||||
flat_outputs = util.flatten(out_vals)
|
||||
if ctx.tuple_results:
|
||||
std.ReturnOp([mhlo.TupleOp(output_tuple_type, flat_outputs).result])
|
||||
else:
|
||||
std.ReturnOp(flat_outputs)
|
||||
|
||||
return symbol_name
|
||||
|
||||
|
||||
@ -447,16 +465,24 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
xla.check_backend_matches(backend, ctx.platform)
|
||||
output_types = map(aval_to_ir_types, avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
||||
if ctx.tuple_results:
|
||||
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
||||
call_output_types = [output_tuple_type]
|
||||
else:
|
||||
call_output_types = flat_output_types
|
||||
sub_ctx = ctx.replace(
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
|
||||
symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
|
||||
core.ClosedJaxpr(call_jaxpr, ()))
|
||||
call = std.CallOp([output_tuple_type],
|
||||
call = std.CallOp(call_output_types,
|
||||
ir.FlatSymbolRefAttr.get(symbol_name),
|
||||
_flatten_lowering_ir_args(args)).result
|
||||
flat_results = [mhlo.GetTupleElementOp(typ, call, _i32_attr(i)).result
|
||||
for i, typ in enumerate(flat_output_types)]
|
||||
_flatten_lowering_ir_args(args))
|
||||
if ctx.tuple_results:
|
||||
flat_results = [
|
||||
mhlo.GetTupleElementOp(typ, call.result, _i32_attr(i)).result
|
||||
for i, typ in enumerate(flat_output_types)]
|
||||
else:
|
||||
flat_results = call.results
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
|
||||
def _xla_call_lower(ctx, avals_in, avals_out, *args,
|
||||
@ -774,16 +800,21 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
|
||||
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
||||
"extra data movement anyway, so maybe you don't want it after all).")
|
||||
|
||||
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
ctx = LoweringContext(backend.platform if backend is not None else None,
|
||||
xla.AxisEnv(nreps, (), ()), "")
|
||||
|
||||
backend = xb.get_backend(backend)
|
||||
if backend.runtime_type == "iree":
|
||||
tuple_args = False
|
||||
ctx = ctx.replace(tuple_results=False)
|
||||
else:
|
||||
tuple_args = len(avals_in) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
lower_jaxpr_to_fun(ctx, "main", core.ClosedJaxpr(jaxpr, consts),
|
||||
tuple_arguments=tuple_args, uniquify_name=False)
|
||||
tuple_arguments=tuple_args, public=True)
|
||||
|
||||
assert not any(donated_invars), donated_invars
|
||||
backend = xb.get_backend(backend)
|
||||
# TODO(b/203122001): implement buffer donation.
|
||||
# if backend.platform in ("gpu", "tpu"):
|
||||
# donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
|
||||
@ -794,9 +825,10 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars
|
||||
# warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations)))
|
||||
ctx.module.operation.verify()
|
||||
output = io.StringIO()
|
||||
ctx.module.operation.print(file=output, enable_debug_info=True,
|
||||
ctx.module.operation.print(file=output, #enable_debug_info=True,
|
||||
print_generic_op_form=False)
|
||||
module = output.getvalue()
|
||||
# print("MLIR module to be compiled:")
|
||||
# print(module)
|
||||
return XlaComputation(
|
||||
name, module, False, nreps, device, backend, tuple_args, avals_in,
|
||||
|
@ -1482,7 +1482,8 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
if config.jax_enable_checks:
|
||||
assert type(aval) is ShapedArray
|
||||
npy_value = self._value
|
||||
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape
|
||||
assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape, (
|
||||
aval, npy_value.shape, npy_value.dtype)
|
||||
assert (device is None) or device is device_buffer.device()
|
||||
|
||||
def _check_if_deleted(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user