mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

The main changes here are only indirectly related to gather: we just had to update some other rules (e.g. for comparison, and squeeze) for a simple dynamic-batch-shape gather to work. I also skipped two tests and deleted some old dynamic shape slicing logic because we want to handle that differently. We didn't have to do that removal in this PR, but it's just convenient given I'm looking at indexing again.
215 lines
6.0 KiB
Python
215 lines
6.0 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Experimental IREE backend for JAX.
|
|
|
|
This backend is quite incomplete, but exists to allow experimenting with
|
|
using IREE to compile and run JAX computations instead of XLA.
|
|
"""
|
|
|
|
# pytype: skip-file
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import platform
|
|
from typing import Any, List, Sequence, Optional
|
|
|
|
import iree.compiler
|
|
import iree.runtime
|
|
|
|
from jax._src.config import flags
|
|
from jax._src.lib import xla_client
|
|
import numpy as np
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
flags.DEFINE_string(
|
|
'jax_iree_backend', os.getenv('JAX_IREE_BACKEND', 'cpu'),
|
|
'IREE compiler backend to use.')
|
|
|
|
iree_compiler_map = {
|
|
"cpu" : "llvm-cpu",
|
|
"cuda" : "cuda",
|
|
"vmvx" : "vmvx",
|
|
"vulkan" : "vulkan-spirv"
|
|
}
|
|
|
|
iree_runtime_map = {
|
|
"cpu" : "local-task",
|
|
"cuda" : "cuda",
|
|
"vmvx" : "local-task",
|
|
"vulkan" : "vulkan"
|
|
}
|
|
|
|
class IreeDevice:
|
|
|
|
def __init__(self, client):
|
|
self.id = 0
|
|
self.host_id = 0
|
|
self.process_index = 0
|
|
self.platform = "iree"
|
|
self.device_kind = "IREE device"
|
|
self.client = client
|
|
|
|
def __str__(self) -> str:
|
|
return "IreeDevice"
|
|
|
|
def transfer_to_infeed(self, literal: Any):
|
|
raise NotImplementedError("transfer_to_infeed")
|
|
|
|
def transfer_from_outfeed(self, shape: xla_client.Shape):
|
|
raise NotImplementedError("transfer_to_outfeed")
|
|
|
|
def live_buffers(self) -> List[IreeBuffer]:
|
|
raise NotImplementedError("live_buffers")
|
|
|
|
|
|
class IreeBuffer(xla_client.DeviceArrayBase):
|
|
|
|
def __init__(self, client, device, buffer):
|
|
self.client = client
|
|
self._device = device
|
|
assert device is not None
|
|
self._buffer = buffer
|
|
|
|
def copy_to_device(self, device):
|
|
return self
|
|
|
|
def __array__(self, dtype=None, context=None):
|
|
return np.asarray(self._buffer)
|
|
|
|
def to_iree(self):
|
|
return self._buffer
|
|
|
|
def platform(self):
|
|
return self.client.platform
|
|
|
|
def device(self):
|
|
return self._device
|
|
|
|
def block_until_ready(self) -> IreeBuffer:
|
|
return self # no async
|
|
|
|
# overrides repr on base class which expects _value and aval attributes
|
|
def __repr__(self): return f'IreeBuffer({np.asarray(self)})'
|
|
|
|
@property
|
|
def _value(self):
|
|
return np.asarray(self)
|
|
|
|
def copy_to_host_async(self):
|
|
return self
|
|
|
|
class IreeExecutable:
|
|
|
|
def __init__(self, client, devices, module_object, function_name):
|
|
self.client = client
|
|
self.traceback = None
|
|
self.fingerprint = None
|
|
self._devices = devices
|
|
self.module_object = module_object
|
|
self.function_name = function_name
|
|
|
|
def local_devices(self) -> List[IreeDevice]:
|
|
return self._devices
|
|
|
|
def execute(self, arguments: Sequence[IreeBuffer]) -> List[IreeBuffer]:
|
|
inputs = [arg.to_iree() for arg in arguments]
|
|
outputs = self.module_object[self.function_name](*inputs)
|
|
# TODO(phawkins): Have a way to just have it always return the list,
|
|
# regardless of arity.
|
|
if not isinstance(outputs, list) and not isinstance(outputs, tuple):
|
|
outputs = [outputs]
|
|
return [
|
|
IreeBuffer(self.client, self._devices[0], output) for output in outputs
|
|
]
|
|
|
|
|
|
class IreeClient:
|
|
|
|
def __init__(self,
|
|
*,
|
|
iree_backend: Optional[str] = None):
|
|
self.platform = "iree"
|
|
self.platform_version = "0.0.1"
|
|
self.runtime_type = "iree"
|
|
self.iree_backend = (FLAGS.jax_iree_backend if iree_backend is None
|
|
else iree_backend)
|
|
self.compiler_driver = iree_compiler_map[self.iree_backend]
|
|
self.runtime_driver = iree_runtime_map[self.iree_backend]
|
|
self.iree_config = iree.runtime.system_api.Config(self.runtime_driver)
|
|
self._devices = [IreeDevice(self)]
|
|
|
|
def process_index(self) -> int:
|
|
return 0
|
|
|
|
def device_count(self) -> int:
|
|
return len(self._devices)
|
|
|
|
def devices(self) -> List[IreeDevice]:
|
|
return self._devices
|
|
|
|
def local_devices(self) -> List[IreeDevice]:
|
|
return self._devices
|
|
|
|
def local_device_count(self) -> int:
|
|
return len(self._devices)
|
|
|
|
def get_default_device_assignment(
|
|
self,
|
|
num_replicas: int) -> List[IreeDevice]:
|
|
if num_replicas != 1:
|
|
raise NotImplementedError("Only single-device computations implemented")
|
|
return [self._devices[0]]
|
|
|
|
|
|
def compile(self, computation: str,
|
|
compile_options: xla_client.CompileOptions) -> IreeExecutable:
|
|
del compile_options # Ignored.
|
|
extra_args = []
|
|
# extra_args=["--mlir-print-ir-after-all"]
|
|
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
|
extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin21.5.0"]
|
|
iree_binary = iree.compiler.compile_str(
|
|
computation, target_backends=[self.compiler_driver], input_type="mhlo",
|
|
# extended_diagnostics=True,
|
|
extra_args=extra_args,
|
|
)
|
|
# Load it into the runtime.
|
|
vm_module = iree.runtime.VmModule.from_flatbuffer(
|
|
self.iree_config.vm_instance, iree_binary)
|
|
module_object = iree.runtime.load_vm_module(vm_module, self.iree_config)
|
|
return IreeExecutable(self, self._devices, module_object, "main")
|
|
|
|
def buffer_from_pyval(
|
|
self,
|
|
argument: Any,
|
|
device: Optional[IreeDevice],
|
|
force_copy: bool = True,
|
|
host_buffer_semantics: xla_client.HostBufferSemantics = xla_client
|
|
.HostBufferSemantics.ZERO_COPY
|
|
) -> IreeBuffer:
|
|
# TODO(phawkins): IREE's python API will accept a numpy array directly but
|
|
# may want to explicitly construct a lower level BufferView to avoid copies.
|
|
if device is None:
|
|
assert type(argument) is np.ndarray
|
|
device = self._devices[0]
|
|
return IreeBuffer(self, device, np.array(argument, copy=True))
|
|
|
|
|
|
def iree_client_factory():
|
|
return IreeClient()
|