Add support for linear algebra ops on GPU using Cusolver:

* LU decomposition
* Symmetric (Hermitian) eigendecomposition
* Singular value decomposition.

Make LU decomposition tests less sensitive to the exact decomposition; check that we have a decomposition, not precisely the same one scipy returns.
This commit is contained in:
Peter Hawkins 2019-08-02 11:16:15 -04:00
parent 7c060435bb
commit ed3e2308c1
9 changed files with 954 additions and 68 deletions

View File

@ -14,6 +14,11 @@
# JAX is Autograd and XLA
load(
"@org_tensorflow//tensorflow/core:platform/default/cuda_build_defs.bzl",
"if_cuda_is_configured",
)
licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
@ -26,7 +31,9 @@ sh_binary(
"//jaxlib",
"//jaxlib:lapack.so",
"//jaxlib:pytree",
],
] + if_cuda_is_configured([
"//jaxlib:cusolver_kernels",
]),
deps = ["@bazel_tools//tools/bash/runfiles"],
)

View File

@ -54,7 +54,11 @@ fi
# new location.
cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib"
if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then
cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib"
fi
cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \
"${TARGET}/jaxlib"
sed \

View File

@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial
import numpy as onp
from jax.numpy import lax_numpy as np
@ -34,6 +35,7 @@ from jax.core import Primitive
from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule,
_float, _complex, _input_dtype, _broadcasting_select)
from jax.lib import lapack
from jax.lib import cusolver
# traceables
@ -81,6 +83,11 @@ def _T(x): return np.swapaxes(x, -1, -2)
def _H(x): return np.conj(_T(x))
def symmetrize(x): return (x + _H(x)) / 2
def _unpack_tuple(f, n):
def g(c, *args, **kwargs):
t = f(c, *args, **kwargs)
return (c.GetTupleElement(t, i) for i in range(n))
return g
# primitives
@ -123,13 +130,18 @@ def _nan_like(c, operand):
nan = c.Constant(onp.array(onp.nan, dtype=dtype))
return c.Broadcast(nan, shape.dimensions())
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "potrf"):
_cpu_potrf = lapack.potrf
else:
_cpu_potrf = _unpack_tuple(lapack.jax_potrf, 2)
def cholesky_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
potrf_output = lapack.jax_potrf(c, operand, lower=True)
result = c.GetTupleElement(potrf_output, 0)
info = c.GetTupleElement(potrf_output, 1)
result, info = _cpu_potrf(c, operand, lower=True)
return c.Select(c.Eq(info, c.ConstantS32Scalar(0)), result,
_nan_like(c, result))
else:
@ -163,15 +175,18 @@ def eig_abstract_eval(operand):
w = vl = vr = operand
return core.AbstractTuple((w, vl, vr))
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "geev"):
_cpu_geev = lapack.geev
else:
_cpu_geev = _unpack_tuple(lapack.jax_geev, 4)
def eig_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
geev_out = lapack.jax_geev(c, operand)
w = c.GetTupleElement(geev_out, 0)
vl = c.GetTupleElement(geev_out, 1)
vr = c.GetTupleElement(geev_out, 2)
ok = c.Eq(c.GetTupleElement(geev_out, 3), c.ConstantS32Scalar(0))
w, vl, vr, info = _cpu_geev(c, operand)
ok = c.Eq(info, c.ConstantS32Scalar(0))
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
_nan_like(c, w))
vl = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vl,
@ -219,13 +234,11 @@ def eigh_abstract_eval(operand, lower):
v, w = operand, operand
return core.AbstractTuple((v, w))
def eigh_cpu_translation_rule(c, operand, lower):
def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
syevd_out = lapack.jax_syevd(c, operand, lower=lower)
v = c.GetTupleElement(syevd_out, 0)
w = c.GetTupleElement(syevd_out, 1)
ok = c.Eq(c.GetTupleElement(syevd_out, 2), c.ConstantS32Scalar(0))
v, w, info = syevd_impl(c, operand, lower=lower)
ok = c.Eq(info, c.ConstantS32Scalar(0))
v = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), v,
_nan_like(c, v))
w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w,
@ -267,7 +280,22 @@ eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
xla.backend_specific_translations['cpu'][eigh_p] = eigh_cpu_translation_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "syevd"):
_cpu_syevd = lapack.syevd
else:
_cpu_syevd = _unpack_tuple(lapack.jax_syevd, 3)
xla.backend_specific_translations['cpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, _cpu_syevd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, cusolver.syevd)
batching.primitive_batchers[eigh_p] = eigh_batching_rule
@ -522,14 +550,13 @@ def _lu_batching_rule(batched_args, batch_dims):
x = batching.bdim_at_front(x, bd)
return lu_p.bind(x), 0
def _lu_cpu_translation_rule(c, operand):
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
shape = c.GetShape(operand)
batch_dims = shape.dimensions()[:-2]
getrf_out = lapack.jax_getrf(c, operand)
lu = c.GetTupleElement(getrf_out, 0)
lu, pivot, info = getrf_impl(c, operand)
# Subtract 1 from the pivot to get 0-based indices.
pivot = c.Sub(c.GetTupleElement(getrf_out, 1), c.ConstantS32Scalar(1))
ok = c.Eq(c.GetTupleElement(getrf_out, 2), c.ConstantS32Scalar(0))
pivot = c.Sub(pivot, c.ConstantS32Scalar(1))
ok = c.Eq(info, c.ConstantS32Scalar(0))
lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu,
_nan_like(c, lu))
return c.Tuple(lu, pivot)
@ -541,7 +568,20 @@ lu_p.def_abstract_eval(_lu_abstract_eval)
xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True)
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule
xla.backend_specific_translations['cpu'][lu_p] = _lu_cpu_translation_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "getrf"):
_cpu_getrf = lapack.getrf
else:
_cpu_getrf = _unpack_tuple(lapack.jax_getrf, 3)
xla.backend_specific_translations['cpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, _cpu_getrf)
if cusolver:
xla.backend_specific_translations['gpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, cusolver.getrf)
def lu_pivots_to_permutation(swaps, m):
@ -681,16 +721,13 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim
return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
def svd_cpu_translation_rule(c, operand, full_matrices, compute_uv):
def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv):
shape = c.GetShape(operand)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types:
gesdd_out = lapack.jax_gesdd(c, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
s = c.GetTupleElement(gesdd_out, 0)
u = c.GetTupleElement(gesdd_out, 1)
vt = c.GetTupleElement(gesdd_out, 2)
ok = c.Eq(c.GetTupleElement(gesdd_out, 3), c.ConstantS32Scalar(0))
s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
ok = c.Eq(info, c.ConstantS32Scalar(0))
s = _broadcasting_select(c, c.Reshape(ok, None, (1,)), s,
_nan_like(c, s))
u = _broadcasting_select(c, c.Reshape(ok, None, (1, 1)), u,
@ -711,7 +748,22 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
svd_p = Primitive('svd')
svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
xla.translations[svd_p] = svd_translation_rule
xla.backend_specific_translations['cpu'][svd_p] = svd_cpu_translation_rule
ad.primitive_jvps[svd_p] = svd_jvp_rule
batching.primitive_batchers[svd_p] = svd_batching_rule
xla.translations[svd_p] = svd_translation_rule
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if hasattr(lapack, "gesdd"):
_cpu_gesdd = lapack.gesdd
else:
_cpu_gesdd = _unpack_tuple(lapack.jax_gesdd, 4)
xla.backend_specific_translations['cpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, _cpu_gesdd)
# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to
# 0.1.23.
if cusolver:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, cusolver.gesvd)

View File

@ -47,3 +47,10 @@ try:
from jaxlib import pytree
except ImportError:
pytree = None
# TODO(phawkins): make the import unconditional when the minimum Jaxlib version
# has been increased to 0.1.23.
try:
from jaxlib import cusolver
except ImportError:
cusolver = None

View File

@ -29,7 +29,10 @@ pyx_library(
py_library(
name = "jaxlib",
srcs = ["version.py"],
srcs = [
"cusolver.py",
"version.py"
],
)
tf_pybind_extension(
@ -52,3 +55,29 @@ tf_pybind_extension(
"@pybind11",
],
)
tf_pybind_extension(
name = "cusolver_kernels",
srcs = ["cusolver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
"-Wno-c++98-c++11-compat",
],
features = ["-use_header_modules"],
module_name = "cusolver_kernels",
deps = [
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_cuda//cuda:cudart",
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cusolver",
"@pybind11",
],
)

585
jaxlib/cusolver.cc Normal file
View File

@ -0,0 +1,585 @@
/* Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <stdexcept>
#include <utility>
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace jax {
namespace {
namespace py = pybind11;
void ThrowIfError(cudaError_t error) {
if (error != cudaSuccess) {
throw std::runtime_error("CUDA operation failed");
}
}
void ThrowIfErrorStatus(cusolverStatus_t status) {
switch (status) {
case CUSOLVER_STATUS_SUCCESS:
return;
case CUSOLVER_STATUS_NOT_INITIALIZED:
throw std::runtime_error("cuSolver has not been initialized");
case CUSOLVER_STATUS_ALLOC_FAILED:
throw std::runtime_error("cuSolver allocation failed");
case CUSOLVER_STATUS_INVALID_VALUE:
throw std::runtime_error("cuSolver invalid value error");
case CUSOLVER_STATUS_ARCH_MISMATCH:
throw std::runtime_error("cuSolver architecture mismatch error");
case CUSOLVER_STATUS_MAPPING_ERROR:
throw std::runtime_error("cuSolver mapping error");
case CUSOLVER_STATUS_EXECUTION_FAILED:
throw std::runtime_error("cuSolver execution failed");
case CUSOLVER_STATUS_INTERNAL_ERROR:
throw std::runtime_error("cuSolver internal error");
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
throw std::invalid_argument("cuSolver matrix type not supported error");
case CUSOLVER_STATUS_NOT_SUPPORTED:
throw std::runtime_error("cuSolver not supported error");
case CUSOLVER_STATUS_ZERO_PIVOT:
throw std::runtime_error("cuSolver zero pivot error");
case CUSOLVER_STATUS_INVALID_LICENSE:
throw std::runtime_error("cuSolver invalid license error");
default:
throw std::runtime_error("Unknown cuSolver error");
}
}
// To avoid creating cusolver contexts in the middle of execution, we maintain
// a pool of them.
class SolverHandlePool {
public:
SolverHandlePool() = default;
// RAII class representing a cusolver handle borrowed from the pool. Returns
// the handle to the pool on destruction.
class Handle {
public:
Handle() = default;
~Handle() {
if (pool_) {
pool_->Return(handle_);
}
}
Handle(Handle const&) = delete;
Handle(Handle&& other) {
pool_ = other.pool_;
handle_ = other.handle_;
other.pool_ = nullptr;
other.handle_ = nullptr;
}
Handle& operator=(Handle const&) = delete;
Handle& operator=(Handle&& other) {
pool_ = other.pool_;
handle_ = other.handle_;
other.pool_ = nullptr;
other.handle_ = nullptr;
return *this;
}
cusolverDnHandle_t get() { return handle_; }
private:
friend class SolverHandlePool;
Handle(SolverHandlePool* pool, cusolverDnHandle_t handle)
: pool_(pool), handle_(handle) {}
SolverHandlePool* pool_ = nullptr;
cusolverDnHandle_t handle_ = nullptr;
};
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
// associated with the handle.
static Handle Borrow(cudaStream_t stream = nullptr);
private:
static SolverHandlePool* Instance();
void Return(cusolverDnHandle_t handle);
absl::Mutex mu_;
std::vector<cusolverDnHandle_t> handles_ GUARDED_BY(mu_);
};
/*static*/ SolverHandlePool* SolverHandlePool::Instance() {
static auto* pool = new SolverHandlePool;
return pool;
}
/*static*/ SolverHandlePool::Handle SolverHandlePool::Borrow(
cudaStream_t stream) {
SolverHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusolverDnHandle_t handle;
if (pool->handles_.empty()) {
ThrowIfErrorStatus(cusolverDnCreate(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
}
if (stream) {
ThrowIfErrorStatus(cusolverDnSetStream(handle, stream));
}
return Handle(pool, handle);
}
void SolverHandlePool::Return(cusolverDnHandle_t handle) {
absl::MutexLock lock(&mu_);
handles_.push_back(handle);
}
// Set of types known to Cusolver.
enum class Type {
F32,
F64,
C64,
C128,
};
// Converts a NumPy dtype to a Type.
Type DtypeToType(const py::dtype& np_type) {
static auto* types = new absl::flat_hash_map<std::pair<char, int>, Type>({
{{'f', 4}, Type::F32},
{{'f', 8}, Type::F64},
{{'c', 8}, Type::C64},
{{'c', 16}, Type::C128},
});
auto it = types->find({np_type.kind(), np_type.itemsize()});
if (it == types->end()) {
throw std::invalid_argument(
absl::StrFormat("Unsupported dtype %s", py::repr(np_type)));
}
return it->second;
}
int SizeOfType(Type type) {
switch (type) {
case Type::F32:
return sizeof(float);
case Type::F64:
return sizeof(double);
case Type::C64:
return sizeof(cuComplex);
case Type::C128:
return sizeof(cuDoubleComplex);
}
}
// Descriptor objects are opaque host-side objects used to pass data from JAX
// to the custom kernel launched by XLA. Currently simply treat host-side
// structures as byte-strings; this is not portable across architectures. If
// portability is needed, we could switch to using a representation such as
// protocol buffers or flatbuffers.
// Packs a descriptor object into a py::bytes structure.
template <typename T>
py::bytes PackDescriptor(const T& descriptor) {
return py::bytes(absl::bit_cast<const char*>(&descriptor), sizeof(T));
}
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid size for linalg operation descriptor.");
}
return absl::bit_cast<const T*>(opaque);
}
// getrf: LU decomposition
struct GetrfDescriptor {
Type type;
int batch, m, n;
};
// Returns the workspace size and a descriptor for a getrf operation.
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
int m, int n) {
Type type = DtypeToType(dtype);
auto handle = SolverHandlePool::Borrow();
int lwork;
switch (type) {
case Type::F32:
ThrowIfErrorStatus(cusolverDnSgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork));
break;
case Type::F64:
ThrowIfErrorStatus(cusolverDnDgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork));
break;
case Type::C64:
ThrowIfErrorStatus(cusolverDnCgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork));
break;
case Type::C128:
ThrowIfErrorStatus(cusolverDnZgetrf_bufferSize(handle.get(), m, n,
/*A=*/nullptr,
/*lda=*/m, &lwork));
break;
}
return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})};
}
void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GetrfDescriptor& d =
*UnpackDescriptor<GetrfDescriptor>(opaque, opaque_len);
auto handle = SolverHandlePool::Borrow(stream);
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream));
void* workspace = buffers[2];
int* ipiv = static_cast<int*>(buffers[3]);
int* info = static_cast<int*>(buffers[4]);
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<float*>(workspace),
ipiv, info));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<double*>(workspace),
ipiv, info));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m,
static_cast<cuComplex*>(workspace),
ipiv, info));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnZgetrf(
handle.get(), d.m, d.n, a, d.m,
static_cast<cuDoubleComplex*>(workspace), ipiv, info));
a += d.m * d.n;
ipiv += std::min(d.m, d.n);
++info;
}
break;
}
}
}
// Symmetric (Hermitian) eigendecomposition: syevd/heevd
struct SyevdDescriptor {
Type type;
cublasFillMode_t uplo;
int batch, n;
int lwork;
};
// Returns the workspace size and a descriptor for a syevd operation.
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
Type type = DtypeToType(dtype);
auto handle = SolverHandlePool::Borrow();
int lwork;
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
cublasFillMode_t uplo =
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
switch (type) {
case Type::F32:
ThrowIfErrorStatus(cusolverDnSsyevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork));
break;
case Type::F64:
ThrowIfErrorStatus(cusolverDnDsyevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork));
break;
case Type::C64:
ThrowIfErrorStatus(cusolverDnCheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork));
break;
case Type::C128:
ThrowIfErrorStatus(cusolverDnZheevd_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr,
&lwork));
break;
}
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
}
void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SyevdDescriptor& d =
*UnpackDescriptor<SyevdDescriptor>(opaque, opaque_len);
auto handle = SolverHandlePool::Borrow(stream);
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream));
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a,
d.n, w, static_cast<float*>(work),
d.lwork, info));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a,
d.n, w, static_cast<double*>(work),
d.lwork, info));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
cusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuComplex*>(work), d.lwork, info));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnZheevd(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuDoubleComplex*>(work), d.lwork, info));
a += d.n * d.n;
w += d.n;
++info;
}
break;
}
}
}
// Singular value decomposition: gesvd
struct GesvdDescriptor {
Type type;
int batch, m, n;
int lwork;
signed char jobu, jobvt;
};
// Returns the workspace size and a descriptor for a gesvd operation.
std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
int m, int n, bool compute_uv,
bool full_matrices) {
Type type = DtypeToType(dtype);
auto handle = SolverHandlePool::Borrow();
int lwork;
switch (type) {
case Type::F32:
ThrowIfErrorStatus(
cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork));
break;
case Type::F64:
ThrowIfErrorStatus(
cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork));
break;
case Type::C64:
ThrowIfErrorStatus(
cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork));
break;
case Type::C128:
ThrowIfErrorStatus(
cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork));
break;
}
signed char jobu, jobvt;
if (compute_uv) {
if (full_matrices) {
jobu = jobvt = 'A';
} else {
jobu = jobvt = 'S';
}
} else {
jobu = jobvt = 'N';
}
return {lwork,
PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})};
}
// TODO(phawkins): in the batched case, we should consider using the batched
// Jacobi implementation instead.
void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const GesvdDescriptor& d =
*UnpackDescriptor<GesvdDescriptor>(opaque, opaque_len);
auto handle = SolverHandlePool::Borrow(stream);
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.m * d.n,
cudaMemcpyDeviceToDevice, stream));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]);
float* vt = static_cast<float*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnSgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<float*>(work), d.lwork,
/*rwork=*/nullptr, info));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]);
double* vt = static_cast<double*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnDgesvd(handle.get(), d.jobu, d.jobvt, d.m,
d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<double*>(work), d.lwork,
/*rwork=*/nullptr, info));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]);
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
cuComplex* vt = static_cast<cuComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
cusolverDnCgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
u, d.m, vt, d.n, static_cast<cuComplex*>(work),
d.lwork, /*rwork=*/nullptr, info));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
double* s = static_cast<double*>(buffers[2]);
cuDoubleComplex* u = static_cast<cuDoubleComplex*>(buffers[3]);
cuDoubleComplex* vt = static_cast<cuDoubleComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(cusolverDnZgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<cuDoubleComplex*>(work), d.lwork,
/*rwork=*/nullptr, info));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
}
}
template <typename T>
py::capsule EncapsulateFunction(T* fn) {
return py::capsule(absl::bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
py::dict Registrations() {
py::dict dict;
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
return dict;
}
PYBIND11_MODULE(cusolver_kernels, m) {
m.def("registrations", &Registrations);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
}
} // namespace
} // namespace jax

182
jaxlib/cusolver.py Normal file
View File

@ -0,0 +1,182 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from jaxlib import xla_client
try:
from jaxlib import cusolver_kernels
for _name, _value in cusolver_kernels.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
except ImportError:
pass
_Shape = xla_client.Shape
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
if dtype == np.float32:
return np.float32
elif dtype == np.float64:
return np.float64
elif dtype == np.complex64:
return np.float32
elif dtype == np.complex128:
return np.float64
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
def getrf(c, a):
"""LU decomposition."""
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
lwork, opaque = cusolver_kernels.build_getrf_descriptor(
np.dtype(dtype), b, m, n)
out = c.CustomCall(
b"cusolver_getrf",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(dtype, (lwork,), (0,)),
_Shape.array_shape(
np.dtype(np.int32), batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(_Shape.array_shape(
dtype, batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 2),
c.GetTupleElement(out, 3))
def syevd(c, a, lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
np.dtype(dtype), lower, b, n)
eigvals_type = _real_type(dtype)
out = c.CustomCall(
b"cusolver_syevd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, dims, layout),
_Shape.array_shape(
np.dtype(eigvals_type), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(
np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
_Shape.array_shape(dtype, (lwork,), (0,))
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, dims, layout),
),
opaque=opaque)
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
def gesvd(c, a, full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
b = 1
m, n = a_shape.dimensions()
singular_vals_dtype = _real_type(dtype)
if m < n:
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
out = c.CustomCall(
b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, (m, n), (1, 0)),
_Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
_Shape.array_shape(dtype, (n, n), (1, 0)),
_Shape.array_shape(dtype, (m, m), (1, 0)),
_Shape.array_shape(np.dtype(np.int32), (), ()),
_Shape.array_shape(dtype, (lwork,), (0,)),
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, (m, n), (1, 0)),
),
opaque=opaque)
s = c.GetTupleElement(out, 1)
vt = c.GetTupleElement(out, 2)
u = c.GetTupleElement(out, 3)
info = c.GetTupleElement(out, 4)
else:
lwork, opaque = cusolver_kernels.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
out = c.CustomCall(
b"cusolver_gesvd",
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, (m, n), (0, 1)),
_Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
_Shape.array_shape(dtype, (m, m), (0, 1)),
_Shape.array_shape(dtype, (n, n), (0, 1)),
_Shape.array_shape(np.dtype(np.int32), (), ()),
_Shape.array_shape(dtype, (lwork,), (0,)),
)),
operand_shapes_with_layout=(
_Shape.array_shape(dtype, (m, n), (0, 1)),
),
opaque=opaque)
s = c.GetTupleElement(out, 1)
u = c.GetTupleElement(out, 2)
vt = c.GetTupleElement(out, 3)
info = c.GetTupleElement(out, 4)
if not full_matrices:
u = c.Slice(u, (0, 0), (m, min(m, n)))
vt = c.Slice(vt, (0, 0), (min(m, n), n))
return s, u, vt, info

View File

@ -166,7 +166,7 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
b_shape = c.GetShape(b)
dtype = b_shape.element_type()
@ -214,7 +214,7 @@ def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)),
Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
))
jax_trsm = trsm
# ?getrf: LU decomposition
@ -305,7 +305,7 @@ cdef void lapack_zgetrf(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_zgetrf", <void*>(lapack_zgetrf))
def jax_getrf(c, a):
def getrf(c, a):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
@ -330,7 +330,7 @@ def jax_getrf(c, a):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
return c.CustomCall(
out = c.CustomCall(
fn,
operands=(
c.ConstantS32Scalar(b),
@ -358,8 +358,10 @@ def jax_getrf(c, a):
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
return tuple(c.GetTupleElement(out, i) for i in range(3))
def jax_getrf(c, a):
return c.Tuple(*getrf(c, a))
# ?potrf: Cholesky decomposition
@ -429,7 +431,7 @@ cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))
def jax_potrf(c, a, lower=False):
def potrf(c, a, lower=False):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
@ -448,7 +450,7 @@ def jax_potrf(c, a, lower=False):
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
return c.CustomCall(
out = c.CustomCall(
fn,
operands=(c.ConstantS32Scalar(int(lower)), c.ConstantS32Scalar(n), a),
shape_with_layout=Shape.tuple_shape((
@ -460,6 +462,10 @@ def jax_potrf(c, a, lower=False):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (n, n), (0, 1)),
))
return tuple(c.GetTupleElement(out, i) for i in range(2))
def jax_potrf(c, a, lower=False):
return c.Tuple(*potrf(c, a, lower))
# ?gesdd: Singular value decomposition
@ -555,7 +561,7 @@ cdef void lapack_dgesdd(void* out_tuple, void** data) nogil:
ldvt = min(m, n)
# First perform a workspace query to get the optimal lwork
# NB: We perform a workspace query with malloc and free for the work array,
# NB: We perform a workspace query with malloc and free for the work array,
# because it is officially recommended in the LAPACK documentation
cdef double wkopt = 0
cdef int lwork = -1
@ -667,7 +673,7 @@ cdef void lapack_zgesdd(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_zgesdd", <void*>(lapack_zgesdd))
def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
def gesdd(c, a, full_matrices=True, compute_uv=True):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
@ -720,8 +726,11 @@ def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (m, n), (0, 1)),
))
return c.Tuple(c.GetTupleElement(out, 1), c.GetTupleElement(out, 2),
c.GetTupleElement(out, 3), c.GetTupleElement(out, 4))
return (c.GetTupleElement(out, 1), c.GetTupleElement(out, 2),
c.GetTupleElement(out, 3), c.GetTupleElement(out, 4))
def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
return c.Tuple(*gesdd(c, a, full_matrices, compute_uv))
# syevd: Symmetric eigendecomposition
@ -861,7 +870,7 @@ cdef void lapack_zheevd(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_zheevd", <void*>(lapack_zheevd))
def jax_syevd(c, a, lower=False):
def syevd(c, a, lower=False):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
@ -928,8 +937,11 @@ def jax_syevd(c, a, lower=False):
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
def jax_syevd(c, a, lower=False):
return c.Tuple(*syevd(c, a, lower))
# geev: Nonsymmetric eigendecomposition
@ -1140,7 +1152,7 @@ register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
def jax_geev(c, a):
def geev(c, a):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
@ -1212,11 +1224,12 @@ def jax_geev(c, a):
Shape.array_shape(dtype, dims, layout),
))
if real:
return c.Tuple(
c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)),
c.GetTupleElement(out, 5), c.GetTupleElement(out, 6),
c.GetTupleElement(out, 7))
return (c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)),
c.GetTupleElement(out, 5), c.GetTupleElement(out, 6),
c.GetTupleElement(out, 7))
else:
return c.Tuple(
c.GetTupleElement(out, 2), c.GetTupleElement(out, 3),
c.GetTupleElement(out, 4), c.GetTupleElement(out, 5))
return (c.GetTupleElement(out, 2), c.GetTupleElement(out, 3),
c.GetTupleElement(out, 4), c.GetTupleElement(out, 5))
def jax_geev(c, a):
return c.Tuple(*geev(c, a))

View File

@ -163,8 +163,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for lower in [False, True]
for rng in [jtu.rand_default()]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
# for TPU.
@jtu.skip_on_devices("tpu")
def testEigh(self, n, dtype, lower, rng):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
@ -196,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for rng in [jtu.rand_default()]
for lower in [True, False]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
# for TPU.
@jtu.skip_on_devices("tpu")
def testEighGrad(self, shape, dtype, rng, lower):
self.skipTest("Test fails with numeric errors.")
uplo = "L" if lower else "U"
@ -224,8 +224,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for lower in [True, False]
for eps in [1e-4]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
# for TPU.
@jtu.skip_on_devices("tpu")
def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
_skip_if_unsupported_type(dtype)
# Special case to test for complex eigenvector grad correctness.
@ -263,7 +263,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for shape in [(1, 1), (4, 4), (5, 5)]
for dtype in float_types + complex_types
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_devices("tpu")
def testEighBatching(self, shape, dtype, rng):
_skip_if_unsupported_type(dtype)
shape = (10,) + shape
@ -318,7 +318,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
for full_matrices in [False, True]
for compute_uv in [False, True]
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_devices("tpu")
def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((m, n), dtype)]
@ -414,7 +414,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
if not full_matrices and m >= n:
jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,))
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_devices("tpu")
def testQrBatching(self):
shape = (10, 4, 5)
dtype = np.float32
@ -476,7 +476,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
# Regression test for incorrect type for eigenvalues of a complex matrix.
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_devices("tpu")
def testIssue669(self):
def test(x):
val, vec = np.linalg.eigh(x)
@ -499,9 +499,9 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testLu(self, shape, dtype, rng):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
check_dtypes=True, tol=1e-3)
x, = args_maker()
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True)
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)
# TODO(phawkins): figure out why this test fails on Travis and reenable.
@ -555,8 +555,15 @@ class ScipyLinalgTest(jtu.JaxTestCase):
_skip_if_unsupported_type(dtype)
args_maker = lambda: [rng((n, n), dtype)]
self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
args_maker, check_dtypes=True, tol=1e-3)
x, = args_maker()
lu, piv = jsp.linalg.lu_factor(x)
l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
u = onp.triu(lu)
for i in range(n):
x[[i, piv[i]],] = x[[piv[i], i],]
self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3)
# self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
# args_maker, check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)
@parameterized.named_parameters(jtu.cases_from_list(