Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib

PiperOrigin-RevId: 432236852
This commit is contained in:
jax authors 2022-03-03 11:11:02 -08:00
commit cf9a900d78
58 changed files with 5176 additions and 1543 deletions

View File

@ -52,7 +52,12 @@ py_binary(
"//jaxlib:_cuda_prng",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocblas_kernels",
"//jaxlib:hip_gpu_support",
"//jaxlib:_hipblas",
"//jaxlib:_hipsolver",
"//jaxlib:_hipsparse",
"//jaxlib:_hip_linalg",
"//jaxlib:_hip_prng",
]),
deps = ["@bazel_tools//tools/python/runfiles"],
)

View File

@ -383,7 +383,7 @@ def main():
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx90",
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",
@ -510,6 +510,7 @@ def main():
config_args += ["--config=tpu"]
if args.enable_rocm:
config_args += ["--config=rocm"]
if not args.enable_nccl:
config_args += ["--config=nonccl"]
command = ([bazel_path] + args.bazel_startup_options +

View File

@ -197,11 +197,21 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cublas.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_linalg.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_prng.so"))
if r.Rlocation("__main__/jaxlib/_hipsolver.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsolver.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipblas.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_linalg.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_prng.so"))
if r.Rlocation("__main__/jaxlib/_cusolver.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cusolver.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cublas.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_linalg.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cuda_prng.pyd"))
if r.Rlocation("__main__/jaxlib/_hipsolver.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsolver.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipblas.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_linalg.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hip_prng.pyd"))
if r.Rlocation("__main__/jaxlib/cusolver.py") is not None:
libdevice_dir = os.path.join(jaxlib_dir, "cuda", "nvvm", "libdevice")
os.makedirs(libdevice_dir)
@ -210,12 +220,16 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_linalg.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
if r.Rlocation("__main__/jaxlib/rocblas_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocblas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/rocsolver.py"))
if r.Rlocation("__main__/jaxlib/hipsolver.py") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hipsolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hip_linalg.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hip_prng.py"))
if r.Rlocation("__main__/jaxlib/_cusparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_cusparse.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusparse.py"))
if r.Rlocation("__main__/jaxlib/_hipsparse.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_hipsparse.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/hipsparse.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
mlir_dir = os.path.join(jaxlib_dir, "mlir")

View File

@ -0,0 +1,91 @@
FROM ubuntu:bionic
MAINTAINER Reza Rahimi <reza.rahimi@amd.com>
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.0/
ARG ROCM_BUILD_NAME=ubuntu
ARG ROCM_BUILD_NUM=main
ARG ROCM_PATH=/opt/rocm-5.0.0
ARG DEBIAN_FRONTEND=noninteractive
ENV HOME /root/
ENV ROCM_PATH=$ROCM_PATH
RUN apt-get --allow-unauthenticated update && apt install -y wget software-properties-common
RUN apt-get clean all
RUN wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; then \
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
else \
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
fi'
RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
software-properties-common \
clang-6.0 \
clang-format-6.0 \
curl \
g++-multilib \
git \
vim \
libnuma-dev \
virtualenv \
python3-pip \
pciutils \
wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Add to get ppa
RUN apt-get update
RUN apt-get install -y software-properties-common
# Install rocm pkgs
RUN apt-get update --allow-insecure-repositories && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
rocm-dev rocm-libs rccl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Set up paths
ENV HCC_HOME=$ROCM_PATH/hcc
ENV HIP_PATH=$ROCM_PATH/hip
ENV OPENCL_ROOT=$ROCM_PATH/opencl
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
# Add target file to help determine which device(s) to build for
RUN bash -c 'echo -e "gfx900\ngfx906\ngfx908\ngfx90a\ngfx1030" >> ${ROCM_PATH}/bin/target.lst'
# Need to explicitly create the $ROCM_PATH/.info/version file to workaround what seems to be a bazel bug
# The env vars being set via --action_env in .bazelrc and .tf_configure.bazelrc files are sometimes
# not getting set in the build command being spawned by bazel (in theory this should not happen)
# As a consequence ROCM_PATH is sometimes not set for the hipcc commands.
# When hipcc incokes hcc, it specifies $ROCM_PATH/.../include dirs via the `-isystem` options
# If ROCM_PATH is not set, it defaults to /opt/rocm, and as a consequence a dependency is generated on the
# header files included within `/opt/rocm`, which then leads to bazel dependency errors
# Explicitly creating the $ROCM_PATH/.info/version allows ROCM path to be set correrctly, even when ROCM_PATH
# is not explicitly set, and thus avoids the eventual bazel dependency error.
# The bazel bug needs to be root-caused and addressed, but that is out of our control and may take a long time
# to come to fruition, so implementing the workaround to make do till then
# Filed https://github.com/bazelbuild/bazel/issues/11163 for tracking this
RUN touch ${ROCM_PATH}/.info/version
ENV PATH="/root/bin:/root/.local/bin:$PATH"
# Install python3.9
RUN add-apt-repository ppa:deadsnakes/ppa && \
apt update && \
apt install -y python3.9-dev \
python3-pip \
python3.9-distutils
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.6 1 && \
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2
RUN pip3 install --upgrade setuptools pip
RUN pip3 install absl-py numpy==1.19.5 scipy wheel six setuptools pytest pytest-rerunfailures

23
build/rocm/README.md Normal file
View File

@ -0,0 +1,23 @@
# JAX Builds on ROCm
This directory contains files and setup instructions t0 build and test JAX for ROCm in Docker environment. You can build, test and run JAX on ROCm yourself!
***
### Build JAX-ROCm in docker
1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).
2. Build JAX by running the following command from JAX root folder.
./build/rocm/ci_build.sh --keep_image bash -c "./build/rocm/build_rocm.sh"
3. Launch a container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).
```
sudo docker run -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --entrypoint /bin/bash jax-rocm:latest
```
***
### Build and Test JAX-ROCm in docker (suitable for CI jobs)
This folder has all the scripts necessary to build and run tests for JAX-ROCm.
The following command will build JAX on ROCm and run all the tests inside docker (script should be called from JAX root folder).
```
./build/rocm/ci_build.sh bash -c "./build/rocm/build_rocm.sh&&./build/rocm/run_single_gpu.py&&build/rocm/run_multi_gpu.sh"
```

View File

@ -0,0 +1,70 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Common Bash functions used by build scripts
die() {
# Print a message and exit with code 1.
#
# Usage: die <error_message>
# e.g., die "Something bad happened."
echo $@
exit 1
}
realpath() {
# Get the real path of a file
# Usage: realpath <file_path>
if [[ "$#" != "1" ]]; then
die "realpath: incorrect usage"
fi
[[ "$1" = /* ]] && echo "$1" || echo "$PWD/${1#./}"
}
to_lower() {
# Convert the string to lower case
# Usage: to_lower <string>
echo "$1" | tr '[:upper:]' '[:lower:]'
}
calc_elapsed_time() {
# Calculate elapsed time. Takes nanosecond format input of the kind output
# by date +'%s%N'
#
# Usage: calc_elapsed_time <START_TIME> <END_TIME>
if [[ $# != "2" ]]; then
die "calc_elapsed_time: incorrect usage"
fi
START_TIME=$1
END_TIME=$2
if [[ ${START_TIME} == *"N" ]]; then
# Nanosecond precision not available
START_TIME=$(echo ${START_TIME} | sed -e 's/N//g')
END_TIME=$(echo ${END_TIME} | sed -e 's/N//g')
ELAPSED="$(expr ${END_TIME} - ${START_TIME}) s"
else
ELAPSED="$(expr $(expr ${END_TIME} - ${START_TIME}) / 1000000) ms"
fi
echo ${ELAPSED}
}

25
build/rocm/build_rocm.sh Executable file
View File

@ -0,0 +1,25 @@
#!/usr/bin/env bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux
ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
ROCM_TF_FORK_BRANCH="develop-upstream"
rm -rf /tmp/tensorflow-upstream || true
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream
python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream
pip3 install --use-feature=2020-resolver --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
pip3 install --use-feature=2020-resolver --force-reinstall . # installs jax

119
build/rocm/ci_build.sh Executable file
View File

@ -0,0 +1,119 @@
#!/usr/bin/env bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Usage: ci_build.sh [--dockerfile <DOCKERFILE_PATH> --keep_image]
# <COMMAND>
#
# DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docker build.
# If this optional value is not supplied (via the --dockerfile flag)
# Dockerfile.rocm (located in the same directory as this script)
# will be used.
# KEEP_IMAGE: (Optional) If this flag is set, the container will be committed as an image
#
# COMMAND: Command to be executed in the docker container
set -eux
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "${SCRIPT_DIR}/build_common.sh"
CONTAINER_TYPE="rocm"
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.rocm"
DOCKER_CONTEXT_PATH="${SCRIPT_DIR}"
KEEP_IMAGE="--rm"
POSITIONAL_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--dockerfile)
DOCKERFILE_PATH="$2"
DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}")
shift 2
;;
--keep_image)
KEEP_IMAGE=""
shift 1
;;
*)
POSITIONAL_ARGS+=("$1")
shift
;;
esac
done
if [[ ! -f "${DOCKERFILE_PATH}" ]]; then
die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\""
fi
ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G"
# Helper function to traverse directories up until given file is found.
function upsearch (){
test / == "$PWD" && return || \
test -e "$1" && echo "$PWD" && return || \
cd .. && upsearch "$1"
}
# Set up WORKSPACE and BUILD_TAG. Jenkins will set them for you or we pick
# reasonable defaults if you run it outside of Jenkins.
WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
BUILD_TAG="${BUILD_TAG:-jax_ci}"
# Determine the docker image name
DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}"
# Under Jenkins matrix build, the build tag may contain characters such as
# commas (,) and equal signs (=), which are not valid inside docker image names.
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g')
# Convert to all lower-case, as per requirement of Docker image names
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]')
# Print arguments.
echo "WORKSPACE: ${WORKSPACE}"
echo "COMMAND: ${POSITIONAL_ARGS[*]}"
echo "BUILD_TAG: ${BUILD_TAG}"
echo " (docker container name will be ${DOCKER_IMG_NAME})"
echo ""
echo "Building container (${DOCKER_IMG_NAME})..."
docker build -t ${DOCKER_IMG_NAME} \
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
# Check docker build status
if [[ $? != "0" ]]; then
die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}"
fi
# Run the command inside the container.
echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..."
docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \
-v ${WORKSPACE}:/workspace \
-w /workspace \
${ROCM_EXTRA_PARAMS} \
"${DOCKER_IMG_NAME}" \
${POSITIONAL_ARGS[@]}
if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then
echo "Committing the docker container as jax-rocm"
docker stop ${DOCKER_IMG_NAME}
docker commit ${DOCKER_IMG_NAME} jax-rocm
docker rm ${DOCKER_IMG_NAME}
fi
echo "ROCm build was successful!"

20
build/rocm/run_multi_gpu.sh Executable file
View File

@ -0,0 +1,20 @@
#!/usr/bin/env bash
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux
# run test module with multi-gpu requirements. We currently do not have a way to filter tests.
# this issue is also tracked in https://github.com/google/jax/issues/7323
python3 -m pytest --reruns 3 -x tests/pmap_test.py
python3 -m pytest --reruns 3 -x tests/multi_device_test.py

121
build/rocm/run_single_gpu.py Executable file
View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor
GPU_LOCK = threading.Lock()
LAST_CODE = 0
def run_shell_command(cmd, shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
result = subprocess.run(cmd,
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
# sys.exit(result.returncode)
return result.returncode, result.stderr.decode(), result.stdout.decode()
def collect_testmodules():
all_test_files = []
return_code, stderr, stdout = run_shell_command(
["python3", "-m", "pytest", "--collect-only", "tests"])
if return_code != 0:
print(stdout)
print(stderr)
print("Test module discovery failed.")
exit(return_code)
for line in stdout.split("\n"):
match = re.match("<Module (.*)>", line)
if match:
test_file = match.group(1)
all_test_files.append(test_file)
print("---------- collected test modules ----------")
print("Found %d test modules." % (len(all_test_files)))
print("\n".join(all_test_files))
print("--------------------------------------------")
return all_test_files
def run_test(testmodule, gpu_tokens):
global LAST_CODE
with GPU_LOCK:
if LAST_CODE != 0:
return
target_gpu = gpu_tokens.pop()
env_vars = {
"HIP_VISIBLE_DEVICES": str(target_gpu),
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
cmd = ["python3", "-m", "pytest", "--reruns", "3", "-x", testmodule]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout)
print(stderr)
LAST_CODE = return_code
return
def run_parallel(all_testmodules, p):
print("Running tests with parallelism=", p)
available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules
for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens)
# waiting for all modules to finish
executor.shutdown(wait=True) # wait for all jobs to finish
return
def find_num_gpus():
cmd = ["lspci|grep 'controller'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout)
def main(args):
all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel)
exit(LAST_CODE)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-p",
"--parallel",
type=int,
help="number of tests to run in parallel")
args = parser.parse_args()
if args.parallel is None:
sys_gpu_count = find_num_gpus()
args.parallel = sys_gpu_count
print("%d GPUs detected." % sys_gpu_count)
main(args)

View File

@ -38,7 +38,9 @@ from jax._src.lib import lapack
from jax._src.lib import cuda_linalg
from jax._src.lib import cusolver
from jax._src.lib import cusparse
from jax._src.lib import rocsolver
from jax._src.lib import hip_linalg
from jax._src.lib import hipsolver
from jax._src.lib import hipsparse
from jax._src.lib import xla_client
@ -372,10 +374,10 @@ if cusolver is not None:
partial(_cholesky_cpu_gpu_translation_rule, cusolver.potrf),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, rocsolver.potrf),
partial(_cholesky_cpu_gpu_translation_rule, hipsolver.potrf),
platform='gpu')
# Asymmetric eigendecomposition
@ -571,9 +573,9 @@ if cusolver is not None:
eigh_p, partial(_eigh_cpu_gpu_translation_rule, cusolver.syevd),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, rocsolver.syevd),
eigh_p, partial(_eigh_cpu_gpu_translation_rule, hipsolver.syevd),
platform='gpu')
@ -756,10 +758,10 @@ if cusolver is not None:
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, rocsolver.trsm),
partial(_triangular_solve_gpu_translation_rule, hipsolver.trsm),
platform='gpu')
# Support operation for LU decomposition: Transformation of the pivots returned
@ -835,8 +837,12 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
def _lu_pivots_to_permutation_gpu(ctx, avals_in, avals_out, pivots, *,
permutation_size):
if cuda_linalg:
return [cuda_linalg.lu_pivots_to_permutation(
ctx.builder, pivots, permutation_size=permutation_size)]
if hip_linalg:
return [hip_linalg.lu_pivots_to_permutation(
ctx.builder, pivots, permutation_size=permutation_size)]
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
lu_pivots_to_permutation_p.multiple_results = False
@ -856,6 +862,10 @@ if cuda_linalg:
_lu_pivots_to_permutation_gpu,
platform='gpu')
if hip_linalg:
xla.register_translation(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu,
platform='gpu')
# LU decomposition
# Computes a pivoted LU decomposition such that
@ -1047,9 +1057,9 @@ if cusolver is not None:
lu_p, partial(_lu_cpu_gpu_translation_rule, cusolver.getrf),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, rocsolver.getrf),
lu_p, partial(_lu_cpu_gpu_translation_rule, hipsolver.getrf),
platform='gpu')
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
@ -1215,13 +1225,12 @@ if cusolver is not None:
partial(_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, rocsolver.geqrf, rocsolver.orgqr),
partial(_qr_cpu_gpu_translation_rule, hipsolver.geqrf, hipsolver.orgqr),
platform='gpu')
# Singular value decomposition
def svd_impl(operand, full_matrices, compute_uv):
@ -1387,15 +1396,17 @@ if cusolver is not None:
svd_p, partial(_svd_cpu_gpu_translation_rule, cusolver.gesvd),
platform='gpu')
if rocsolver is not None:
if hipsolver is not None:
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, rocsolver.gesvd),
svd_p, partial(_svd_cpu_gpu_translation_rule, hipsolver.gesvd),
platform='gpu')
def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
b, *, m, n, ldb, t):
if cusparse:
return [cusparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
if hipsparse:
return [hipsparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
@ -1407,6 +1418,10 @@ if cusparse is not None and hasattr(cusparse, "gtsv2"):
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
if hipsparse is not None and hasattr(hipsparse, "gtsv2"):
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
"""Pure JAX implementation of `tridiagonal_solve`."""

View File

@ -22,9 +22,9 @@ import warnings
from typing import Optional, Tuple
__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client',
'xla_extension',
'cuda_linalg', 'cuda_prng', 'cusolver', 'hip_linalg', 'hip_prng',
'hipsolver','jaxlib', 'lapack', 'pocketfft', 'pytree',
'tpu_driver_client', 'version', 'xla_client', 'xla_extension',
]
# Before attempting to import jaxlib, warn about experimental
@ -111,26 +111,41 @@ try:
except ImportError:
cusolver = None
try:
import jaxlib.hipsolver as hipsolver # pytype: disable=import-error
except ImportError:
hipsolver = None
try:
import jaxlib.cusparse as cusparse # pytype: disable=import-error
except ImportError:
cusparse = None
try:
import jaxlib.rocsolver as rocsolver # pytype: disable=import-error
import jaxlib.hipsparse as hipsparse # pytype: disable=import-error
except ImportError:
rocsolver = None
hipsparse = None
try:
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
except ImportError:
cuda_prng = None
try:
import jaxlib.hip_prng as hip_prng # pytype: disable=import-error
except ImportError:
hip_prng = None
try:
import jaxlib.cuda_linalg as cuda_linalg # pytype: disable=import-error
except ImportError:
cuda_linalg = None
try:
import jaxlib.hip_linalg as hip_linalg # pytype: disable=import-error
except ImportError:
hip_linalg = None
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
@ -148,6 +163,8 @@ try:
except:
tpu_driver_client = None # type: ignore
# TODO(rocm): check if we need the same for rocm.
cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):

View File

@ -33,6 +33,7 @@ from jax._src.lib import cuda_prng
from jax._src.numpy.lax_numpy import (
_canonicalize_tuple_index, _eliminate_deprecated_list_indexing,
_expand_bool_indices, _register_stackable)
from jax._src.lib import hip_prng
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod
@ -384,11 +385,18 @@ def _threefry2x32_gpu_translation_rule(ctx, avals_in, avals_out, k1, k2, x1,
def _broadcast(x, aval):
return xla_client.ops.BroadcastInDim(
x, aval_out.shape, tuple(range(rank - len(aval.shape), rank)))
if cuda_prng:
return xla.xla_destructure(
ctx.builder,
cuda_prng.threefry2x32(
ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
else:
return xla.xla_destructure(
ctx.builder,
hip_prng.threefry2x32(
ctx.builder, (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval))))
threefry2x32_p = core.Primitive("threefry2x32")
@ -405,7 +413,9 @@ xla.register_translation(threefry2x32_p, xla.lower_fun(
if cuda_prng:
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
platform='gpu')
if hip_prng:
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
platform='gpu')
@partial(jit, inline=True)
def threefry_2x32(keypair, count):

View File

@ -23,8 +23,8 @@ product, sparse matrix/matrix product) for two common sparse representations
These routines have reference implementations defined via XLA scatter/gather
operations that will work on any backend, although they are not particularly
performant. On GPU runtimes built against CUDA 11.0 or newer, each operation is
computed efficiently via cusparse.
performant. On GPU runtimes built against CUDA 11.0/ROCm 5.0 or newer, each operation is
computed efficiently via cusparse/hipsparse.
Further down are some examples of potential high-level wrappers for sparse objects.
(API should be considered unstable and subject to change).

View File

@ -26,10 +26,19 @@ from jax.interpreters import xla
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lib import cusparse
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
try:
from jax._src.lib import cusparse
except ImportError:
cusparse = None
try:
from jax._src.lib import hipsparse
except ImportError:
hipsparse = None
@tree_util.register_pytree_node_class
class COO(JAXSparse):
"""Experimental COO matrix implemented in JAX; API subject to change."""
@ -116,11 +125,14 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
*, shape):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_todense cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"coo_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
shape=shape)
if cusparse is not None:
return [cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
else:
return [hipsparse.coo_todense(ctx.builder, data, row, col, shape=shape)]
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
return coo_todense(data_dot, row, col, shape=shape)
@ -139,7 +151,7 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')
@ -192,12 +204,16 @@ def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_fromdense cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
if cusparse is not None:
data, row, col = cusparse.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
else:
data, row, col = hipsparse.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, row, col]
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
@ -229,7 +245,7 @@ ad.primitive_jvps[coo_fromdense_p] = _coo_fromdense_jvp
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
platform='gpu')
@ -285,12 +301,16 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
v, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matvec cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
shape=shape, transpose=transpose)
if cusparse is not None:
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
@ -313,7 +333,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')
@ -367,12 +387,16 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
B, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"coo_matmat cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"coo_matmat cusparse/hipsprse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
shape=shape, transpose=transpose)
if cusparse is not None:
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, shape, transpose):
return coo_matmat(data_dot, row, col, B, shape=shape, transpose=transpose)
@ -392,6 +416,6 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, shape, transpose):
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')

View File

@ -27,10 +27,18 @@ from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat_impl, _coo_matvec_impl, _coo_todense_impl
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lib import cusparse
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
try:
from jax._src.lib import cusparse
except ImportError:
cusparse = None
try:
from jax._src.lib import hipsparse
except ImportError:
hipsparse = None
@tree_util.register_pytree_node_class
class CSR(JAXSparse):
@ -178,11 +186,14 @@ def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, *, shape):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_todense cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_todense_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, shape=shape)
if cusparse:
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
else:
return [hipsparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
return csr_todense(data_dot, indices, indptr, shape=shape)
@ -201,7 +212,7 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')
@ -259,12 +270,16 @@ def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
index_dtype):
dtype = avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_fromdense cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
if cusparse:
data, indices, indptr = cusparse.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
else:
data, indices, indptr = hipsparse.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, indices, indptr]
def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
@ -295,7 +310,7 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
platform='gpu')
@ -347,12 +362,16 @@ def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, v, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matvec cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matvec_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, v,
shape=shape, transpose=transpose)
if cusparse:
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
else:
return [hipsparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose):
return csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
@ -376,7 +395,7 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')
@ -429,12 +448,16 @@ def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, B, *, shape, transpose):
dtype = avals_in[0].dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f"csr_matmat cusparse lowering not available for dtype={dtype}. "
warnings.warn(f"csr_matmat cusparse/hipsparse lowering not available for dtype={dtype}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matmat_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, B,
shape=shape, transpose=transpose)
if cusparse is not None:
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
else:
return [hipsparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):
return csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose)
@ -456,6 +479,6 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
if cusparse and cusparse.is_supported:
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')

View File

@ -21,6 +21,7 @@ load(
"flatbuffer_py_library",
"if_rocm_is_configured",
"pybind_extension",
"rocm_library",
)
licenses(["notice"])
@ -93,15 +94,14 @@ cc_library(
)
cc_library(
name = "rocm_gpu_kernel_helpers",
srcs = if_rocm_is_configured(["rocm_gpu_kernel_helpers.cc"]),
hdrs = if_rocm_is_configured(["rocm_gpu_kernel_helpers.h"]),
name = "hip_gpu_kernel_helpers",
srcs = if_rocm_is_configured(["hip_gpu_kernel_helpers.cc"]),
hdrs = if_rocm_is_configured(["hip_gpu_kernel_helpers.h"]),
copts = [
"-fexceptions",
],
features = ["-use_header_modules"],
deps = [
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
@ -118,9 +118,7 @@ py_library(
"lapack.py",
"pocketfft.py",
"version.py",
] + if_rocm_is_configured([
"rocsolver.py",
]),
],
deps = [":pocketfft_flatbuffers_py"],
)
@ -252,6 +250,23 @@ py_library(
],
)
py_library(
name = "hip_gpu_support",
srcs = [
"hip_linalg.py",
"hip_prng.py",
"hipsolver.py",
"hipsparse.py",
],
deps = [
":_hip_linalg",
":_hip_prng",
":_hipblas",
":_hipsolver",
":_hipsparse",
],
)
cc_library(
name = "cublas_kernels",
srcs = ["cublas_kernels.cc"],
@ -275,6 +290,27 @@ cc_library(
],
)
cc_library(
name = "hipblas_kernels",
srcs = ["hipblas_kernels.cc"],
hdrs = ["hipblas_kernels.h"],
deps = [
":handle_pool",
":hip_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_cublas",
srcs = ["cublas.cc"],
@ -295,6 +331,26 @@ pybind_extension(
],
)
pybind_extension(
name = "_hipblas",
srcs = ["hipblas.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipblas",
deps = [
":hipblas_kernels",
":kernel_pybind11_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
],
)
cc_library(
name = "cusolver_kernels",
srcs = ["cusolver_kernels.cc"],
@ -312,6 +368,23 @@ cc_library(
],
)
cc_library(
name = "hipsolver_kernels",
srcs = ["hipsolver_kernels.cc"],
hdrs = ["hipsolver_kernels.h"],
deps = [
":handle_pool",
":hip_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_cusolver",
srcs = ["cusolver.cc"],
@ -334,6 +407,27 @@ pybind_extension(
],
)
pybind_extension(
name = "_hipsolver",
srcs = ["hipsolver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipsolver",
deps = [
":hip_gpu_kernel_helpers",
":hipsolver_kernels",
":kernel_pybind11_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
],
)
cc_library(
name = "cusparse_kernels",
srcs = ["cusparse_kernels.cc"],
@ -352,6 +446,23 @@ cc_library(
],
)
cc_library(
name = "hipsparse_kernels",
srcs = ["hipsparse_kernels.cc"],
hdrs = ["hipsparse_kernels.h"],
deps = [
":handle_pool",
":hip_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsparse",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_cusparse",
srcs = ["cusparse.cc"],
@ -381,6 +492,34 @@ pybind_extension(
],
)
pybind_extension(
name = "_hipsparse",
srcs = ["hipsparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hipsparse",
deps = [
":hip_gpu_kernel_helpers",
":hipsparse_kernels",
":kernel_pybind11_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsparse",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
],
)
cc_library(
name = "cuda_lu_pivot_kernels",
srcs = [
@ -396,6 +535,21 @@ cc_library(
],
)
cc_library(
name = "hip_lu_pivot_kernels",
srcs = [
"hip_lu_pivot_kernels.cc",
],
hdrs = ["hip_lu_pivot_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels_impl",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_rocm//rocm:rocm_headers",
],
)
cuda_library(
name = "cuda_lu_pivot_kernels_impl",
srcs = [
@ -410,6 +564,20 @@ cuda_library(
],
)
rocm_library(
name = "hip_lu_pivot_kernels_impl",
srcs = [
"hip_lu_pivot_kernels.hip.cc",
],
hdrs = ["hip_lu_pivot_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_cuda_linalg",
srcs = ["cuda_linalg.cc"],
@ -430,6 +598,25 @@ pybind_extension(
],
)
pybind_extension(
name = "_hip_linalg",
srcs = ["hip_linalg.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_hip_linalg",
deps = [
":hip_gpu_kernel_helpers",
":hip_lu_pivot_kernels",
":hip_lu_pivot_kernels_impl",
":kernel_pybind11_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@pybind11",
],
)
cc_library(
name = "cuda_prng_kernels",
srcs = [
@ -445,6 +632,21 @@ cc_library(
],
)
cc_library(
name = "hip_prng_kernels",
srcs = [
"hip_prng_kernels.cc",
],
hdrs = ["hip_prng_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":hip_prng_kernels_impl",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_rocm//rocm:rocm_headers",
],
)
cuda_library(
name = "cuda_prng_kernels_impl",
srcs = [
@ -459,6 +661,20 @@ cuda_library(
],
)
rocm_library(
name = "hip_prng_kernels_impl",
srcs = [
"hip_prng_kernels.hip.cc",
],
hdrs = ["hip_prng_kernels.h"],
deps = [
":hip_gpu_kernel_helpers",
":kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@local_config_rocm//rocm:rocm_headers",
],
)
pybind_extension(
name = "_cuda_prng",
srcs = ["cuda_prng.cc"],
@ -478,37 +694,25 @@ pybind_extension(
],
)
# AMD GPU support (ROCm)
pybind_extension(
name = "rocblas_kernels",
srcs = if_rocm_is_configured(["rocblas.cc"]),
name = "_hip_prng",
srcs = ["hip_prng.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "rocblas_kernels",
module_name = "_hip_prng",
deps = [
":handle_pool",
":hip_gpu_kernel_helpers",
":hip_prng_kernels",
":kernel_pybind11_helpers",
":rocm_gpu_kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:rocblas",
"@local_config_rocm//rocm:rocm_headers",
"@local_config_rocm//rocm:rocsolver",
"@pybind11",
],
)
# TODO(rocm): do we also need to support this?
cc_library(
name = "gpu_kernels",
srcs = ["gpu_kernels.cc"],

View File

@ -0,0 +1,168 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include <stdexcept>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
namespace jax {
namespace {
std::string ErrorString(hipError_t error) { return hipGetErrorString(error); }
std::string ErrorString(hipsparseStatus_t status) {
// TODO(reza): check and see if we can use hipify
switch (status) {
case HIPSPARSE_STATUS_SUCCESS:
return "hipSparse success.";
case HIPSPARSE_STATUS_NOT_INITIALIZED:
return "hipSparse has not been initialized.";
case HIPSPARSE_STATUS_ALLOC_FAILED:
return "hipSparse allocation failed.";
case HIPSPARSE_STATUS_INVALID_VALUE:
return "hipSparse invalid value error.";
case HIPSPARSE_STATUS_ARCH_MISMATCH:
return "hipSparse architecture mismatch error.";
case HIPSPARSE_STATUS_MAPPING_ERROR:
return "hipSpase mapping error.";
case HIPSPARSE_STATUS_EXECUTION_FAILED:
return "hipSparse execution failed.";
case HIPSPARSE_STATUS_INTERNAL_ERROR:
return "hipSparse internal error.";
case HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "hipSparse matrix type not supported error.";
case HIPSPARSE_STATUS_ZERO_PIVOT:
return "hipSparse zero pivot error.";
case HIPSPARSE_STATUS_NOT_SUPPORTED:
return "hipSparse not supported error.";
case HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES:
return "hipSparse insufficient reosourse error.";
default:
return absl::StrCat("Unknown hipSparse error: ", status, ".");
}
}
std::string ErrorString(hipsolverStatus_t status) {
switch (status) {
case HIPSOLVER_STATUS_SUCCESS:
return "hipSolver success.";
case HIPSOLVER_STATUS_NOT_INITIALIZED:
return "hipSolver has not been initialized.";
case HIPSOLVER_STATUS_ALLOC_FAILED:
return "hipSolver allocation failed.";
case HIPSOLVER_STATUS_INVALID_VALUE:
return "hipSolver invalid value error.";
case HIPSOLVER_STATUS_MAPPING_ERROR:
return "hipSolver mapping error.";
case HIPSOLVER_STATUS_EXECUTION_FAILED:
return "hipSolver execution failed.";
case HIPSOLVER_STATUS_INTERNAL_ERROR:
return "hipSolver internal error.";
case HIPSOLVER_STATUS_NOT_SUPPORTED:
return "hipSolver status not supported.";
case HIPSOLVER_STATUS_ARCH_MISMATCH:
return "hipSolver architecture mismatch error.";
case HIPSOLVER_STATUS_HANDLE_IS_NULLPTR:
return "hipSolver null pointer handle error.";
case HIPSOLVER_STATUS_INVALID_ENUM:
return "hipSolver unsupported enum status error.";
default:
return absl::StrCat("Unknown hipSolver error: ", status, ".");
}
}
std::string ErrorString(hipblasStatus_t status) {
switch (status) {
case HIPBLAS_STATUS_SUCCESS:
return "hipBlas success.";
case HIPBLAS_STATUS_NOT_INITIALIZED:
return "hipBlas has not been initialized.";
case HIPBLAS_STATUS_ALLOC_FAILED:
return "hipBlas resource allocation failed.";
case HIPBLAS_STATUS_INVALID_VALUE:
return "hipBlas invalid value error.";
case HIPBLAS_STATUS_MAPPING_ERROR:
return "hipBlas mapping error.";
case HIPBLAS_STATUS_EXECUTION_FAILED:
return "hipBlas execution failed.";
case HIPBLAS_STATUS_INTERNAL_ERROR:
return "hipBlas internal error.";
case HIPBLAS_STATUS_NOT_SUPPORTED:
return "hipBlas not supported error.";
case HIPBLAS_STATUS_ARCH_MISMATCH:
return "hipBlas architecture mismatch.";
case HIPBLAS_STATUS_HANDLE_IS_NULLPTR:
return "hipBlas null pointer handle error.";
case HIPBLAS_STATUS_INVALID_ENUM:
return "hipBlas unsupported enum status error.";
default:
return absl::StrCat("Unknown hipBlas error: ", status, ".");
}
}
template <typename T>
std::string ErrorString(T status, const char* file, std::int64_t line,
const char* expr) {
return absl::StrFormat("%s:%d: operation %s failed: %s", file, line, expr,
ErrorString(status));
}
} // namespace
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
const char* expr) {
if (error != hipSuccess)
return absl::InternalError(ErrorString(error, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPSOLVER_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPSPARSE_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::Status AsStatus(hipblasStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != HIPBLAS_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<void* []>>
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
hipMemcpyHostToDevice, stream)));
return std::move(host_ptrs);
}
} // namespace jax

View File

@ -0,0 +1,66 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
#define JAXLIB_HIP_GPU_KERNEL_HELPERS_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "rocm/include/hipsolver.h"
#include "rocm/include/hipsparse.h"
#define JAX_AS_STATUS(expr) jax::AsStatus(expr, __FILE__, __LINE__, #expr)
#define JAX_THROW_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) \
throw std::runtime_error(std::string(s___.message())); \
}
#define JAX_RETURN_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) \
return s___; \
}
namespace jax {
// Used via JAX_AS_STATUS(expr) macro.
absl::Status AsStatus(hipError_t error, const char* file, std::int64_t line,
const char* expr);
absl::Status AsStatus(hipsolverStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(hipsparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(hipblasStatus_t status, const char* file,
std::int64_t line, const char* expr);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
absl::StatusOr<std::unique_ptr<void*[]>>
MakeBatchPointers(hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size);
} // namespace jax
#endif // JAXLIB_HIP_GPU_KERNEL_HELPERS_H_

51
jaxlib/hip_linalg.cc Normal file
View File

@ -0,0 +1,51 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/hip_lu_pivot_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
namespace jax {
namespace {
std::string
BuildHipLuPivotsToPermutationDescriptor(std::int64_t batch_size,
std::int32_t pivot_size,
std::int32_t permutation_size) {
return PackDescriptorAsString(LuPivotsToPermutationDescriptor{
batch_size, pivot_size, permutation_size});
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["hip_lu_pivots_to_permutation"] =
EncapsulateFunction(HipLuPivotsToPermutation);
return dict;
}
PYBIND11_MODULE(_hip_linalg, m) {
m.def("registrations", &Registrations);
m.def("hip_lu_pivots_to_permutation_descriptor",
[](std::int64_t batch_size, std::int32_t pivot_size,
std::int32_t permutation_size) {
std::string result = BuildHipLuPivotsToPermutationDescriptor(
batch_size, pivot_size, permutation_size);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace jax

63
jaxlib/hip_linalg.py Normal file
View File

@ -0,0 +1,63 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import numpy as np
from jaxlib import xla_client
try:
from . import _hip_linalg
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
pass
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def lu_pivots_to_permutation(c, pivots, *, permutation_size):
"""Kernel for the transformation of pivots to permutations on GPU."""
pivots_shape = c.get_shape(pivots)
dims = pivots_shape.dimensions()
dtype = np.dtype(np.int32)
assert pivots_shape.element_type() == dtype
batch_size = _prod(dims[:-1])
pivot_size = dims[-1]
opaque = _hip_linalg.hip_lu_pivots_to_permutation_descriptor(
batch_size, pivot_size, permutation_size)
pivots_layout = tuple(range(len(dims) - 1, -1, -1))
pivots_shape_with_layout = xla_client.Shape.array_shape(
dtype, dims, pivots_layout)
permutations_layout = pivots_layout
permutations_dims = list(dims)
permutations_dims[-1] = permutation_size
permutations_shape_with_layout = xla_client.Shape.array_shape(
dtype, permutations_dims, permutations_layout)
return xla_client.ops.CustomCallWithLayout(
c,
b"hip_lu_pivots_to_permutation",
operands=(pivots,),
shape_with_layout=permutations_shape_with_layout,
operand_shapes_with_layout=(pivots_shape_with_layout,),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)

View File

@ -0,0 +1,48 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_lu_pivot_kernels.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
absl::Status HipLuPivotsToPermutation_(hipStream_t stream, void** buffers,
const char* opaque,
std::size_t opaque_len) {
auto s =
UnpackDescriptor<LuPivotsToPermutationDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchLuPivotsToPermutationKernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
return absl::OkStatus();
}
} // namespace
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status) {
auto s = HipLuPivotsToPermutation_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
absl::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace jax

View File

@ -0,0 +1,43 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_LU_PIVOT_KERNELS_H_
#define JAXLIB_HIP_LU_PIVOT_KERNELS_H_
#include <cstddef>
#include <string>
#include "rocm/include/hip/hip_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
struct LuPivotsToPermutationDescriptor {
std::int64_t batch_size;
std::int32_t pivot_size;
std::int32_t permutation_size;
};
void LaunchLuPivotsToPermutationKernel(
hipStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor);
void HipLuPivotsToPermutation(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len,
XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIP_LU_PIVOT_KERNELS_H_

View File

@ -0,0 +1,77 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_lu_pivot_kernels.h"
#include <array>
#include <iostream>
namespace jax {
namespace {
__device__ void ComputePermutation(const std::int32_t* pivots,
std::int32_t* permutation_out,
const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (int i = 0; i < permutation_size; ++i) {
permutation_out[i] = i;
}
// Compute the permutation from a sequence of transpositions encoded in the
// pivot array by applying the transpositions in order on the identity
// permutation.
for (int i = 0; i < pivot_size; ++i) {
if ((pivots[i] < 0) || (pivots[i] >= permutation_size)) {
continue;
}
std::int32_t swap_temporary = permutation_out[i];
permutation_out[i] = permutation_out[pivots[i]];
permutation_out[pivots[i]] = swap_temporary;
}
}
__global__ void LuPivotsToPermutationKernel(
const std::int32_t* pivots, std::int32_t* permutation_out,
const std::int64_t batch_size, const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < batch_size; idx += blockDim.x * gridDim.x) {
// Fill in the output array with the identity permutation.
ComputePermutation(pivots + idx * pivot_size,
permutation_out + idx * permutation_size, pivot_size,
permutation_size);
}
}
} // namespace
void LaunchLuPivotsToPermutationKernel(
hipStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor) {
const std::int32_t* pivots =
reinterpret_cast<const std::int32_t*>(buffers[0]);
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
const int block_dim = 128;
const std::int64_t grid_dim = std::min<std::int64_t>(
1024, (descriptor.batch_size + block_dim - 1) / block_dim);
LuPivotsToPermutationKernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
descriptor.permutation_size);
}
} // namespace jax

43
jaxlib/hip_prng.cc Normal file
View File

@ -0,0 +1,43 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_prng_kernels.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
namespace {
std::string BuildHipThreeFry2x32Descriptor(std::int64_t n) {
return PackDescriptorAsString(ThreeFry2x32Descriptor{n});
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["hip_threefry2x32"] = EncapsulateFunction(HipThreeFry2x32);
return dict;
}
PYBIND11_MODULE(_hip_prng, m) {
m.def("registrations", &Registrations);
m.def("hip_threefry2x32_descriptor", [](std::int64_t n) {
std::string result = BuildHipThreeFry2x32Descriptor(n);
return pybind11::bytes(result);
});
}
} // namespace
} // namespace jax

56
jaxlib/hip_prng.py Normal file
View File

@ -0,0 +1,56 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import operator
import numpy as np
from jaxlib import xla_client
try:
from . import _hip_prng
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
pass
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def threefry2x32(c, keys, data):
"""ThreeFry2x32 kernel for GPU."""
assert len(keys) == 2, keys
assert len(data) == 2, data
dims = c.get_shape(keys[0]).dimensions()
dtype = np.dtype(np.uint32)
for x in itertools.chain(keys, data):
x_shape = c.get_shape(x)
assert x_shape.element_type() == dtype
assert dims == x_shape.dimensions(), (dims, x_shape)
ndims = len(dims)
opaque = _hip_prng.hip_threefry2x32_descriptor(_prod(dims))
layout = tuple(range(ndims - 1, -1, -1))
shape = xla_client.Shape.array_shape(dtype, dims, layout)
return xla_client.ops.CustomCallWithLayout(
c,
b"hip_threefry2x32",
operands=(keys[0], keys[1], data[0], data[1]),
shape_with_layout=xla_client.Shape.tuple_shape([shape, shape]),
operand_shapes_with_layout=(shape, ) * 4,
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)

View File

@ -0,0 +1,45 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_prng_kernels.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace {
absl::Status HipThreeFry2x32_(hipStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto s = UnpackDescriptor<ThreeFry2x32Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
LaunchThreeFry2x32Kernel(stream, buffers, **s);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipGetLastError()));
return absl::OkStatus();
}
} // namespace
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = HipThreeFry2x32_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
absl::string_view message = s.message();
XlaCustomCallStatusSetFailure(status, message.data(), message.length());
}
}
} // namespace jax

39
jaxlib/hip_prng_kernels.h Normal file
View File

@ -0,0 +1,39 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIP_PRNG_KERNELS_H_
#define JAXLIB_HIP_PRNG_KERNELS_H_
#include <cstddef>
#include <string>
#include "rocm/include/hip/hip_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
struct ThreeFry2x32Descriptor {
std::int64_t n;
};
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor);
void HipThreeFry2x32(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIP_PRNG_KERNELS_H_

View File

@ -0,0 +1,116 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hip_prng_kernels.h"
#include <array>
#include <cstddef>
namespace jax {
namespace {
__global__ void
ThreeFry2x32Kernel(const std::uint32_t* key0, const std::uint32_t* key1,
const std::uint32_t* data0, const std::uint32_t* data1,
std::uint32_t* out0, std::uint32_t* out1, std::int64_t n) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < n;
idx += blockDim.x * gridDim.x) {
// Rotation distances specified by the Threefry2x32 algorithm.
std::uint32_t rotations[8] = {13, 15, 26, 6, 17, 29, 16, 24};
std::uint32_t x[2];
std::uint32_t ks[3];
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = 0x1BD11BDA;
ks[0] = key0[idx];
x[0] = data0[idx];
ks[2] = ks[2] ^ key0[idx];
ks[1] = key1[idx];
x[1] = data1[idx];
ks[2] = ks[2] ^ key1[idx];
auto rotate_left = [](std::uint32_t v, std::uint32_t distance) {
return (v << distance) | (v >> (32 - distance));
};
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [&](std::uint32_t* v, std::uint32_t rotation) {
v[0] += v[1];
v[1] = rotate_left(v[1], rotation);
v[1] ^= v[0];
};
// There are no known statistical flaws with 13 rounds of Threefry2x32.
// We are conservative and use 20 rounds.
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1];
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 1u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[2];
x[1] = x[1] + ks[0] + 2u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1] + 3u;
for (int i = 4; i < 8; ++i) {
round(x, rotations[i]);
}
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + 4u;
for (int i = 0; i < 4; ++i) {
round(x, rotations[i]);
}
out0[idx] = x[0] + ks[2];
out1[idx] = x[1] + ks[0] + 5u;
}
}
} // namespace
void LaunchThreeFry2x32Kernel(hipStream_t stream, void** buffers,
ThreeFry2x32Descriptor descriptor) {
std::array<const std::uint32_t*, 2> keys;
keys[0] = reinterpret_cast<const std::uint32_t*>(buffers[0]);
keys[1] = reinterpret_cast<const std::uint32_t*>(buffers[1]);
std::array<const std::uint32_t*, 2> data;
data[0] = reinterpret_cast<const std::uint32_t*>(buffers[2]);
data[1] = reinterpret_cast<const std::uint32_t*>(buffers[3]);
std::array<std::uint32_t*, 2> out;
out[0] = reinterpret_cast<std::uint32_t*>(buffers[4]);
out[1] = reinterpret_cast<std::uint32_t*>(buffers[5]);
const int block_dim = 128;
const std::int64_t grid_dim =
std::min<std::int64_t>(1024, (descriptor.n + block_dim - 1) / block_dim);
ThreeFry2x32Kernel<<<grid_dim, block_dim, /*dynamic_shared_mem_bytes=*/0,
stream>>>(keys[0], keys[1], data[0], data[1], out[0],
out[1], descriptor.n);
}
} // namespace jax

93
jaxlib/hipblas.cc Normal file
View File

@ -0,0 +1,93 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/hipblas_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
namespace jax {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
HipblasType DtypeToHipblasType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, HipblasType>({
{{'f', 4}, HipblasType::F32},
{{'f', 8}, HipblasType::F64},
{{'c', 8}, HipblasType::C64},
{{'c', 16}, HipblasType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
// Returns the descriptor for a TrsmBatched operation.
std::pair<size_t, py::bytes>
BuildTrsmBatchedDescriptor(const py::dtype& dtype, int batch, int m, int n,
bool left_side, bool lower, bool trans_a,
bool conj_a, bool unit_diagonal) {
size_t size = batch * sizeof(void*);
TrsmBatchedDescriptor desc;
desc.type = DtypeToHipblasType(dtype);
desc.batch = batch;
desc.m = m;
desc.n = n;
desc.side = left_side ? HIPBLAS_SIDE_LEFT : HIPBLAS_SIDE_RIGHT;
desc.uplo = lower ? HIPBLAS_FILL_MODE_LOWER : HIPBLAS_FILL_MODE_UPPER;
desc.trans = trans_a ? (conj_a ? HIPBLAS_OP_C : HIPBLAS_OP_T) : HIPBLAS_OP_N;
desc.diag = unit_diagonal ? HIPBLAS_DIAG_UNIT : HIPBLAS_DIAG_NON_UNIT;
return {size, PackDescriptor(desc)};
}
// Returns the descriptor for a GetrfBatched operation.
std::pair<size_t, py::bytes> BuildGetrfBatchedDescriptor(const py::dtype& dtype,
int b, int n) {
HipblasType type = DtypeToHipblasType(dtype);
size_t size = b * sizeof(void*);
return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})};
}
py::dict Registrations() {
py::dict dict;
dict["hipblas_trsm_batched"] = EncapsulateFunction(TrsmBatched);
dict["hipblas_getrf_batched"] = EncapsulateFunction(GetrfBatched);
return dict;
}
PYBIND11_MODULE(_hipblas, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_batched_descriptor", &BuildTrsmBatchedDescriptor);
m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor);
}
} // namespace
} // namespace jax

222
jaxlib/hipblas_kernels.cc Normal file
View File

@ -0,0 +1,222 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hipblas_kernels.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using BlasHandlePool = HandlePool<hipblasHandle_t, hipStream_t>;
template <>
/*static*/ absl::StatusOr<BlasHandlePool::Handle>
BlasHandlePool::Borrow(hipStream_t stream) {
BlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipblasHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
namespace {
// Converts a NumPy dtype to a CublasType.
int SizeOfHipblasType(HipblasType type) {
switch (type) {
case HipblasType::F32:
return sizeof(float);
case HipblasType::F64:
return sizeof(double);
case HipblasType::C64:
return sizeof(hipComplex);
case HipblasType::C128:
return sizeof(hipDoubleComplex);
}
}
} // namespace
// Batched triangular solve: trsmbatched
static absl::Status TrsmBatched_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<TrsmBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[1], SizeOfHipblasType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)));
}
const int lda = d.side == HIPBLAS_SIDE_LEFT ? d.m : d.n;
const int ldb = d.m;
auto a_batch_host = MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfHipblasType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host = MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfHipblasType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
switch (d.type) {
case HipblasType::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
// TODO(reza): is the following statement correct for rocm?
// NOTE(phawkins): if alpha is in GPU memory, cuBlas seems to segfault.
const float alpha = 1.0f;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasStrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb, d.batch)));
break;
}
case HipblasType::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case HipblasType::C64: {
hipblasComplex** a_batch_ptrs = static_cast<hipblasComplex**>(buffers[3]);
hipblasComplex** b_batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
const hipblasComplex alpha = hipblasComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<hipblasComplex**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)));
break;
}
case HipblasType::C128: {
hipblasDoubleComplex** a_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[3]);
hipblasDoubleComplex** b_batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[4]);
const hipblasDoubleComplex alpha = hipblasDoubleComplex(1.0f, 0.0f);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZtrsmBatched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<hipblasDoubleComplex**>(a_batch_ptrs), lda, b_batch_ptrs,
ldb, d.batch)));
break;
}
}
return absl::OkStatus();
}
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = TrsmBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Batched LU decomposition: getrfbatched
static absl::Status GetrfBatched_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfBatchedDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfBatchedDescriptor& d = **s;
auto h = BlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[0] != buffers[1]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfHipblasType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfHipblasType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
switch (d.type) {
case HipblasType::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasSgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::F64: {
double** batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasDgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::C64: {
hipblasComplex** batch_ptrs = static_cast<hipblasComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasCgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
case HipblasType::C128: {
hipblasDoubleComplex** batch_ptrs =
static_cast<hipblasDoubleComplex**>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipblasZgetrfBatched(
handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch)));
break;
}
}
return absl::OkStatus();
}
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = GetrfBatched_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace jax

61
jaxlib/hipblas_kernels.h Normal file
View File

@ -0,0 +1,61 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPBLAS_KERNELS_H_
#define JAXLIB_HIPBLAS_KERNELS_H_
#include <cstddef>
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
// Set of types known to Hipsolver.
enum class HipblasType {
F32,
F64,
C64,
C128,
};
// Batched triangular solve: trsmbatched
struct TrsmBatchedDescriptor {
HipblasType type;
int batch, m, n;
hipblasSideMode_t side;
hipblasFillMode_t uplo;
hipblasOperation_t trans;
hipblasDiagType_t diag;
};
void TrsmBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Batched LU decomposition: getrfbatched
struct GetrfBatchedDescriptor {
HipblasType type;
int batch, n;
};
void GetrfBatched(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIPBLAS_KERNELS_H_

369
jaxlib/hipsolver.cc Normal file
View File

@ -0,0 +1,369 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/hipsolver_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsolver.h"
namespace jax {
namespace {
namespace py = pybind11;
// Converts a NumPy dtype to a Type.
HipsolverType DtypeToHipsolverType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, HipsolverType>({
{{'f', 4}, HipsolverType::F32},
{{'f', 8}, HipsolverType::F64},
{{'c', 8}, HipsolverType::C64},
{{'c', 16}, HipsolverType::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
// potrf: Cholesky decomposition
// Returns the workspace size and a descriptor for a potrf operation.
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
std::int64_t workspace_size;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
if (b == 1) {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(float);
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(double);
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(hipComplex);
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZpotrf_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork)));
workspace_size = lwork * sizeof(hipDoubleComplex);
break;
}
} else {
// TODO(rocm): when cuda and hip had same API for batched potrf, remove this
// batched potrf has different API compared to CUDA. In hip we still need to create the workspace and additional space to copy the batch array pointers
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(float)) + (b * sizeof(float*));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(double)) + (b * sizeof(double*));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(hipComplex)) + (b * sizeof(hipComplex*));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZpotrfBatched_bufferSize(handle.get(), uplo, n,
/*A=*/nullptr,
/*lda=*/n, &lwork, b)));
workspace_size = (lwork * sizeof(hipDoubleComplex)) + (b * sizeof(hipDoubleComplex*));
break;
}
}
return {workspace_size,
PackDescriptor(PotrfDescriptor{type, uplo, b, n, lwork})};
}
// getrf: LU decomposition
// Returns the workspace size and a descriptor for a getrf operation.
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})};
}
// geqrf: QR decomposition
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZgeqrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork)));
break;
}
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})};
}
// orgqr/ungqr: apply elementary Householder transformations
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
int m, int n, int k) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverSorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverDorgqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverCungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsolverZungqr_bufferSize(handle.get(), m, n, k,
/*A=*/nullptr,
/*lda=*/m,
/*tau=*/nullptr, &lwork)));
break;
}
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})};
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
// Returns the workspace size and a descriptor for a syevd operation.
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverSsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverDsyevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverCheevd_bufferSize(handle.get(), jobz, uplo, n, /*A=*/nullptr,
/*lda=*/n, /*W=*/nullptr, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork)));
break;
}
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
}
// Singular value decomposition using QR algorithm: gesvd
// Returns the workspace size and a descriptor for a gesvd operation.
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
signed char jobu, jobvt;
if (compute_uv) {
if (full_matrices) {
jobu = jobvt = 'A';
} else {
jobu = jobvt = 'S';
}
} else {
jobu = jobvt = 'N';
}
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverSgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverDgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverCgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsolverZgesvd_bufferSize(handle.get(), jobu, jobvt, m, n, &lwork)));
break;
}
return {lwork,
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
}
py::dict Registrations() {
py::dict dict;
dict["hipsolver_potrf"] = EncapsulateFunction(Potrf);
dict["hipsolver_getrf"] = EncapsulateFunction(Getrf);
dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf);
dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr);
dict["hipsolver_syevd"] = EncapsulateFunction(Syevd);
// dict["cusolver_syevj"] = EncapsulateFunction(Syevj); not supported by
// ROCm yet
dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd);
// dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by
// ROCm yet
return dict;
}
PYBIND11_MODULE(_hipsolver, m) {
m.def("registrations", &Registrations);
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
// m.def("build_syevj_descriptor", &BuildSyevjDescriptor); not supported by
// ROCm yet
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by
// ROCm yet
}
} // namespace
} // namespace jax

381
jaxlib/hipsolver.py Normal file
View File

@ -0,0 +1,381 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import numpy as np
from jaxlib import xla_client
try:
from . import _hipblas
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
pass
try:
from . import _hipsolver
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
pass
_ops = xla_client.ops
_Shape = xla_client.Shape
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def trsm(c,
a,
b,
left_side=False,
lower=False,
trans_a=False,
conj_a=False,
diag=False):
"""Batched triangular solve.
XLA implements unbatched triangular solve directly, so we need only implement
the batched case."""
b_shape = c.get_shape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
k = m if left_side else n
a_shape = c.get_shape(a)
if (batch_dims + (k, k) != a_shape.dimensions()
or a_shape.element_type() != dtype):
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
if conj_a and not trans_a:
raise NotImplementedError(
"Conjugation without transposition not supported")
lwork, opaque = _hipblas.build_trsm_batched_descriptor(
np.dtype(dtype), batch, m, n, left_side, lower, trans_a, conj_a, diag)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = _ops.CustomCallWithLayout(
c,
b"hipblas_trsm_batched",
operands=(a, b),
shape_with_layout=_Shape.tuple_shape(
(_Shape.array_shape(dtype, b_shape.dimensions(), layout),
_Shape.array_shape(np.dtype(np.int8), (lwork, ), (0, )),
_Shape.array_shape(np.dtype(np.int8), (lwork, ), (0, )))),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, a_shape.dimensions(), layout),
_Shape.array_shape(dtype, b_shape.dimensions(), layout),
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)
def potrf(c, a, lower):
"""Cholesky decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
lwork, opaque = _hipsolver.build_potrf_descriptor(np.dtype(dtype), lower,
batch, n)
kernel = b"hipsolver_potrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, ),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (n, n), (num_bd, num_bd + 1) +
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(np.dtype(np.int8), (lwork, ), (0, )),
)),
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (n, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
def getrf(c, a):
"""LU decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
if batch > 1 and m == n and m // batch <= 128:
lwork, opaque = _hipblas.build_getrf_batched_descriptor(
np.dtype(dtype), batch, m)
workspace = _Shape.array_shape(np.dtype(np.int8), (lwork, ), (0, ))
kernel = b"hipblas_getrf_batched"
else:
lwork, opaque = _hipsolver.build_getrf_descriptor(np.dtype(dtype), batch,
m, n)
workspace = _Shape.array_shape(dtype, (lwork, ), (0, ))
kernel = b"hipsolver_getrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, ),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), (num_bd, num_bd + 1) +
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims + (min(m, n), ),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
workspace,
)),
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
def geqrf(c, a):
"""QR decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
lwork, opaque = _hipsolver.build_geqrf_descriptor(np.dtype(dtype), batch, m,
n)
workspace = _Shape.array_shape(dtype, (lwork, ), (0, ))
kernel = b"hipsolver_geqrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, ),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), (num_bd, num_bd + 1) +
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(dtype, batch_dims + (min(m, n), ),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
workspace,
)),
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
def orgqr(c, a, tau):
"""Product of elementary Householder reflections."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
tau_dims = c.get_shape(tau).dimensions()
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]
lwork, opaque = _hipsolver.build_orgqr_descriptor(np.dtype(dtype), batch, m,
n, k)
workspace = _Shape.array_shape(dtype, (lwork, ), (0, ))
kernel = b"hipsolver_orgqr"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, tau),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), (num_bd, num_bd + 1) +
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
workspace,
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n), (num_bd, num_bd + 1) +
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(dtype, batch_dims + (k, ),
tuple(range(num_bd, -1, -1))),
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1))
def syevd(c, a, lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
# TODO(rocm): rocm does not support jacobian method.
kernel = b"hipsolver_syevd"
lwork, opaque = _hipsolver.build_syevd_descriptor(np.dtype(dtype), lower,
batch, n)
eigvals_type = _real_type(dtype)
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, ),
shape_with_layout=_Shape.tuple_shape(
(_Shape.array_shape(dtype, dims, layout),
_Shape.array_shape(np.dtype(eigvals_type), batch_dims + (n, ),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(dtype, (lwork, ), (0, )))),
operand_shapes_with_layout=(_Shape.array_shape(dtype, dims, layout), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
def gesvd(c, a, full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_shape = c.get_shape(a)
dims = a_shape.dimensions()
dtype = a_shape.element_type()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = _prod(batch_dims)
singular_vals_dtype = np.dtype(_real_type(dtype))
# TODO(rocm): rocm does not support jacobian method.
# for cuda, jax uses jacobian method for small size matrixes
if m < n:
lwork, opaque = _hipsolver.build_gesvd_descriptor(np.dtype(dtype), b, n, m,
compute_uv,
full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd, ) + scalar_layout
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
out = _ops.CustomCallWithLayout(
c,
b"hipsolver_gesvd",
operands=(a, ),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n), ),
vector_layout),
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
_Shape.array_shape(dtype, (lwork, ), (0, )),
)),
operand_shapes_with_layout=(_Shape.array_shape(dtype,
batch_dims + (m, n),
matrix_layout), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
vt = _ops.GetTupleElement(out, 2)
u = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
else:
lwork, opaque = _hipsolver.build_gesvd_descriptor(np.dtype(dtype), b, m, n,
compute_uv,
full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd, ) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = _ops.CustomCallWithLayout(
c,
b"hipsolver_gesvd",
operands=(a, ),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n), ),
vector_layout),
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
_Shape.array_shape(dtype, (lwork, ), (0, )),
)),
operand_shapes_with_layout=(_Shape.array_shape(dtype,
batch_dims + (m, n),
matrix_layout), ),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion.
API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
if not full_matrices:
u = _ops.Slice(u, (0, ) * len(dims), batch_dims + (m, min(m, n)),
(1, ) * len(dims))
vt = _ops.Slice(vt, (0, ) * len(dims), batch_dims + (min(m, n), n),
(1, ) * len(dims))
return s, u, vt, info

620
jaxlib/hipsolver_kernels.cc Normal file
View File

@ -0,0 +1,620 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hipsolver_kernels.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsolver.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
template <>
/*static*/ absl::StatusOr<SolverHandlePool::Handle>
SolverHandlePool::Borrow(hipStream_t stream) {
SolverHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipsolverHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
static int SizeOfHipsolverType(HipsolverType type) {
switch (type) {
case HipsolverType::F32:
return sizeof(float);
case HipsolverType::F64:
return sizeof(double);
case HipsolverType::C64:
return sizeof(hipFloatComplex);
case HipsolverType::C128:
return sizeof(hipDoubleComplex);
}
}
// potrf: Cholesky decomposition
static absl::Status Potrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const PotrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[2]);
void* workspace = buffers[3];
if (d.batch == 1) {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSpotrf(handle.get(), d.uplo, d.n, a, d.n,
static_cast<float*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDpotrf(handle.get(), d.uplo, d.n, a, d.n,
static_cast<double*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrf(
handle.get(), d.uplo, d.n, a, d.n,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrf(
handle.get(), d.uplo, d.n, a, d.n,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
break;
}
}
} else {
auto buffer_ptrs_host =
MakeBatchPointers(stream, buffers[1], workspace, d.batch,
SizeOfHipsolverType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(buffer_ptrs_host.status());
// Make sure that accesses to buffer_ptrs_host complete before we delete it.
// TODO(phawkins): avoid synchronization here.
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipStreamSynchronize(stream)));
switch (d.type) {
case HipsolverType::F32: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<float**>(workspace), d.n,
static_cast<float*>(workspace + (d.batch * sizeof(float*))), d.lwork,
info, d.batch)));
break;
}
case HipsolverType::F64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<double**>(workspace), d.n,
static_cast<double*>(workspace + (d.batch * sizeof(double*))), d.lwork,
info, d.batch)));
break;
}
case HipsolverType::C64: {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipFloatComplex**>(workspace), d.n,
static_cast<hipFloatComplex*>(workspace + (d.batch * sizeof(hipFloatComplex*))),d.lwork,
info, d.batch)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
static_cast<hipDoubleComplex*>(workspace + (d.batch * sizeof(hipDoubleComplex*))), d.lwork,
info, d.batch)));
break;
}
}
}
return absl::OkStatus();
}
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// getrf: LU decomposition
static absl::Status Getrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<float*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<double*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverCgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<hipFloatComplex*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgetrf(
handle.get(), d.m, d.n, a, d.m,
static_cast<hipDoubleComplex*>(workspace), d.lwork, ipiv, info)));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
}
return absl::OkStatus();
}
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// geqrf: QR decomposition
static absl::Status Geqrf_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* tau = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
static_cast<float*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* tau = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDgeqrf(handle.get(), d.m, d.n, a, d.m, tau,
static_cast<double*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgeqrf(
handle.get(), d.m, d.n, a, d.m, tau,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgeqrf(
handle.get(), d.m, d.n, a, d.m, tau,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += std::min(d.m, d.n);
++info;
}
break;
}
}
return absl::OkStatus();
}
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// orgqr/ungqr: apply elementary Householder transformations
static absl::Status Orgqr_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const OrgqrDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[2] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[2], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[3]);
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
void* workspace = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[2]);
float* tau = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<float*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[2]);
double* tau = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<double*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[2]);
hipFloatComplex* tau = static_cast<hipFloatComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCungqr(
handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<hipFloatComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[2]);
hipDoubleComplex* tau = static_cast<hipDoubleComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZungqr(
handle.get(), d.m, d.n, d.k, a, d.m, tau,
static_cast<hipDoubleComplex*>(workspace), d.lwork, info)));
a += d.m * d.n;
tau += d.k;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
static absl::Status Syevd_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SyevdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SyevdDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevd(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work), d.lwork, info)));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// TODO(rocm): add Syevj_ apis when support from hipsolver is ready
// Singular value decomposition using QR algorithm: gesvd
static absl::Status Gesvd_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GesvdDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverSgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
u, d.m, vt, d.n, static_cast<float*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<double*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
hipFloatComplex* u = static_cast<hipFloatComplex*>(buffers[3]);
hipFloatComplex* vt = static_cast<hipFloatComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<hipFloatComplex*>(work), d.lwork, /*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
hipDoubleComplex* u = static_cast<hipDoubleComplex*>(buffers[3]);
hipDoubleComplex* vt = static_cast<hipDoubleComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<hipDoubleComplex*>(work), d.lwork,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// TODO(rocm): add Gesvdj_ apis when support from hipsolver is ready
} // namespace jax

109
jaxlib/hipsolver_kernels.h Normal file
View File

@ -0,0 +1,109 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPSOLVER_KERNELS_H_
#define JAXLIB_HIPSOLVER_KERNELS_H_
#include "absl/status/statusor.h"
#include "jaxlib/handle_pool.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipblas.h"
#include "rocm/include/hipsolver.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using SolverHandlePool = HandlePool<hipsolverHandle_t, hipStream_t>;
template <>
absl::StatusOr<SolverHandlePool::Handle>
SolverHandlePool::Borrow(hipStream_t stream);
// Set of types known to Hipsolver.
enum class HipsolverType {
F32,
F64,
C64,
C128,
};
// potrf: Cholesky decomposition
struct PotrfDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
std::int64_t batch, n;
int lwork;
};
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// getrf: LU decomposition
struct GetrfDescriptor {
HipsolverType type;
int batch, m, n, lwork;
};
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// geqrf: QR decomposition
struct GeqrfDescriptor {
HipsolverType type;
int batch, m, n, lwork;
};
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// orgqr/ungqr: apply elementary Householder transformations
struct OrgqrDescriptor {
HipsolverType type;
int batch, m, n, k, lwork;
};
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
struct SyevdDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Singular value decomposition using QR algorithm: gesvd
struct GesvdDescriptor {
HipsolverType type;
int batch, m, n;
int lwork;
signed char jobu, jobvt;
};
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIPSOLVER_KERNELS_H_

571
jaxlib/hipsparse.cc Normal file
View File

@ -0,0 +1,571 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/hipsparse.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/hipsparse_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace py = pybind11;
namespace jax {
namespace {
hipsparseIndexType_t DtypeToHipSparseIndexType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, hipsparseIndexType_t>({
{{'u', 2}, HIPSPARSE_INDEX_16U},
{{'i', 4}, HIPSPARSE_INDEX_32I},
{{'i', 8}, HIPSPARSE_INDEX_64I},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported index dtype: %s", py::repr(np_type)));
}
return it->second;
}
// TODO(rocm): add more hip data types when supported
hipDataType DtypeToHipDataType(const py::dtype& np_type) {
static auto* types =
new absl::flat_hash_map<std::pair<char, int>, hipDataType>(
{{{'f', 2}, HIP_R_16F},
{{'c', 4}, HIP_C_16F},
{{'f', 4}, HIP_R_32F},
{{'c', 8}, HIP_C_32F},
{{'f', 8}, HIP_R_64F},
{{'c', 16}, HIP_C_64F}});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported data dtype: %s", py::repr(np_type)));
}
return it->second;
}
// Returns the descriptor for a Sparse matrix.
SparseMatDescriptor BuildSparseMatDescriptor(const py::dtype& data_dtype,
const py::dtype& index_dtype,
int rows, int cols, int nnz) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
hipsparseIndexType_t index_type = DtypeToHipSparseIndexType(index_dtype);
return SparseMatDescriptor{value_type, index_type, rows, cols, nnz};
}
// Returns the descriptor for a Dense matrix.
DenseMatDescriptor BuildDenseMatDescriptor(const py::dtype& data_dtype,
int rows, int cols) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
return DenseMatDescriptor{value_type, rows, cols};
}
// Returns the descriptor for a Dense vector.
DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
int size) {
hipDataType value_type = DtypeToHipDataType(data_dtype);
return DenseVecDescriptor{value_type, size};
}
// CsrToDense: Convert CSR matrix to dense matrix
// Returns the descriptor for a Sparse matrix.
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
// buffer_size does not reference these pointers, but does error on NULL.
// TODO(jakevdp): check whether this is documented.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrToDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrFromDense: Convert dense matrix to CSR matrix
// Returns the descriptor for a CsrFromDense operation.
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty, d.index_type,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatvec: Product of CSR matrix and dense vector.
// Returns the descriptor for a CsrMatvec operation.
std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
HipConst alpha = HipOne(y.type);
HipConst beta = HipZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CsrMatvecDescriptor{A, x, y, op})};
}
// CsrMatmat: Product of CSR matrix and dense matrix.
// Returns the descriptor for a CsrMatmat operation.
std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty, A.index_type,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, HIPSPARSE_ORDER_ROW)));
size_t buffer_size;
HipConst alpha = HipOne(C.type);
HipConst beta = HipZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CsrMatmatDescriptor{A, B, C, op_A})};
}
// CooToDense: Convert COO matrix to dense matrix
// Returns the descriptor for a CooToDense operation.
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSparseToDense_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
// CooFromDense: Convert dense matrix to COO matrix
// Returns the descriptor for a CooFromDense operation.
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor d =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, empty, d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz, empty, empty, empty,
d.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
size_t buffer_size;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_bufferSize(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
return {buffer_size, PackDescriptor(d)};
}
// CooMatvec: Product of COO matrix and dense vector.
// Returns the descriptor for a CooMatvec operation.
std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
const py::dtype& data_dtype, const py::dtype& x_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseVecDescriptor x =
BuildDenseVecDescriptor(x_dtype, transpose ? rows : cols);
DenseVecDescriptor y =
BuildDenseVecDescriptor(compute_dtype, transpose ? cols : rows);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
hipsparseOperation_t op = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, x.size, empty, x.type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, y.size, empty, y.type)));
size_t buffer_size;
HipConst alpha = HipOne(y.type);
HipConst beta = HipZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
HIPSPARSE_MV_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
return {buffer_size, PackDescriptor(CooMatvecDescriptor{A, x, y, op})};
}
// CooMatmat: Product of COO matrix and dense matrix.
// Returns the descriptor for a CooMatmat operation.
std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
const py::dtype& data_dtype, const py::dtype& b_dtype,
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
int cols, int BCcols, int nnz, bool transpose) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
SparseMatDescriptor A =
BuildSparseMatDescriptor(data_dtype, index_dtype, rows, cols, nnz);
DenseMatDescriptor B =
BuildDenseMatDescriptor(b_dtype, transpose ? rows : cols, BCcols);
DenseMatDescriptor C =
BuildDenseMatDescriptor(compute_dtype, transpose ? cols : rows, BCcols);
hipsparseOperation_t op_A = transpose ? HIPSPARSE_OPERATION_TRANSPOSE
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, HIPSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, HIPSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, HIPSPARSE_ORDER_ROW)));
size_t buffer_size;
HipConst alpha = HipOne(C.type);
HipConst beta = HipZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM_bufferSize(
handle.get(), op_A, HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, HIPSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
return {buffer_size, PackDescriptor(CooMatmatDescriptor{A, B, C, op_A})};
}
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
}
template <typename F>
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
auto h = SparseHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
size_t size;
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
/*du=*/nullptr, /*B=*/nullptr, ldb, &size)));
return size;
}
size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
return Gtsv2BufferSize(hipsparseSgtsv2_bufferSizeExt, m, n, ldb);
}
size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
return Gtsv2BufferSize(hipsparseDgtsv2_bufferSizeExt, m, n, ldb);
}
py::dict Registrations() {
py::dict dict;
dict["hipsparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["hipsparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["hipsparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
dict["hipsparse_csr_matmat"] = EncapsulateFunction(CsrMatmat);
dict["hipsparse_coo_todense"] = EncapsulateFunction(CooToDense);
dict["hipsparse_coo_fromdense"] = EncapsulateFunction(CooFromDense);
dict["hipsparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict["hipsparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
dict["hipsparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
dict["hipsparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
return dict;
}
PYBIND11_MODULE(_hipsparse, m) {
m.attr("hipsparse_supported") = py::bool_(true);
m.def("registrations", &Registrations);
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
m.def("build_csr_matmat_descriptor", &BuildCsrMatmatDescriptor);
m.def("build_coo_todense_descriptor", &BuildCooToDenseDescriptor);
m.def("build_coo_fromdense_descriptor", &BuildCooFromDenseDescriptor);
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
}
} // namespace
} // namespace jax

332
jaxlib/hipsparse.py Normal file
View File

@ -0,0 +1,332 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
hipsparse wrappers for performing sparse matrix computations in JAX on ROCM
"""
import numpy as np
from jax._src.lib import xla_client
try:
from . import _hipsparse
except ImportError:
_hipsparse = None
else:
for _name, _value in _hipsparse.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
is_supported : bool = _hipsparse and _hipsparse.hipsparse_supported
_ops = xla_client.ops
_Shape = xla_client.Shape
def _validate_csr(c, data, indices, indptr, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(indices).dimensions() == (nnz,)
assert c.get_shape(indptr).element_type() == index_dtype
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
return data_dtype, index_dtype, nnz
def _validate_coo(c, data, row, col, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(row).dimensions() == (nnz,)
assert c.get_shape(col).element_type() == index_dtype
assert c.get_shape(col).dimensions() == (nnz,)
return data_dtype, index_dtype, nnz
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
buffer_size, opaque = _hipsparse.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_csr_todense",
operands=(data, indices, indptr),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def csr_fromdense(c, mat, *, nnz, index_dtype):
"""CSR from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
shape = c.get_shape(mat).dimensions()
rows, cols = shape
buffer_size, opaque = _hipsparse.build_csr_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_csr_fromdense",
operands=(mat,),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (shape[0] + 1,), (0,)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()
if compute_dtype is None:
compute_dtype = data_dtype
buffer_size, opaque = _hipsparse.build_csr_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_csr_matvec",
operands=(data, indices, indptr, x),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,))
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
buffer_size, opaque = _hipsparse.build_csr_matmat_descriptor(
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_csr_matmat",
operands=(data, indices, indptr, B),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
buffer_size, opaque = _hipsparse.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_coo_todense",
operands=(data, row, col),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def coo_fromdense(c, mat, *, nnz, index_dtype):
"""COO from dense matrix."""
data_dtype = np.dtype(c.get_shape(mat).element_type())
shape = c.get_shape(mat).dimensions()
rows, cols = shape
buffer_size, opaque = _hipsparse.build_coo_fromdense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_coo_fromdense",
operands=(mat,),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(np.dtype(np.int8), (buffer_size,), (0,)),
)),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
"""COO matrix/vector multiply."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()
if compute_dtype is None:
compute_dtype = data_dtype
buffer_size, opaque = _hipsparse.build_coo_matvec_descriptor(
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_coo_matvec",
operands=(data, row, col, x),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
"""COO from dense matrix."""
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
_, Ccols = B_shape
if compute_dtype is None:
compute_dtype = data_dtype
buffer_size, opaque = _hipsparse.build_coo_matmat_descriptor(
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_coo_matmat",
operands=(data, row, col, B),
operand_shapes_with_layout=(
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size, Ccols), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING,
)
return _ops.GetTupleElement(out, 0)
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
"""Calls `hipsparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
f32 = (t == np.float32)
dl_shape, d_shape, du_shape, B_shape = map(c.get_shape, (dl, d, du, B))
if f32:
buffer_size = _hipsparse.gtsv2_f32_buffer_size(m, n, ldb)
else:
buffer_size = _hipsparse.gtsv2_f64_buffer_size(m, n, ldb)
out = xla_client.ops.CustomCallWithLayout(
c,
b"hipsparse_gtsv2_" + (b"f32" if f32 else b"f64"),
operands=(dl, d, du, B),
operand_shapes_with_layout=(dl_shape, d_shape, du_shape, B_shape),
shape_with_layout=_Shape.tuple_shape(
(_Shape.array_shape(np.dtype(t), (ldb, n), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=_hipsparse.build_gtsv2_descriptor(m, n, ldb),
has_side_effect=False,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)

533
jaxlib/hipsparse_kernels.cc Normal file
View File

@ -0,0 +1,533 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/hipsparse_kernels.h"
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/hip_gpu_kernel_helpers.h"
#include "jaxlib/kernel_helpers.h"
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
SparseHandlePool::Borrow(hipStream_t stream) {
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
hipsparseHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
HipConst HipZero(hipDataType type) {
HipConst c;
std::memset(&c, 0, sizeof(c));
return c;
}
HipConst HipOne(hipDataType type) {
HipConst c;
std::memset(&c, 0, sizeof(c));
// TODO(rocm): add more data type if new rocm support
switch (type) {
// TODO(jakevdp): 16F/16BF here might break on big endian platforms.
case HIP_R_16F:
case HIP_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
case HIP_R_32F:
case HIP_C_32F:
c.f32[0] = 1.0;
break;
case HIP_R_64F:
case HIP_C_64F:
c.f64[0] = 1.0;
break;
}
return c;
}
static absl::Status CsrToDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[2],
/*csrColInd=*/buffers[1],
/*csrValues=*/buffers[0], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrFromDense: Convert dense matrix to CSR matrix
static absl::Status CsrFromDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz,
/*csrRowOffsets=*/buffers[3],
/*csrColInd=*/buffers[2],
/*csrValues=*/buffers[1], d.index_type, d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatvec: Product of CSR matrix and dense vector.
static absl::Status CsrMatvec_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.y.type);
HipConst beta = HipZero(d.y.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CsrMatmat: Product of CSR matrix and dense matrix.
static absl::Status CsrMatmat_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CsrMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CsrMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* csr_values = buffers[0];
void* csr_col_ind = buffers[1];
void* csr_row_offsets = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.C.type);
HipConst beta = HipZero(d.C.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCsr(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, csr_row_offsets, csr_col_ind,
csr_values, d.A.index_type, d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO,
d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CsrMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooToDense: Convert COO matrix to dense matrix
static absl::Status CooToDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_a, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[1],
/*cooColInd=*/buffers[2],
/*cooValues=*/buffers[0], d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.rows, d.cols,
/*ld=*/d.cols, buffers[3], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSparseToDense(handle.get(), mat_a, mat_b,
HIPSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
return absl::OkStatus();
}
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooToDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooFromDense: Convert dense matrix to COO matrix
static absl::Status CooFromDense_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SparseMatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SparseMatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
hipsparseDnMatDescr_t mat_a = 0;
hipsparseSpMatDescr_t mat_b = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_a, d.rows, d.cols,
/*ld=*/d.cols, buffers[0], d.value_type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseCreateCoo(&mat_b, d.rows, d.cols, d.nnz,
/*cooRowInd=*/buffers[2],
/*cooColInd=*/buffers[3],
/*cooValues=*/buffers[1], d.index_type,
HIPSPARSE_INDEX_BASE_ZERO, d.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_analysis(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDenseToSparse_convert(
handle.get(), mat_a, mat_b, HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffers[4])));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_b)));
return absl::OkStatus();
}
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooFromDense_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooMatvec: Product of COO matrix and dense vector.
static absl::Status CooMatvec_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatvecDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatvecDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* xbuf = buffers[3];
void* ybuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.y.type);
HipConst beta = HipZero(d.y.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnVecDescr_t vec_x = 0;
hipsparseDnVecDescr_t vec_y = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_x, d.x.size, xbuf, d.x.type)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipsparseCreateDnVec(&vec_y, d.y.size, ybuf, d.y.type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, HIPSPARSE_MV_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_x)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnVec(vec_y)));
return absl::OkStatus();
}
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatvec_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// CooMatmat: Product of COO matrix and dense matrix.
static absl::Status CooMatmat_(hipStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<CooMatmatDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const CooMatmatDescriptor& d = **s;
auto h = SparseHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
void* coo_values = buffers[0];
void* coo_row_ind = buffers[1];
void* coo_col_ind = buffers[2];
void* Bbuf = buffers[3];
void* Cbuf = buffers[4];
void* buf = buffers[5];
// TODO(rocm): check the following statement for rocm
// TODO(jakevdp): alpha and beta should be user-specifiable, but constants
// are sufficient for basic matvec operations.
// Note that, contrary to cusparse docs, alpha and beta must be host pointers
// or else the operation will segfault.
HipConst alpha = HipOne(d.C.type);
HipConst beta = HipZero(d.C.type);
hipsparseSpMatDescr_t mat_a = 0;
hipsparseDnMatDescr_t mat_b = 0;
hipsparseDnMatDescr_t mat_c = 0;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, HIPSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, HIPSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseSpMM(
handle.get(), d.op_A, /*opB=*/HIPSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, HIPSPARSE_SPMM_ALG_DEFAULT, buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_b)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsparseDestroyDnMat(mat_c)));
return absl::OkStatus();
}
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = CooMatmat_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
template <typename T, typename F>
static absl::Status gtsv2(F computeGtsv2, hipStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto h = SparseHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
auto s = UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const Gtsv2Descriptor& descriptor = **s;
int m = descriptor.m;
int n = descriptor.n;
int ldb = descriptor.ldb;
const T* dl = (const T*)(buffers[0]);
const T* d = (const T*)(buffers[1]);
const T* du = (const T*)(buffers[2]);
const T* B = (T*)(buffers[3]);
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];
// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
// Once copy insertion is supported for custom call aliasing, we could alias B
// with X and avoid the copy, the code below is written defensively assuming B
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipMemcpyAsync(X, B, B_bytes, hipMemcpyDeviceToDevice, stream)));
}
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer)));
return absl::OkStatus();
}
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<float>(hipsparseSgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status) {
auto s = gtsv2<double>(hipsparseDgtsv2, stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace jax

150
jaxlib/hipsparse_kernels.h Normal file
View File

@ -0,0 +1,150 @@
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_HIPSPARSE_KERNELS_H_
#define JAXLIB_HIPSPARSE_KERNELS_H_
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "jaxlib/handle_pool.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/hipsparse.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
namespace jax {
using SparseHandlePool = HandlePool<hipsparseHandle_t, hipStream_t>;
template <>
/*static*/ absl::StatusOr<SparseHandlePool::Handle>
SparseHandlePool::Borrow(hipStream_t stream);
union HipConst {
int8_t i8[2];
int16_t i16[2];
int32_t i32[2];
int64_t i64[2];
uint8_t u8[2];
uint16_t u16[2];
uint32_t u32[2];
uint64_t u64[2];
float f32[2];
double f64[2];
};
HipConst HipZero(hipDataType type);
HipConst HipOne(hipDataType type);
struct SparseMatDescriptor {
hipDataType value_type;
hipsparseIndexType_t index_type;
int rows, cols, nnz;
};
struct DenseMatDescriptor {
hipDataType type;
int rows, cols;
};
struct DenseVecDescriptor {
hipDataType type;
int size;
};
// CsrToDense: Convert CSR matrix to dense matrix
void CsrToDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrFromDense: Convert dense matrix to CSR matrix
void CsrFromDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatvec: Product of CSR matrix and dense vector.
struct CsrMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
hipsparseOperation_t op;
};
void CsrMatvec(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CsrMatmat: Product of CSR matrix and dense matrix.
struct CsrMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
hipsparseOperation_t op_A;
};
void CsrMatmat(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooToDense: Convert COO matrix to dense matrix
void CooToDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooFromDense: Convert dense matrix to COO matrix
void CooFromDense(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatvec: Product of COO matrix and dense vector.
struct CooMatvecDescriptor {
SparseMatDescriptor A;
DenseVecDescriptor x, y;
hipsparseOperation_t op;
};
void CooMatvec(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// CooMatmat: Product of COO matrix and dense matrix.
struct CooMatmatDescriptor {
SparseMatDescriptor A;
DenseMatDescriptor B, C;
hipsparseOperation_t op_A;
};
void CooMatmat(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
struct Gtsv2Descriptor {
int m, n, ldb;
};
void gtsv2_f32(hipStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
void gtsv2_f64(hipStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len, XlaCustomCallStatus* status);
} // namespace jax
#endif // JAXLIB_HIPSPARSE_KERNELS_H_

View File

@ -17,12 +17,13 @@
load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_library = "pyx_library")
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library", _flatbuffer_py_library = "flatbuffer_py_library")
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
# lint tools.
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_library = native.py_library
pyx_library = _pyx_library
pybind_extension = _pybind_extension

View File

@ -1,986 +0,0 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/rocblas.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_runtime_api.h"
#include "rocm/include/rocsolver.h"
#include "absl/base/casts.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/rocm_gpu_kernel_helpers.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
absl::Status AsStatus(rocblas_status status) {
switch (status) {
case rocblas_status_success:
return absl::OkStatus();
default:
return absl::InternalError(rocblas_status_to_string(status));
}
}
using rocBlasHandlePool = HandlePool<rocblas_handle, hipStream_t>;
template <>
/*static*/ absl::StatusOr<rocBlasHandlePool::Handle> rocBlasHandlePool::Borrow(
hipStream_t stream) {
rocBlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
rocblas_handle handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(AsStatus(rocblas_create_handle(&handle)))
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(AsStatus(rocblas_set_stream(handle, stream)))
}
return rocBlasHandlePool::Handle(pool, handle, stream);
}
namespace {
namespace py = pybind11;
// Set of types known to Rocsolver.
enum class Type {
F32,
F64,
C64,
C128,
};
// Converts a NumPy dtype to a Type.
Type DtypeToType(const py::dtype& np_type) {
static auto* types = new absl::flat_hash_map<std::pair<char, int>, Type>({
{{'f', 4}, Type::F32},
{{'f', 8}, Type::F64},
{{'c', 8}, Type::C64},
{{'c', 16}, Type::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
int SizeOfType(Type type) {
switch (type) {
case Type::F32:
return sizeof(float);
case Type::F64:
return sizeof(double);
case Type::C64:
return sizeof(rocblas_float_complex);
case Type::C128:
return sizeof(rocblas_double_complex);
}
}
// the buffers[] are all allocated in rocsolver.py
// where the number of buffers and their size is determined / hardcoded as
// expected here
//##########################
// rocblas
//##########################
// Batched triangular solve: Trsm
struct TrsmDescriptor {
Type type;
int batch, m, n;
rocblas_side side;
rocblas_fill uplo;
rocblas_operation trans;
rocblas_diagonal diag;
};
// Returns the descriptor for a Trsm operation.
std::pair<size_t, py::bytes> BuildTrsmDescriptor(const py::dtype& dtype,
int batch, int m, int n,
bool left_side, bool lower,
bool trans_a, bool conj_a,
bool unit_diagonal) {
std::int64_t lwork =
batch * sizeof(void*); // number of bytes needed for the batch pointers
TrsmDescriptor desc;
desc.type = DtypeToType(dtype);
desc.batch = batch;
desc.m = m;
desc.n = n;
desc.side = left_side ? rocblas_side_left : rocblas_side_right;
desc.uplo = lower ? rocblas_fill_lower : rocblas_fill_upper;
desc.trans = trans_a ? (conj_a ? rocblas_operation_conjugate_transpose
: rocblas_operation_transpose)
: rocblas_operation_none;
desc.diag = unit_diagonal ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;
return {lwork, PackDescriptor(desc)};
}
absl::Status Trsm_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<TrsmDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const TrsmDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// b is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[2] != buffers[1]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[2], buffers[1], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
const int lda = d.side == rocblas_side_left ? d.m : d.n;
const int ldb = d.m;
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[0]);
float* b = static_cast<float*>(buffers[2]);
const float alpha = 1.0f;
JAX_RETURN_IF_ERROR(AsStatus(
rocblas_strsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
d.n, &alpha, const_cast<float*>(a), lda, b, ldb)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[0]);
double* b = static_cast<double*>(buffers[2]);
const double alpha = 1.0;
JAX_RETURN_IF_ERROR(AsStatus(
rocblas_dtrsm(handle.get(), d.side, d.uplo, d.trans, d.diag, d.m,
d.n, &alpha, const_cast<double*>(a), lda, b, ldb)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[0]);
rocblas_float_complex* b =
static_cast<rocblas_float_complex*>(buffers[2]);
const rocblas_float_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_float_complex*>(a), lda, b, ldb)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[0]);
rocblas_double_complex* b =
static_cast<rocblas_double_complex*>(buffers[2]);
const rocblas_double_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_double_complex*>(a), lda, b, ldb)))
break;
}
}
} else {
auto a_batch_host =
MakeBatchPointers(stream, buffers[0], buffers[3], d.batch,
SizeOfType(d.type) * lda * lda);
JAX_RETURN_IF_ERROR(a_batch_host.status());
auto b_batch_host =
MakeBatchPointers(stream, buffers[2], buffers[4], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(b_batch_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
float** b_batch_ptrs = static_cast<float**>(buffers[4]);
const float alpha = 1.0f;
JAX_RETURN_IF_ERROR(AsStatus(rocblas_strsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<float**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)))
break;
}
case Type::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
double** b_batch_ptrs = static_cast<double**>(buffers[4]);
const double alpha = 1.0;
JAX_RETURN_IF_ERROR(AsStatus(rocblas_dtrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<double**>(a_batch_ptrs), lda, b_batch_ptrs, ldb,
d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** a_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[3]);
rocblas_float_complex** b_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[4]);
const rocblas_float_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ctrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_float_complex**>(a_batch_ptrs), lda,
b_batch_ptrs, ldb, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** a_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[3]);
rocblas_double_complex** b_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[4]);
const rocblas_double_complex alpha = {1.0f, 0.0f};
JAX_RETURN_IF_ERROR(AsStatus(rocblas_ztrsm_batched(
handle.get(), d.side, d.uplo, d.trans, d.diag, d.m, d.n, &alpha,
const_cast<rocblas_double_complex**>(a_batch_ptrs), lda,
b_batch_ptrs, ldb, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Trsm(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Trsm_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
//##########################
// rocsolver
//##########################
// potrf: Cholesky decomposition
struct PotrfDescriptor {
Type type;
rocblas_fill uplo;
std::int64_t batch, n;
};
// Returns the descriptor for a potrf operation.
std::pair<int, py::bytes> BuildPotrfDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
Type type = DtypeToType(dtype);
rocblas_fill uplo = lower ? rocblas_fill_lower : rocblas_fill_upper;
std::int64_t lwork =
b * sizeof(void*); // number of bytes needed for the batch pointers
return {lwork, PackDescriptor(PotrfDescriptor{type, uplo, b, n})};
}
absl::Status Potrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<PotrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const PotrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.n * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* info = static_cast<int*>(buffers[2]);
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_spotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_dpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_cpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_zpotrf(handle.get(), d.uplo, d.n, a, d.n, info)))
break;
}
}
} else {
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfType(d.type) * d.n * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[3]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_spotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[3]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** a_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[3]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** a_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[3]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zpotrf_batched(
handle.get(), d.uplo, d.n, a_batch_ptrs, d.n, info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Potrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Potrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// getrf: LU decomposition
struct GetrfDescriptor {
Type type;
int batch, m, n;
};
// Returns the descriptor for a getrf operation.
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
Type type = DtypeToType(dtype);
std::int64_t lwork =
b * sizeof(void*); // number of bytes needed for the batch pointers
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
}
absl::Status Getrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GetrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* ipiv = static_cast<int*>(buffers[2]);
int* info = static_cast<int*>(buffers[3]);
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgetrf(handle.get(), d.m, d.n, a, d.m, ipiv, info)))
break;
}
}
} else {
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[4], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** batch_ptrs = static_cast<float**>(buffers[4]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::F64: {
double** batch_ptrs = static_cast<double**>(buffers[4]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[4]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[4]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgetrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
ipiv, std::min(d.m, d.n), info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Getrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Getrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// geqrf: QR decomposition
struct GeqrfDescriptor {
Type type;
int batch, m, n;
};
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
Type type = DtypeToType(dtype);
std::int64_t lwork =
b * sizeof(void*); // number of bytes needed for the batch pointers
return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n})};
}
absl::Status Geqrf_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GeqrfDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GeqrfDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
// here tau is tau
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* tau = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_sgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* tau = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_dgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[2]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_cgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[2]);
JAX_RETURN_IF_ERROR(
AsStatus(rocsolver_zgeqrf(handle.get(), d.m, d.n, a, d.m, tau)))
break;
}
}
} else {
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[3], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** batch_ptrs = static_cast<float**>(buffers[3]);
float* tau = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::F64: {
double** batch_ptrs = static_cast<double**>(buffers[3]);
double* tau = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[3]);
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[2]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[3]);
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[2]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgeqrf_batched(handle.get(), d.m, d.n, batch_ptrs, d.m,
tau, std::min(d.m, d.n), d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Geqrf(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Geqrf_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// orgqr/ungqr: apply elementary Householder transformations
struct OrgqrDescriptor {
Type type;
int batch, m, n, k;
};
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
int m, int n, int k) {
Type type = DtypeToType(dtype);
std::int64_t lwork = 0;
return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k})};
}
absl::Status Orgqr_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<OrgqrDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const OrgqrDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[2] != buffers[0]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[2], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
switch (d.type) {
// orgqr
case Type::F32: {
float* a = static_cast<float*>(buffers[2]);
float* tau = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[2]);
double* tau = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
break;
}
// ungqr
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[2]);
rocblas_float_complex* tau =
static_cast<rocblas_float_complex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[2]);
rocblas_double_complex* tau =
static_cast<rocblas_double_complex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zungqr(handle.get(), d.m, d.n, d.k, a, d.m, tau)))
a += d.m * d.n;
tau += d.k;
}
break;
}
}
return absl::OkStatus();
}
void Orgqr(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Orgqr_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
// not implemented yet in rocsolver
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// not implemented yet in rocsolver
// Singular value decomposition using QR algorithm: gesvd
struct GesvdDescriptor {
Type type;
int batch, m, n;
rocblas_svect jobu, jobvt;
};
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
Type type = DtypeToType(dtype);
std::int64_t lwork =
b * sizeof(void*); // number of bytes needed for the batch pointers
rocblas_svect jobu, jobvt;
if (compute_uv) {
if (full_matrices) {
jobu = jobvt = rocblas_svect_all;
} else {
jobu = jobvt = rocblas_svect_singular;
}
} else {
jobu = jobvt = rocblas_svect_none;
}
return {lwork, PackDescriptor(GesvdDescriptor{type, b, m, n, jobu, jobvt})};
}
absl::Status Gesvd_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const GesvdDescriptor& d = **s;
auto h = rocBlasHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
// a is INOUT, so we copy the input to the output and use that if they are not
// already the same
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(AsStatus(hipMemcpyAsync(
buffers[1], buffers[0], SizeOfType(d.type) * d.batch * d.m * d.n,
hipMemcpyDeviceToDevice, stream)))
}
int* info = static_cast<int*>(buffers[5]);
const rocblas_int lda = d.m;
const rocblas_int ldu = d.m;
const rocblas_int ldv = d.n;
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_sgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_dgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::C64: {
rocblas_float_complex* a =
static_cast<rocblas_float_complex*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
rocblas_float_complex* u =
static_cast<rocblas_float_complex*>(buffers[3]);
rocblas_float_complex* vt =
static_cast<rocblas_float_complex*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_cgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
case Type::C128: {
rocblas_double_complex* a =
static_cast<rocblas_double_complex*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
rocblas_double_complex* u =
static_cast<rocblas_double_complex*>(buffers[3]);
rocblas_double_complex* vt =
static_cast<rocblas_double_complex*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(
rocsolver_zgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, lda, s,
u, ldu, vt, ldv, e, rocblas_inplace, info)))
break;
}
}
} else {
const rocblas_stride stride_s = std::min(d.m, d.n);
const rocblas_stride stride_u = ldu * d.m;
const rocblas_stride stride_v = ldv * d.n;
const rocblas_stride stride_e = std::min(d.m, d.n) - 1;
auto a_ptrs_host =
MakeBatchPointers(stream, buffers[1], buffers[7], d.batch,
SizeOfType(d.type) * d.m * d.n);
JAX_RETURN_IF_ERROR(a_ptrs_host.status());
// TODO(phawkins): ideally we would not need to synchronize here, but to
// avoid it we need a way to keep the host-side buffer alive until the copy
// completes.
JAX_RETURN_IF_ERROR(AsStatus(hipStreamSynchronize(stream)))
switch (d.type) {
case Type::F32: {
float** a_batch_ptrs = static_cast<float**>(buffers[7]);
float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_sgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch)))
break;
}
case Type::F64: {
double** a_batch_ptrs = static_cast<double**>(buffers[7]);
double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_dgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch)))
break;
}
case Type::C64: {
rocblas_float_complex** a_batch_ptrs =
static_cast<rocblas_float_complex**>(buffers[7]);
float* s = static_cast<float*>(buffers[2]);
rocblas_float_complex* u =
static_cast<rocblas_float_complex*>(buffers[3]);
rocblas_float_complex* vt =
static_cast<rocblas_float_complex*>(buffers[4]);
float* e = static_cast<float*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_cgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch)))
break;
}
case Type::C128: {
rocblas_double_complex** a_batch_ptrs =
static_cast<rocblas_double_complex**>(buffers[7]);
double* s = static_cast<double*>(buffers[2]);
rocblas_double_complex* u =
static_cast<rocblas_double_complex*>(buffers[3]);
rocblas_double_complex* vt =
static_cast<rocblas_double_complex*>(buffers[4]);
double* e = static_cast<double*>(buffers[6]);
JAX_RETURN_IF_ERROR(AsStatus(rocsolver_zgesvd_batched(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a_batch_ptrs, lda, s,
stride_s, u, ldu, stride_u, vt, ldv, stride_v, e, stride_e,
rocblas_inplace, info, d.batch)))
break;
}
}
}
return absl::OkStatus();
}
void Gesvd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Gesvd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Singular value decomposition using Jacobi algorithm: gesvdj
// not implemented yet in rocsolver
py::dict Registrations() {
py::dict dict;
dict["rocblas_trsm"] = EncapsulateFunction(Trsm);
// there are differnent versions of getrf in cublas and cusolver
// however with rocm there is just one in rocsolver
dict["rocsolver_potrf"] = EncapsulateFunction(Potrf);
dict["rocsolver_getrf"] = EncapsulateFunction(Getrf);
dict["rocsolver_geqrf"] = EncapsulateFunction(Geqrf);
dict["rocsolver_orgqr"] = EncapsulateFunction(Orgqr);
// dict["rocsolver_syevd"] = EncapsulateFunction(Syevd);
// dict["rocsolver_syevj"] = EncapsulateFunction(Syevj);
dict["rocsolver_gesvd"] = EncapsulateFunction(Gesvd);
// dict["rocsolver_gesvdj"] = EncapsulateFunction(Gesvdj);
return dict;
}
PYBIND11_MODULE(rocblas_kernels, m) {
m.def("registrations", &Registrations);
m.def("build_trsm_descriptor", &BuildTrsmDescriptor);
m.def("build_potrf_descriptor", &BuildPotrfDescriptor);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
// m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
// m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor);
}
} // namespace
} // namespace jax

View File

@ -1,47 +0,0 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/rocm_gpu_kernel_helpers.h"
#include <stdexcept>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
namespace jax {
absl::Status AsStatus(hipError_t error) {
if (error != hipSuccess) {
return absl::InternalError(
absl::StrCat("ROCm operation failed: ", hipGetErrorString(error)));
}
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<void* []>> MakeBatchPointers(
hipStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {
char* ptr = static_cast<char*>(buffer);
auto host_ptrs = absl::make_unique<void*[]>(batch);
for (int i = 0; i < batch; ++i) {
host_ptrs[i] = ptr;
ptr += batch_elem_size;
}
JAX_RETURN_IF_ERROR(
AsStatus(hipMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch,
hipMemcpyHostToDevice, stream)));
return std::move(host_ptrs);
}
} // namespace jax

View File

@ -1,47 +0,0 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_ROCM_GPU_KERNEL_HELPERS_H_
#define JAXLIB_ROCM_GPU_KERNEL_HELPERS_H_
#include <memory>
#include "rocm/include/hip/hip_runtime_api.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#define JAX_RETURN_IF_ERROR(expr) \
{ \
auto s___ = (expr); \
if (!s___.ok()) return s___; \
}
namespace jax {
absl::Status AsStatus(hipError_t error);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream
// synchronization) until the copy enqueued by MakeBatchPointers on `stream`
// completes.
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(hipStream_t stream,
void* buffer,
void* dev_ptrs,
int batch,
int batch_elem_size);
} // namespace jax
#endif // JAXLIB_ROCM_GPU_KERNEL_HELPERS_H_

View File

@ -1,354 +0,0 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import numpy as np
from jaxlib import xla_client
try:
from jaxlib import rocblas_kernels
for _name, _value in rocblas_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
pass
# we have a single module for both rocsolver and rocblas functions
rocsolver_kernels = rocblas_kernels
_ops = xla_client.ops
_Shape = xla_client.Shape
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def trsm(c,
a,
b,
left_side=False,
lower=False,
trans_a=False,
conj_a=False,
diag=False):
"""triangular solve"""
b_shape = c.get_shape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
k = m if left_side else n
a_shape = c.get_shape(a)
if (batch_dims + (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype):
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
lwork, opaque = rocblas_kernels.build_trsm_descriptor(np.dtype(dtype), batch, m, n,
left_side, lower, trans_a,
conj_a, diag)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = _ops.CustomCallWithLayout(
c,
b"rocblas_trsm",
operands=(a, b),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, b_shape.dimensions(),
layout), # buffers[2] (b, OUT)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)), # buffers[3] (a batch pointers)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)))), # buffers[4] (b batch pointers)
operand_shapes_with_layout=(
_Shape.array_shape(dtype, a_shape.dimensions(), layout), # buffers[0] (a)
_Shape.array_shape(dtype, b_shape.dimensions(), layout), # buffers[1] (b, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0)
def potrf(c, a, lower):
"""Cholesky decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
lwork, opaque = rocsolver_kernels.build_potrf_descriptor(np.dtype(dtype), lower,
batch, n)
kernel = b"rocsolver_potrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (n, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[1] (a, OUT)
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))), # buffers[2] (info)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)), # buffers[3] (a batch pointers)
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (n, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return _ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1)
def getrf(c, a):
"""LU decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
lwork, opaque = rocsolver_kernels.build_getrf_descriptor(np.dtype(dtype), batch, m, n)
kernel = b"rocsolver_getrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[1] (a, OUT)
_Shape.array_shape(np.dtype(np.int32), batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))), # buffers[2] (ipiv)
_Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))), # buffers[3] (info)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)), # buffers[4] (a batch pointers)
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
def geqrf(c, a):
"""QR decomposition."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
lwork, opaque = rocsolver_kernels.build_geqrf_descriptor(np.dtype(dtype), batch, m, n)
kernel = b"rocsolver_geqrf"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[1] (a, OUT)
_Shape.array_shape(dtype, batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))), # buffers[2] (tau)
# buffers[3] (a batch pointers)
_Shape.array_shape(np.dtype(np.int8), (lwork,), (0,)),
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
# rocsolver geqrf does not return info
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1), None)
def orgqr(c, a, tau):
"""Product of elementary Householder reflections."""
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
tau_dims = c.get_shape(tau).dimensions()
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]
_, opaque = rocsolver_kernels.build_orgqr_descriptor(np.dtype(dtype), batch, m, n, k)
kernel = b"rocsolver_orgqr"
out = _ops.CustomCallWithLayout(
c,
kernel,
operands=(a, tau),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[2] (a OUT)
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
), # buffers[0] (a, IN)
_Shape.array_shape(dtype, batch_dims + (k,),
tuple(range(num_bd, -1, -1))), # buffers[1] (tau IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
return (_ops.GetTupleElement(out, 0), None) # ROCSolver orgqr does not return info
def syevd(c, a, lower=False):
raise NotImplementedError(
"Symmetric (Hermitian) eigendecomposition is not yet implemented in ROCSolver")
def gesvd(c, a, full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_shape = c.get_shape(a)
dims = a_shape.dimensions()
dtype = a_shape.element_type()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = _prod(batch_dims)
singular_vals_dtype = np.dtype(_real_type(dtype))
# gesvdj is not yet implemented in ROCSolver
# if m < 32 and n < 32:
# ...
# elif m < n:
if m < n:
lwork, opaque = rocsolver_kernels.build_gesvd_descriptor(np.dtype(dtype), b, n, m,
compute_uv, full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
out = _ops.CustomCallWithLayout(
c,
b"rocsolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[1] (a, OUT)
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
vector_layout), # buffers[2] (s)
# buffers[3] (u; actually vt)
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
# buffers[4] (vt; actually u)
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims,
scalar_layout), # buffers[5] (info)
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n) - 1,),
vector_layout), # buffers[6] (e)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)), # buffers[7] (a batch pointers)
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[0] (a, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
vt = _ops.GetTupleElement(out, 2)
u = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
else:
lwork, opaque = rocsolver_kernels.build_gesvd_descriptor(np.dtype(dtype), b, m, n,
compute_uv, full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = _ops.CustomCallWithLayout(
c,
b"rocsolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[1] (a, OUT)
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
vector_layout), # buffers[2] (s)
_Shape.array_shape(dtype, batch_dims + (m, m),
matrix_layout), # buffers[3] (u)
_Shape.array_shape(dtype, batch_dims + (n, n),
matrix_layout), # buffers[4] (vt)
_Shape.array_shape(np.dtype(np.int32), batch_dims,
scalar_layout), # buffers[5] (info)
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n) - 1,),
vector_layout), # buffers[6] (e)
_Shape.array_shape(np.dtype(np.int8), (lwork,),
(0,)), # buffers[7] (a batch pointers)
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, batch_dims + (m, n),
matrix_layout), # buffers[0] (a, IN)
),
opaque=opaque,
api_version=xla_client.ops.CustomCallApiVersion
.API_VERSION_STATUS_RETURNING)
s = _ops.GetTupleElement(out, 1)
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
if not full_matrices:
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)), (1,) * len(dims))
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n), (1,) * len(dims))
return s, u, vt, info

View File

@ -226,6 +226,7 @@ class CustomLinearSolveTest(jtu.JaxTestCase):
order=2,
rtol={jnp.float32: 6e-2, jnp.float64: 2e-3})
@jtu.skip_on_devices("rocm") # rtol and atol needs to be adjusted for ROCm
def test_custom_linear_solve_cholesky(self):
def positive_definite_solve(a, b):

View File

@ -120,7 +120,6 @@ class FftTest(jtu.JaxTestCase):
for s in _get_fftn_test_s(shape, axes)
for norm in FFT_NORMS
))
@jtu.skip_on_devices("rocm")
def testFftn(self, inverse, real, shape, dtype, axes, s, norm):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
@ -139,7 +138,6 @@ class FftTest(jtu.JaxTestCase):
tol = 0.15
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
@jtu.skip_on_devices("rocm")
def testIrfftTranspose(self):
# regression test for https://github.com/google/jax/issues/6223
def build_matrix(linear_func, size):
@ -199,7 +197,6 @@ class FftTest(jtu.JaxTestCase):
for shape in [(10,)]
for n in [None, 1, 7, 13, 20]
for axis in [-1, 0]))
@jtu.skip_on_devices("rocm")
def testFft(self, inverse, real, hermitian, shape, dtype, n, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
@ -264,7 +261,6 @@ class FftTest(jtu.JaxTestCase):
for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]
for norm in FFT_NORMS
))
@jtu.skip_on_devices("rocm")
def testFft2(self, inverse, real, shape, dtype, axes, norm):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)

View File

@ -2315,7 +2315,7 @@ class HostCallbackCallTest(jtu.JaxTestCase):
expected_res = np.linalg.eigvals(m)
self.assertAllClose(expected_res, fun(m))
@jtu.skip_on_devices("gpu")
def test_call_doc_example_hlo(self):
"""Examples from the documentation: simplest, call a function."""

View File

@ -1512,6 +1512,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
for full in [False, True]
for w in [False, True]
for cov in [False, True, "unscaled"]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov):
rng = jtu.rand_default(self.rng())
tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5}

View File

@ -395,6 +395,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)
@jtu.skip_on_devices("rocm") # rtol and atol needs to be adjusted for ROCm
def testSphHarmOrderZeroDegreeOne(self):
"""Tests the spherical harmonics of order one and degree zero."""
theta = jnp.array([2.0])

View File

@ -542,6 +542,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for full_matrices in [False, True]
for compute_uv in [False, True]
for hermitian in ([False, True] if m == n else [False])))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
@ -797,6 +798,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 4), (2, 70, 7), (2000, 7), (7, 1000), (70, 7, 2),
(2, 0, 0), (3, 0, 2), (1, 0)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinv(self, shape, dtype):
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
@ -811,6 +813,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
# TODO(phawkins): 1e-1 seems like a very loose tolerance.
jtu.check_grads(jnp.linalg.pinv, args_maker(), 2, rtol=1e-1, atol=2e-1)
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testPinvGradIssue2792(self):
def f(p):
a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2)
@ -849,6 +852,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype}
for shape in [(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMatrixRank(self, shape, dtype):
if (jnp.issubdtype(dtype, np.complexfloating) and
jtu.device_under_test() == "tpu"):
@ -899,7 +903,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
]
for rcond in [-1, None, 0.5]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("tpu") # SVD not implemented on TPU.
@jtu.skip_on_devices("tpu","rocm") # SVD not implemented on TPU. will be fixed in ROCm-5.1
def testLstsq(self, lhs_shape, rhs_shape, dtype, rcond):
rng = jtu.rand_default(self.rng())
np_fun = partial(np.linalg.lstsq, rcond=rcond)
@ -1514,6 +1518,7 @@ class LaxLinalgTest(jtu.JaxTestCase):
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
@parameterized.parameters(np.float32, np.float64)
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def test_tridiagonal_solve(self, dtype):
dl = np.array([0.0, 2.0, 3.0], dtype=dtype)
d = np.ones(3, dtype=dtype)

View File

@ -69,6 +69,8 @@ class MultiDeviceTest(jtu.JaxTestCase):
self.assertEqual(data.device_buffer.device(), device)
def test_computation_follows_data(self):
if jax.device_count() < 5:
self.skipTest("test requires 5 devices")
devices = self.get_devices()
# By default, computation is placed (uncommitted) on device 0
@ -197,6 +199,8 @@ class MultiDeviceTest(jtu.JaxTestCase):
self.assert_committed_to_device(z, devices[1])
def test_broadcast(self):
if jax.device_count() < 3:
self.skipTest("test requires 3 devices")
devices = self.get_devices()
z = 1 + jnp.ones((2, 3))
@ -205,6 +209,8 @@ class MultiDeviceTest(jtu.JaxTestCase):
self.assert_committed_to_device(y, devices[2])
def test_transpose(self):
if jax.device_count() < 3:
self.skipTest("test requires 3 devices")
devices = self.get_devices()
x = jnp.ones((2, 3))

View File

@ -66,6 +66,7 @@ class QdwhTest(jtu.JaxTestCase):
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([8, 10, 20], [6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def testQdwhUnconvergedAfterMaxNumberIterations(
self, m, n, log_cond):
"""Tests unconvergence after maximum number of iterations."""
@ -136,6 +137,7 @@ class QdwhTest(jtu.JaxTestCase):
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([6, 8], [6, 4])
for log_cond in np.linspace(1, 4, 4)))
@jtu.skip_on_devices("rocm") # will be solved rocm-5.1
def testQdwhWithRandomMatrix(self, m, n, log_cond):
"""Tests qdwh with random input."""
rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)

View File

@ -885,6 +885,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for dim in [1, 3, 5]
for dtype in float_dtypes
for method in ['svd', 'eigh', 'cholesky']))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm-5.1
def testMultivariateNormal(self, dim, dtype, method):
r = self.rng()
mean = r.randn(dim)
@ -922,6 +923,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for cov_batch_size in [(), (3,), (2, 3)]
for shape in [(), (1,), (5,)]
for method in ['cholesky', 'svd', 'eigh']))
@jtu.skip_on_devices("rocm") # will be solved in rocm-5.1
def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size,
shape, method):
r = self.rng()

View File

@ -54,7 +54,6 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase):
for n in [None, 1, 7, 13, 20]
for axis in [-1, 0]
for norm in [None, 'ortho']))
@jtu.skip_on_devices("rocm")
def testDct(self, shape, dtype, n, axis, norm):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
@ -72,7 +71,6 @@ class LaxBackedScipyFftTests(jtu.JaxTestCase):
for axes in _get_dctn_test_axes(shape)
for s in _get_dctn_test_s(shape, axes)
for norm in [None, 'ortho']))
@jtu.skip_on_devices("rocm")
def testDctn(self, shape, dtype, s, axes, norm):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)

View File

@ -112,6 +112,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def testDetrend(self, shape, dtype, axis, type, bp):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -139,6 +140,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
for detrend in ['constant', 'linear', False]
for boundary in [None, 'even', 'odd', 'zeros']
for padded in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed in ROCm 5.1
def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, boundary, padded,
timeaxis):
@ -190,6 +192,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
@jtu.skip_on_devices("rocm") # will be fixed in next ROCm version
def testCsdAgainstNumpy(
self, *, xshape, yshape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
@ -240,6 +243,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
for detrend in ['constant', 'linear', False]
for scaling in ['density', 'spectrum']
for average in ['mean']))
@jtu.skip_on_devices("rocm") # will be fixed in next rocm release
def testCsdWithSameParamAgainstNumpy(
self, *, shape, dtype, fs, window, nperseg, noverlap, nfft,
detrend, scaling, timeaxis, average):
@ -297,6 +301,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
for return_onesided in [True, False]
for scaling in ['density', 'spectrum']
for average in ['mean', 'median']))
@jtu.skip_on_devices("rocm") # will be fixed in next ROCm release
def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
noverlap, nfft, detrend, return_onesided,
scaling, timeaxis, average):

View File

@ -31,6 +31,7 @@ from jax.experimental import sparse
from jax.experimental.sparse.bcoo import BCOOInfo
from jax import lax
from jax._src.lib import cusparse
from jax._src.lib import hipsparse
from jax._src.lib import xla_bridge
from jax import jit
from jax import tree_util
@ -282,6 +283,7 @@ class cuSparseTest(jtu.JaxTestCase):
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in all_dtypes
for transpose in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def test_csr_matvec(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -386,6 +388,7 @@ class cuSparseTest(jtu.JaxTestCase):
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for dtype in all_dtypes
for transpose in [True, False]))
@jtu.skip_on_devices("rocm") # will be fixed in rocm-5.1
def test_coo_matmat(self, shape, dtype, transpose):
op = lambda M: M.T if transpose else M
@ -421,7 +424,9 @@ class cuSparseTest(jtu.JaxTestCase):
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
def test_gpu_translation_rule(self):
version = xla_bridge.get_backend().platform_version
cuda_version = None if version == "<unknown>" else int(version.split()[-1])
if version.split()[0] != "rocm":
cuda_version = None if version == "<unknown>" else int(
version.split()[-1])
if cuda_version is None or cuda_version < 11000:
self.assertFalse(cusparse and cusparse.is_supported)
self.assertNotIn(sparse.csr_todense_p,
@ -430,6 +435,10 @@ class cuSparseTest(jtu.JaxTestCase):
self.assertTrue(cusparse and cusparse.is_supported)
self.assertIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
else:
self.assertTrue(hipsparse and hipsparse.is_supported)
self.assertIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(

View File

@ -51,6 +51,7 @@ class SvdTest(jtu.JaxTestCase):
'm': m, 'n': n, 'log_cond': log_cond}
for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSvdWithRectangularInput(self, m, n, log_cond):
"""Tests SVD with rectangular input."""
with jax.default_matmul_precision('float32'):
@ -111,6 +112,7 @@ class SvdTest(jtu.JaxTestCase):
'm': m, 'r': r, 'log_cond': log_cond}
for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
for log_cond in np.linspace(1, 3, 3)))
@jtu.skip_on_devices("rocm") # will be fixed on rocm-5.1
def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
"""Tests SVD with rank-deficient input."""
with jax.default_matmul_precision('float32'):