mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add support for linear algebra ops on GPU using Cusolver:
* LU decomposition * Symmetric (Hermitian) eigendecomposition * Singular value decomposition. Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
This commit is contained in:
parent
7c060435bb
commit
ed3e2308c1
@ -14,6 +14,11 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/core:platform/default/cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
@ -26,7 +31,9 @@ sh_binary(
|
||||
"//jaxlib",
|
||||
"//jaxlib:lapack.so",
|
||||
"//jaxlib:pytree",
|
||||
],
|
||||
] + if_cuda_is_configured([
|
||||
"//jaxlib:cusolver_kernels",
|
||||
]),
|
||||
deps = ["@bazel_tools//tools/bash/runfiles"],
|
||||
)
|
||||
|
||||
|
@ -54,7 +54,11 @@ fi
|
||||
# new location.
|
||||
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
|
||||
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
|
||||
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
|
||||
fi
|
||||
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
|
||||
"${TARGET}/jaxlib"
|
||||
sed \
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import partial
|
||||
import numpy as onp
|
||||
|
||||
from jax.numpy import lax_numpy as np
|
||||
@ -34,6 +35,7 @@ from jax.core import Primitive
|
||||
from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule,
|
||||
_float, _complex, _input_dtype, _broadcasting_select)
|
||||
from jax.lib import lapack
|
||||
from jax.lib import cusolver
|
||||
|
||||
# traceables
|
||||
|
||||
@ -81,6 +83,11 @@ def _T(x): return np.swapaxes(x, -1, -2)
|
||||
def _H(x): return np.conj(_T(x))
|
||||
def symmetrize(x): return (x + _H(x)) / 2
|
||||
|
||||
def _unpack_tuple(f, n):
|
||||
def g(c, *args, **kwargs):
|
||||
t = f(c, *args, **kwargs)
|
||||
return (c.GetTupleElement(t, i) for i in range(n))
|
||||
return g
|
||||
|
||||
# primitives
|
||||
|
||||
@ -123,13 +130,18 @@ def _nan_like(c, operand):
|
||||
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
|
||||
return c.Broadcast(nan, shape.dimensions())
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if hasattr(lapack, "potrf"):
|
||||
_cpu_potrf = lapack.potrf
|
||||
else:
|
||||
_cpu_potrf = _unpack_tuple(lapack.jax_potrf, 2)
|
||||
|
||||
def cholesky_cpu_translation_rule(c, operand):
|
||||
shape = c.GetShape(operand)
|
||||
dtype = shape.element_type().type
|
||||
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
||||
potrf_output = lapack.jax_potrf(c, operand, lower=True)
|
||||
result = c.GetTupleElement(potrf_output, 0)
|
||||
info = c.GetTupleElement(potrf_output, 1)
|
||||
result, info = _cpu_potrf(c, operand, lower=True)
|
||||
return c.Select(c.Eq(info, c.ConstantS32Scalar(0)), result,
|
||||
_nan_like(c, result))
|
||||
else:
|
||||
@ -163,15 +175,18 @@ def eig_abstract_eval(operand):
|
||||
w = vl = vr = operand
|
||||
return core.AbstractTuple((w, vl, vr))
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if hasattr(lapack, "geev"):
|
||||
_cpu_geev = lapack.geev
|
||||
else:
|
||||
_cpu_geev = _unpack_tuple(lapack.jax_geev, 4)
|
||||
|
||||
def eig_cpu_translation_rule(c, operand):
|
||||
shape = c.GetShape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
geev_out = lapack.jax_geev(c, operand)
|
||||
w = c.GetTupleElement(geev_out, 0)
|
||||
vl = c.GetTupleElement(geev_out, 1)
|
||||
vr = c.GetTupleElement(geev_out, 2)
|
||||
ok = c.Eq(c.GetTupleElement(geev_out, 3), c.ConstantS32Scalar(0))
|
||||
w, vl, vr, info = _cpu_geev(c, operand)
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
|
||||
_nan_like(c, w))
|
||||
vl = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vl,
|
||||
@ -219,13 +234,11 @@ def eigh_abstract_eval(operand, lower):
|
||||
v, w = operand, operand
|
||||
return core.AbstractTuple((v, w))
|
||||
|
||||
def eigh_cpu_translation_rule(c, operand, lower):
|
||||
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
|
||||
shape = c.GetShape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
syevd_out = lapack.jax_syevd(c, operand, lower=lower)
|
||||
v = c.GetTupleElement(syevd_out, 0)
|
||||
w = c.GetTupleElement(syevd_out, 1)
|
||||
ok = c.Eq(c.GetTupleElement(syevd_out, 2), c.ConstantS32Scalar(0))
|
||||
v, w, info = syevd_impl(c, operand, lower=lower)
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
v = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), v,
|
||||
_nan_like(c, v))
|
||||
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
|
||||
@ -267,7 +280,22 @@ eigh_p.def_impl(eigh_impl)
|
||||
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
||||
xla.translations[eigh_p] = eigh_translation_rule
|
||||
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
||||
xla.backend_specific_translations['cpu'][eigh_p] = eigh_cpu_translation_rule
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if hasattr(lapack, "syevd"):
|
||||
_cpu_syevd = lapack.syevd
|
||||
else:
|
||||
_cpu_syevd = _unpack_tuple(lapack.jax_syevd, 3)
|
||||
|
||||
xla.backend_specific_translations['cpu'][eigh_p] = partial(
|
||||
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if cusolver:
|
||||
xla.backend_specific_translations['gpu'][eigh_p] = partial(
|
||||
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
|
||||
batching.primitive_batchers[eigh_p] = eigh_batching_rule
|
||||
|
||||
|
||||
@ -522,14 +550,13 @@ def _lu_batching_rule(batched_args, batch_dims):
|
||||
x = batching.bdim_at_front(x, bd)
|
||||
return lu_p.bind(x), 0
|
||||
|
||||
def _lu_cpu_translation_rule(c, operand):
|
||||
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
|
||||
shape = c.GetShape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
getrf_out = lapack.jax_getrf(c, operand)
|
||||
lu = c.GetTupleElement(getrf_out, 0)
|
||||
lu, pivot, info = getrf_impl(c, operand)
|
||||
# Subtract 1 from the pivot to get 0-based indices.
|
||||
pivot = c.Sub(c.GetTupleElement(getrf_out, 1), c.ConstantS32Scalar(1))
|
||||
ok = c.Eq(c.GetTupleElement(getrf_out, 2), c.ConstantS32Scalar(0))
|
||||
pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
|
||||
_nan_like(c, lu))
|
||||
return c.Tuple(lu, pivot)
|
||||
@ -541,7 +568,20 @@ lu_p.def_abstract_eval(_lu_abstract_eval)
|
||||
xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
|
||||
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
||||
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
||||
xla.backend_specific_translations['cpu'][lu_p] = _lu_cpu_translation_rule
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if hasattr(lapack, "getrf"):
|
||||
_cpu_getrf = lapack.getrf
|
||||
else:
|
||||
_cpu_getrf = _unpack_tuple(lapack.jax_getrf, 3)
|
||||
|
||||
xla.backend_specific_translations['cpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, _cpu_getrf)
|
||||
|
||||
if cusolver:
|
||||
xla.backend_specific_translations['gpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, cusolver.getrf)
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(swaps, m):
|
||||
@ -681,16 +721,13 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
|
||||
dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim
|
||||
return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
|
||||
|
||||
def svd_cpu_translation_rule(c, operand, full_matrices, compute_uv):
|
||||
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
|
||||
shape = c.GetShape(operand)
|
||||
dtype = shape.element_type().type
|
||||
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
|
||||
gesdd_out = lapack.jax_gesdd(c, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
s = c.GetTupleElement(gesdd_out, 0)
|
||||
u = c.GetTupleElement(gesdd_out, 1)
|
||||
vt = c.GetTupleElement(gesdd_out, 2)
|
||||
ok = c.Eq(c.GetTupleElement(gesdd_out, 3), c.ConstantS32Scalar(0))
|
||||
s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
ok = c.Eq(info, c.ConstantS32Scalar(0))
|
||||
s = _broadcasting_select(c, c.Reshape(ok, None, (1,)), s,
|
||||
_nan_like(c, s))
|
||||
u = _broadcasting_select(c, c.Reshape(ok, None, (1, 1)), u,
|
||||
@ -711,7 +748,22 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||
svd_p = Primitive('svd')
|
||||
svd_p.def_impl(svd_impl)
|
||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||
xla.translations[svd_p] = svd_translation_rule
|
||||
xla.backend_specific_translations['cpu'][svd_p] = svd_cpu_translation_rule
|
||||
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = svd_batching_rule
|
||||
xla.translations[svd_p] = svd_translation_rule
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if hasattr(lapack, "gesdd"):
|
||||
_cpu_gesdd = lapack.gesdd
|
||||
else:
|
||||
_cpu_gesdd = _unpack_tuple(lapack.jax_gesdd, 4)
|
||||
|
||||
xla.backend_specific_translations['cpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, _cpu_gesdd)
|
||||
|
||||
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
|
||||
# 0.1.23.
|
||||
if cusolver:
|
||||
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, cusolver.gesvd)
|
||||
|
@ -47,3 +47,10 @@ try:
|
||||
from jaxlib import pytree
|
||||
except ImportError:
|
||||
pytree = None
|
||||
|
||||
# TODO(phawkins): make the import unconditional when the minimum Jaxlib version
|
||||
# has been increased to 0.1.23.
|
||||
try:
|
||||
from jaxlib import cusolver
|
||||
except ImportError:
|
||||
cusolver = None
|
||||
|
31
jaxlib/BUILD
31
jaxlib/BUILD
@ -29,7 +29,10 @@ pyx_library(
|
||||
|
||||
py_library(
|
||||
name = "jaxlib",
|
||||
srcs = ["version.py"],
|
||||
srcs = [
|
||||
"cusolver.py",
|
||||
"version.py"
|
||||
],
|
||||
)
|
||||
|
||||
tf_pybind_extension(
|
||||
@ -52,3 +55,29 @@ tf_pybind_extension(
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_pybind_extension(
|
||||
name = "cusolver_kernels",
|
||||
srcs = ["cusolver.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
"-Wno-c++98-c++11-compat",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "cusolver_kernels",
|
||||
deps = [
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@local_config_cuda//cuda:cudart",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
585
jaxlib/cusolver.cc
Normal file
585
jaxlib/cusolver.cc
Normal file
@ -0,0 +1,585 @@
|
||||
/* Copyright 2019 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void ThrowIfError(cudaError_t error) {
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error("CUDA operation failed");
|
||||
}
|
||||
}
|
||||
|
||||
void ThrowIfErrorStatus(cusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS:
|
||||
return;
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED:
|
||||
throw std::runtime_error("cuSolver has not been initialized");
|
||||
case CUSOLVER_STATUS_ALLOC_FAILED:
|
||||
throw std::runtime_error("cuSolver allocation failed");
|
||||
case CUSOLVER_STATUS_INVALID_VALUE:
|
||||
throw std::runtime_error("cuSolver invalid value error");
|
||||
case CUSOLVER_STATUS_ARCH_MISMATCH:
|
||||
throw std::runtime_error("cuSolver architecture mismatch error");
|
||||
case CUSOLVER_STATUS_MAPPING_ERROR:
|
||||
throw std::runtime_error("cuSolver mapping error");
|
||||
case CUSOLVER_STATUS_EXECUTION_FAILED:
|
||||
throw std::runtime_error("cuSolver execution failed");
|
||||
case CUSOLVER_STATUS_INTERNAL_ERROR:
|
||||
throw std::runtime_error("cuSolver internal error");
|
||||
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
throw std::invalid_argument("cuSolver matrix type not supported error");
|
||||
case CUSOLVER_STATUS_NOT_SUPPORTED:
|
||||
throw std::runtime_error("cuSolver not supported error");
|
||||
case CUSOLVER_STATUS_ZERO_PIVOT:
|
||||
throw std::runtime_error("cuSolver zero pivot error");
|
||||
case CUSOLVER_STATUS_INVALID_LICENSE:
|
||||
throw std::runtime_error("cuSolver invalid license error");
|
||||
default:
|
||||
throw std::runtime_error("Unknown cuSolver error");
|
||||
}
|
||||
}
|
||||
|
||||
// To avoid creating cusolver contexts in the middle of execution, we maintain
|
||||
// a pool of them.
|
||||
class SolverHandlePool {
|
||||
public:
|
||||
SolverHandlePool() = default;
|
||||
|
||||
// RAII class representing a cusolver handle borrowed from the pool. Returns
|
||||
// the handle to the pool on destruction.
|
||||
class Handle {
|
||||
public:
|
||||
Handle() = default;
|
||||
~Handle() {
|
||||
if (pool_) {
|
||||
pool_->Return(handle_);
|
||||
}
|
||||
}
|
||||
|
||||
Handle(Handle const&) = delete;
|
||||
Handle(Handle&& other) {
|
||||
pool_ = other.pool_;
|
||||
handle_ = other.handle_;
|
||||
other.pool_ = nullptr;
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
Handle& operator=(Handle const&) = delete;
|
||||
Handle& operator=(Handle&& other) {
|
||||
pool_ = other.pool_;
|
||||
handle_ = other.handle_;
|
||||
other.pool_ = nullptr;
|
||||
other.handle_ = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
cusolverDnHandle_t get() { return handle_; }
|
||||
|
||||
private:
|
||||
friend class SolverHandlePool;
|
||||
Handle(SolverHandlePool* pool, cusolverDnHandle_t handle)
|
||||
: pool_(pool), handle_(handle) {}
|
||||
SolverHandlePool* pool_ = nullptr;
|
||||
cusolverDnHandle_t handle_ = nullptr;
|
||||
};
|
||||
|
||||
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
|
||||
// associated with the handle.
|
||||
static Handle Borrow(cudaStream_t stream = nullptr);
|
||||
|
||||
private:
|
||||
static SolverHandlePool* Instance();
|
||||
|
||||
void Return(cusolverDnHandle_t handle);
|
||||
|
||||
absl::Mutex mu_;
|
||||
std::vector<cusolverDnHandle_t> handles_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
/*static*/ SolverHandlePool* SolverHandlePool::Instance() {
|
||||
static auto* pool = new SolverHandlePool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
/*static*/ SolverHandlePool::Handle SolverHandlePool::Borrow(
|
||||
cudaStream_t stream) {
|
||||
SolverHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
cusolverDnHandle_t handle;
|
||||
if (pool->handles_.empty()) {
|
||||
ThrowIfErrorStatus(cusolverDnCreate(&handle));
|
||||
} else {
|
||||
handle = pool->handles_.back();
|
||||
pool->handles_.pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
ThrowIfErrorStatus(cusolverDnSetStream(handle, stream));
|
||||
}
|
||||
return Handle(pool, handle);
|
||||
}
|
||||
|
||||
void SolverHandlePool::Return(cusolverDnHandle_t handle) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
handles_.push_back(handle);
|
||||
}
|
||||
|
||||
// Set of types known to Cusolver.
|
||||
enum class Type {
|
||||
F32,
|
||||
F64,
|
||||
C64,
|
||||
C128,
|
||||
};
|
||||
|
||||
// Converts a NumPy dtype to a Type.
|
||||
Type DtypeToType(const py::dtype& np_type) {
|
||||
static auto* types = new absl::flat_hash_map<std::pair<char, int>, Type>({
|
||||
{{'f', 4}, Type::F32},
|
||||
{{'f', 8}, Type::F64},
|
||||
{{'c', 8}, Type::C64},
|
||||
{{'c', 16}, Type::C128},
|
||||
});
|
||||
auto it = types->find({np_type.kind(), np_type.itemsize()});
|
||||
if (it == types->end()) {
|
||||
throw std::invalid_argument(
|
||||
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
int SizeOfType(Type type) {
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
return sizeof(float);
|
||||
case Type::F64:
|
||||
return sizeof(double);
|
||||
case Type::C64:
|
||||
return sizeof(cuComplex);
|
||||
case Type::C128:
|
||||
return sizeof(cuDoubleComplex);
|
||||
}
|
||||
}
|
||||
|
||||
// Descriptor objects are opaque host-side objects used to pass data from JAX
|
||||
// to the custom kernel launched by XLA. Currently simply treat host-side
|
||||
// structures as byte-strings; this is not portable across architectures. If
|
||||
// portability is needed, we could switch to using a representation such as
|
||||
// protocol buffers or flatbuffers.
|
||||
|
||||
// Packs a descriptor object into a py::bytes structure.
|
||||
template <typename T>
|
||||
py::bytes PackDescriptor(const T& descriptor) {
|
||||
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
|
||||
}
|
||||
|
||||
// Unpacks a descriptor object from a byte string.
|
||||
template <typename T>
|
||||
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
|
||||
if (opaque_len != sizeof(T)) {
|
||||
throw std::runtime_error("Invalid size for linalg operation descriptor.");
|
||||
}
|
||||
return absl::bit_cast<const T*>(opaque);
|
||||
}
|
||||
|
||||
// getrf: LU decomposition
|
||||
|
||||
struct GetrfDescriptor {
|
||||
Type type;
|
||||
int batch, m, n;
|
||||
};
|
||||
|
||||
// Returns the workspace size and a descriptor for a getrf operation.
|
||||
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
Type type = DtypeToType(dtype);
|
||||
auto handle = SolverHandlePool::Borrow();
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
ThrowIfErrorStatus(cusolverDnSgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork));
|
||||
break;
|
||||
case Type::F64:
|
||||
ThrowIfErrorStatus(cusolverDnDgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork));
|
||||
break;
|
||||
case Type::C64:
|
||||
ThrowIfErrorStatus(cusolverDnCgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork));
|
||||
break;
|
||||
case Type::C128:
|
||||
ThrowIfErrorStatus(cusolverDnZgetrf_bufferSize(handle.get(), m, n,
|
||||
/*A=*/nullptr,
|
||||
/*lda=*/m, &lwork));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
|
||||
}
|
||||
|
||||
void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GetrfDescriptor& d =
|
||||
*UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
|
||||
auto handle = SolverHandlePool::Borrow(stream);
|
||||
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
void* workspace = buffers[2];
|
||||
int* ipiv = static_cast<int*>(buffers[3]);
|
||||
int* info = static_cast<int*>(buffers[4]);
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<float*>(workspace),
|
||||
ipiv, info));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<double*>(workspace),
|
||||
ipiv, info));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<cuComplex*>(workspace),
|
||||
ipiv, info));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnZgetrf(
|
||||
handle.get(), d.m, d.n, a, d.m,
|
||||
static_cast<cuDoubleComplex*>(workspace), ipiv, info));
|
||||
a += d.m * d.n;
|
||||
ipiv += std::min(d.m, d.n);
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition: syevd/heevd
|
||||
|
||||
struct SyevdDescriptor {
|
||||
Type type;
|
||||
cublasFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
// Returns the workspace size and a descriptor for a syevd operation.
|
||||
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
Type type = DtypeToType(dtype);
|
||||
auto handle = SolverHandlePool::Borrow();
|
||||
int lwork;
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
cublasFillMode_t uplo =
|
||||
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
ThrowIfErrorStatus(cusolverDnSsyevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork));
|
||||
break;
|
||||
case Type::F64:
|
||||
ThrowIfErrorStatus(cusolverDnDsyevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork));
|
||||
break;
|
||||
case Type::C64:
|
||||
ThrowIfErrorStatus(cusolverDnCheevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork));
|
||||
break;
|
||||
case Type::C128:
|
||||
ThrowIfErrorStatus(cusolverDnZheevd_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
|
||||
&lwork));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
|
||||
}
|
||||
|
||||
void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SyevdDescriptor& d =
|
||||
*UnpackDescriptor<SyevdDescriptor>(opaque, opaque_len);
|
||||
auto handle = SolverHandlePool::Borrow(stream);
|
||||
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a,
|
||||
d.n, w, static_cast<float*>(work),
|
||||
d.lwork, info));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a,
|
||||
d.n, w, static_cast<double*>(work),
|
||||
d.lwork, info));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuComplex*>(work), d.lwork, info));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnZheevd(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuDoubleComplex*>(work), d.lwork, info));
|
||||
a += d.n * d.n;
|
||||
w += d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Singular value decomposition: gesvd
|
||||
|
||||
struct GesvdDescriptor {
|
||||
Type type;
|
||||
int batch, m, n;
|
||||
int lwork;
|
||||
signed char jobu, jobvt;
|
||||
};
|
||||
|
||||
// Returns the workspace size and a descriptor for a gesvd operation.
|
||||
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, bool compute_uv,
|
||||
bool full_matrices) {
|
||||
Type type = DtypeToType(dtype);
|
||||
auto handle = SolverHandlePool::Borrow();
|
||||
int lwork;
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork));
|
||||
break;
|
||||
case Type::F64:
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork));
|
||||
break;
|
||||
case Type::C64:
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork));
|
||||
break;
|
||||
case Type::C128:
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork));
|
||||
break;
|
||||
}
|
||||
signed char jobu, jobvt;
|
||||
if (compute_uv) {
|
||||
if (full_matrices) {
|
||||
jobu = jobvt = 'A';
|
||||
} else {
|
||||
jobu = jobvt = 'S';
|
||||
}
|
||||
} else {
|
||||
jobu = jobvt = 'N';
|
||||
}
|
||||
return {lwork,
|
||||
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
|
||||
}
|
||||
|
||||
// TODO(phawkins): in the batched case, we should consider using the batched
|
||||
// Jacobi implementation instead.
|
||||
void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const GesvdDescriptor& d =
|
||||
*UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
|
||||
auto handle = SolverHandlePool::Borrow(stream);
|
||||
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.m * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* s = static_cast<float*>(buffers[2]);
|
||||
float* u = static_cast<float*>(buffers[3]);
|
||||
float* vt = static_cast<float*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnSgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<float*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* s = static_cast<double*>(buffers[2]);
|
||||
double* u = static_cast<double*>(buffers[3]);
|
||||
double* vt = static_cast<double*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnDgesvd(handle.get(), d.jobu, d.jobvt, d.m,
|
||||
d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<double*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
float* s = static_cast<float*>(buffers[2]);
|
||||
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
|
||||
cuComplex* vt = static_cast<cuComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnCgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
|
||||
u, d.m, vt, d.n, static_cast<cuComplex*>(work),
|
||||
d.lwork, /*rwork=*/nullptr, info));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
double* s = static_cast<double*>(buffers[2]);
|
||||
cuDoubleComplex* u = static_cast<cuDoubleComplex*>(buffers[3]);
|
||||
cuDoubleComplex* vt = static_cast<cuDoubleComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(cusolverDnZgesvd(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<cuDoubleComplex*>(work), d.lwork,
|
||||
/*rwork=*/nullptr, info));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::capsule EncapsulateFunction(T* fn) {
|
||||
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
|
||||
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(cusolver_kernels, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
182
jaxlib/cusolver.py
Normal file
182
jaxlib/cusolver.py
Normal file
@ -0,0 +1,182 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from jaxlib import cusolver_kernels
|
||||
for _name, _value in cusolver_kernels.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_Shape = xla_client.Shape
|
||||
|
||||
|
||||
def _real_type(dtype):
|
||||
"""Returns the real equivalent of 'dtype'."""
|
||||
if dtype == np.float32:
|
||||
return np.float32
|
||||
elif dtype == np.float64:
|
||||
return np.float64
|
||||
elif dtype == np.complex64:
|
||||
return np.float32
|
||||
elif dtype == np.complex128:
|
||||
return np.float64
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
|
||||
def getrf(c, a):
|
||||
"""LU decomposition."""
|
||||
a_shape = c.GetShape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
lwork, opaque = cusolver_kernels.build_getrf_descriptor(
|
||||
np.dtype(dtype), b, m, n)
|
||||
out = c.CustomCall(
|
||||
b"cusolver_getrf",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(
|
||||
dtype, batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
_Shape.array_shape(
|
||||
np.dtype(np.int32), batch_dims + (min(m, n),),
|
||||
tuple(range(num_bd, -1, -1))),
|
||||
_Shape.array_shape(
|
||||
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
|
||||
)),
|
||||
operand_shapes_with_layout=(_Shape.array_shape(
|
||||
dtype, batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
|
||||
opaque=opaque)
|
||||
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 2),
|
||||
c.GetTupleElement(out, 3))
|
||||
|
||||
|
||||
def syevd(c, a, lower=False):
|
||||
"""Symmetric (Hermitian) eigendecomposition."""
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
dtype = a_shape.element_type()
|
||||
dims = a_shape.dimensions()
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
|
||||
np.dtype(dtype), lower, b, n)
|
||||
eigvals_type = _real_type(dtype)
|
||||
|
||||
out = c.CustomCall(
|
||||
b"cusolver_syevd",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(dtype, dims, layout),
|
||||
_Shape.array_shape(
|
||||
np.dtype(eigvals_type), batch_dims + (n,),
|
||||
tuple(range(num_bd, -1, -1))),
|
||||
_Shape.array_shape(
|
||||
np.dtype(np.int32), batch_dims,
|
||||
tuple(range(num_bd - 1, -1, -1))),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,))
|
||||
)),
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, dims, layout),
|
||||
),
|
||||
opaque=opaque)
|
||||
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
|
||||
c.GetTupleElement(out, 2))
|
||||
|
||||
|
||||
def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition."""
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
dtype = a_shape.element_type()
|
||||
b = 1
|
||||
m, n = a_shape.dimensions()
|
||||
singular_vals_dtype = _real_type(dtype)
|
||||
|
||||
if m < n:
|
||||
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
||||
out = c.CustomCall(
|
||||
b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(dtype, (m, n), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
|
||||
_Shape.array_shape(dtype, (n, n), (1, 0)),
|
||||
_Shape.array_shape(dtype, (m, m), (1, 0)),
|
||||
_Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
)),
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, (m, n), (1, 0)),
|
||||
),
|
||||
opaque=opaque)
|
||||
s = c.GetTupleElement(out, 1)
|
||||
vt = c.GetTupleElement(out, 2)
|
||||
u = c.GetTupleElement(out, 3)
|
||||
info = c.GetTupleElement(out, 4)
|
||||
else:
|
||||
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
|
||||
|
||||
out = c.CustomCall(
|
||||
b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(dtype, (m, n), (0, 1)),
|
||||
_Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
|
||||
_Shape.array_shape(dtype, (m, m), (0, 1)),
|
||||
_Shape.array_shape(dtype, (n, n), (0, 1)),
|
||||
_Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
)),
|
||||
operand_shapes_with_layout=(
|
||||
_Shape.array_shape(dtype, (m, n), (0, 1)),
|
||||
),
|
||||
opaque=opaque)
|
||||
s = c.GetTupleElement(out, 1)
|
||||
u = c.GetTupleElement(out, 2)
|
||||
vt = c.GetTupleElement(out, 3)
|
||||
info = c.GetTupleElement(out, 4)
|
||||
if not full_matrices:
|
||||
u = c.Slice(u, (0, 0), (m, min(m, n)))
|
||||
vt = c.Slice(vt, (0, 0), (min(m, n), n))
|
||||
return s, u, vt, info
|
@ -166,7 +166,7 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
|
||||
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
|
||||
|
||||
|
||||
def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
conj_a=False, diag=False):
|
||||
b_shape = c.GetShape(b)
|
||||
dtype = b_shape.element_type()
|
||||
@ -214,7 +214,7 @@ def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)),
|
||||
Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
|
||||
))
|
||||
|
||||
jax_trsm = trsm
|
||||
|
||||
# ?getrf: LU decomposition
|
||||
|
||||
@ -305,7 +305,7 @@ cdef void lapack_zgetrf(void* out_tuple, void** data) nogil:
|
||||
|
||||
register_cpu_custom_call_target(b"lapack_zgetrf", <void*>(lapack_zgetrf))
|
||||
|
||||
def jax_getrf(c, a):
|
||||
def getrf(c, a):
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
@ -330,7 +330,7 @@ def jax_getrf(c, a):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
return c.CustomCall(
|
||||
out = c.CustomCall(
|
||||
fn,
|
||||
operands=(
|
||||
c.ConstantS32Scalar(b),
|
||||
@ -358,8 +358,10 @@ def jax_getrf(c, a):
|
||||
batch_dims + (m, n),
|
||||
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
|
||||
))
|
||||
return tuple(c.GetTupleElement(out, i) for i in range(3))
|
||||
|
||||
|
||||
def jax_getrf(c, a):
|
||||
return c.Tuple(*getrf(c, a))
|
||||
|
||||
# ?potrf: Cholesky decomposition
|
||||
|
||||
@ -429,7 +431,7 @@ cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
|
||||
|
||||
register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))
|
||||
|
||||
def jax_potrf(c, a, lower=False):
|
||||
def potrf(c, a, lower=False):
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
@ -448,7 +450,7 @@ def jax_potrf(c, a, lower=False):
|
||||
else:
|
||||
raise NotImplementedError("Unsupported dtype {}".format(dtype))
|
||||
|
||||
return c.CustomCall(
|
||||
out = c.CustomCall(
|
||||
fn,
|
||||
operands=(c.ConstantS32Scalar(int(lower)), c.ConstantS32Scalar(n), a),
|
||||
shape_with_layout=Shape.tuple_shape((
|
||||
@ -460,6 +462,10 @@ def jax_potrf(c, a, lower=False):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, (n, n), (0, 1)),
|
||||
))
|
||||
return tuple(c.GetTupleElement(out, i) for i in range(2))
|
||||
|
||||
def jax_potrf(c, a, lower=False):
|
||||
return c.Tuple(*potrf(c, a, lower))
|
||||
|
||||
|
||||
# ?gesdd: Singular value decomposition
|
||||
@ -555,7 +561,7 @@ cdef void lapack_dgesdd(void* out_tuple, void** data) nogil:
|
||||
ldvt = min(m, n)
|
||||
|
||||
# First perform a workspace query to get the optimal lwork
|
||||
# NB: We perform a workspace query with malloc and free for the work array,
|
||||
# NB: We perform a workspace query with malloc and free for the work array,
|
||||
# because it is officially recommended in the LAPACK documentation
|
||||
cdef double wkopt = 0
|
||||
cdef int lwork = -1
|
||||
@ -667,7 +673,7 @@ cdef void lapack_zgesdd(void* out_tuple, void** data) nogil:
|
||||
|
||||
register_cpu_custom_call_target(b"lapack_zgesdd", <void*>(lapack_zgesdd))
|
||||
|
||||
def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
def gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
@ -720,8 +726,11 @@ def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, (m, n), (0, 1)),
|
||||
))
|
||||
return c.Tuple(c.GetTupleElement(out, 1), c.GetTupleElement(out, 2),
|
||||
c.GetTupleElement(out, 3), c.GetTupleElement(out, 4))
|
||||
return (c.GetTupleElement(out, 1), c.GetTupleElement(out, 2),
|
||||
c.GetTupleElement(out, 3), c.GetTupleElement(out, 4))
|
||||
|
||||
def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
return c.Tuple(*gesdd(c, a, full_matrices, compute_uv))
|
||||
|
||||
|
||||
# syevd: Symmetric eigendecomposition
|
||||
@ -861,7 +870,7 @@ cdef void lapack_zheevd(void* out_tuple, void** data) nogil:
|
||||
|
||||
register_cpu_custom_call_target(b"lapack_zheevd", <void*>(lapack_zheevd))
|
||||
|
||||
def jax_syevd(c, a, lower=False):
|
||||
def syevd(c, a, lower=False):
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
@ -928,8 +937,11 @@ def jax_syevd(c, a, lower=False):
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
|
||||
c.GetTupleElement(out, 2))
|
||||
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
|
||||
c.GetTupleElement(out, 2))
|
||||
|
||||
def jax_syevd(c, a, lower=False):
|
||||
return c.Tuple(*syevd(c, a, lower))
|
||||
|
||||
|
||||
# geev: Nonsymmetric eigendecomposition
|
||||
@ -1140,7 +1152,7 @@ register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
|
||||
|
||||
|
||||
|
||||
def jax_geev(c, a):
|
||||
def geev(c, a):
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
@ -1212,11 +1224,12 @@ def jax_geev(c, a):
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
if real:
|
||||
return c.Tuple(
|
||||
c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)),
|
||||
c.GetTupleElement(out, 5), c.GetTupleElement(out, 6),
|
||||
c.GetTupleElement(out, 7))
|
||||
return (c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)),
|
||||
c.GetTupleElement(out, 5), c.GetTupleElement(out, 6),
|
||||
c.GetTupleElement(out, 7))
|
||||
else:
|
||||
return c.Tuple(
|
||||
c.GetTupleElement(out, 2), c.GetTupleElement(out, 3),
|
||||
c.GetTupleElement(out, 4), c.GetTupleElement(out, 5))
|
||||
return (c.GetTupleElement(out, 2), c.GetTupleElement(out, 3),
|
||||
c.GetTupleElement(out, 4), c.GetTupleElement(out, 5))
|
||||
|
||||
def jax_geev(c, a):
|
||||
return c.Tuple(*geev(c, a))
|
||||
|
@ -163,8 +163,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for lower in [False, True]
|
||||
for rng in [jtu.rand_default()]))
|
||||
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
# for TPU.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testEigh(self, n, dtype, lower, rng):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
@ -196,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for rng in [jtu.rand_default()]
|
||||
for lower in [True, False]))
|
||||
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
# for TPU.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testEighGrad(self, shape, dtype, rng, lower):
|
||||
self.skipTest("Test fails with numeric errors.")
|
||||
uplo = "L" if lower else "U"
|
||||
@ -224,8 +224,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for lower in [True, False]
|
||||
for eps in [1e-4]))
|
||||
# TODO(phawkins): enable when there is an eigendecomposition implementation
|
||||
# for GPU/TPU.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
# for TPU.
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
# Special case to test for complex eigenvector grad correctness.
|
||||
@ -263,7 +263,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for shape in [(1, 1), (4, 4), (5, 5)]
|
||||
for dtype in float_types + complex_types
|
||||
for rng in [jtu.rand_default()]))
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testEighBatching(self, shape, dtype, rng):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
shape = (10,) + shape
|
||||
@ -318,7 +318,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
for full_matrices in [False, True]
|
||||
for compute_uv in [False, True]
|
||||
for rng in [jtu.rand_default()]))
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((m, n), dtype)]
|
||||
@ -414,7 +414,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
if not full_matrices and m >= n:
|
||||
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,))
|
||||
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testQrBatching(self):
|
||||
shape = (10, 4, 5)
|
||||
dtype = np.float32
|
||||
@ -476,7 +476,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
|
||||
|
||||
# Regression test for incorrect type for eigenvalues of a complex matrix.
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testIssue669(self):
|
||||
def test(x):
|
||||
val, vec = np.linalg.eigh(x)
|
||||
@ -499,9 +499,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testLu(self, shape, dtype, rng):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
|
||||
check_dtypes=True, tol=1e-3)
|
||||
x, = args_maker()
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True)
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
|
||||
|
||||
# TODO(phawkins): figure out why this test fails on Travis and reenable.
|
||||
@ -555,8 +555,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
_skip_if_unsupported_type(dtype)
|
||||
args_maker = lambda: [rng((n, n), dtype)]
|
||||
|
||||
self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
|
||||
args_maker, check_dtypes=True, tol=1e-3)
|
||||
x, = args_maker()
|
||||
lu, piv = jsp.linalg.lu_factor(x)
|
||||
l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
|
||||
u = onp.triu(lu)
|
||||
for i in range(n):
|
||||
x[[i, piv[i]],] = x[[piv[i], i],]
|
||||
self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3)
|
||||
# self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
|
||||
# args_maker, check_dtypes=True, tol=1e-3)
|
||||
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user