diff --git a/build/BUILD.bazel b/build/BUILD.bazel index 7c21b62f8..f3f70e7dc 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -14,6 +14,11 @@ # JAX is Autograd and XLA +load( + "@org_tensorflow//tensorflow/core:platform/default/cuda_build_defs.bzl", + "if_cuda_is_configured", +) + licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) @@ -26,7 +31,9 @@ sh_binary( "//jaxlib", "//jaxlib:lapack.so", "//jaxlib:pytree", - ], + ] + if_cuda_is_configured([ + "//jaxlib:cusolver_kernels", + ]), deps = ["@bazel_tools//tools/bash/runfiles"], ) diff --git a/build/install_xla_in_source_tree.sh b/build/install_xla_in_source_tree.sh index e4b18fdc7..ed3938e7e 100755 --- a/build/install_xla_in_source_tree.sh +++ b/build/install_xla_in_source_tree.sh @@ -54,7 +54,11 @@ fi # new location. cp -f "$(rlocation __main__/jaxlib/lapack.so)" "${TARGET}/jaxlib" cp -f "$(rlocation __main__/jaxlib/pytree.so)" "${TARGET}/jaxlib" +if [[ -x "$(rlocation __main__/jaxlib/cusolver_kernels.so)" ]]; then + cp -f "$(rlocation __main__/jaxlib/cusolver_kernels.so)" "${TARGET}/jaxlib" +fi cp -f "$(rlocation __main__/jaxlib/version.py)" "${TARGET}/jaxlib" +cp -f "$(rlocation __main__/jaxlib/cusolver.py)" "${TARGET}/jaxlib" cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so)" \ "${TARGET}/jaxlib" sed \ diff --git a/jax/lax_linalg.py b/jax/lax_linalg.py index 60b3566d3..aa2cb6184 100644 --- a/jax/lax_linalg.py +++ b/jax/lax_linalg.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from functools import partial import numpy as onp from jax.numpy import lax_numpy as np @@ -34,6 +35,7 @@ from jax.core import Primitive from jax.lax import (standard_primitive, standard_unop, binop_dtype_rule, _float, _complex, _input_dtype, _broadcasting_select) from jax.lib import lapack +from jax.lib import cusolver # traceables @@ -81,6 +83,11 @@ def _T(x): return np.swapaxes(x, -1, -2) def _H(x): return np.conj(_T(x)) def symmetrize(x): return (x + _H(x)) / 2 +def _unpack_tuple(f, n): + def g(c, *args, **kwargs): + t = f(c, *args, **kwargs) + return (c.GetTupleElement(t, i) for i in range(n)) + return g # primitives @@ -123,13 +130,18 @@ def _nan_like(c, operand): nan = c.Constant(onp.array(onp.nan, dtype=dtype)) return c.Broadcast(nan, shape.dimensions()) +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if hasattr(lapack, "potrf"): + _cpu_potrf = lapack.potrf +else: + _cpu_potrf = _unpack_tuple(lapack.jax_potrf, 2) + def cholesky_cpu_translation_rule(c, operand): shape = c.GetShape(operand) dtype = shape.element_type().type if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types: - potrf_output = lapack.jax_potrf(c, operand, lower=True) - result = c.GetTupleElement(potrf_output, 0) - info = c.GetTupleElement(potrf_output, 1) + result, info = _cpu_potrf(c, operand, lower=True) return c.Select(c.Eq(info, c.ConstantS32Scalar(0)), result, _nan_like(c, result)) else: @@ -163,15 +175,18 @@ def eig_abstract_eval(operand): w = vl = vr = operand return core.AbstractTuple((w, vl, vr)) +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if hasattr(lapack, "geev"): + _cpu_geev = lapack.geev +else: + _cpu_geev = _unpack_tuple(lapack.jax_geev, 4) def eig_cpu_translation_rule(c, operand): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] - geev_out = lapack.jax_geev(c, operand) - w = c.GetTupleElement(geev_out, 0) - vl = c.GetTupleElement(geev_out, 1) - vr = c.GetTupleElement(geev_out, 2) - ok = c.Eq(c.GetTupleElement(geev_out, 3), c.ConstantS32Scalar(0)) + w, vl, vr, info = _cpu_geev(c, operand) + ok = c.Eq(info, c.ConstantS32Scalar(0)) w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w, _nan_like(c, w)) vl = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), vl, @@ -219,13 +234,11 @@ def eigh_abstract_eval(operand, lower): v, w = operand, operand return core.AbstractTuple((v, w)) -def eigh_cpu_translation_rule(c, operand, lower): +def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] - syevd_out = lapack.jax_syevd(c, operand, lower=lower) - v = c.GetTupleElement(syevd_out, 0) - w = c.GetTupleElement(syevd_out, 1) - ok = c.Eq(c.GetTupleElement(syevd_out, 2), c.ConstantS32Scalar(0)) + v, w, info = syevd_impl(c, operand, lower=lower) + ok = c.Eq(info, c.ConstantS32Scalar(0)) v = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), v, _nan_like(c, v)) w = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1,)), w, @@ -267,7 +280,22 @@ eigh_p.def_impl(eigh_impl) eigh_p.def_abstract_eval(eigh_abstract_eval) xla.translations[eigh_p] = eigh_translation_rule ad.primitive_jvps[eigh_p] = eigh_jvp_rule -xla.backend_specific_translations['cpu'][eigh_p] = eigh_cpu_translation_rule + +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if hasattr(lapack, "syevd"): + _cpu_syevd = lapack.syevd +else: + _cpu_syevd = _unpack_tuple(lapack.jax_syevd, 3) + +xla.backend_specific_translations['cpu'][eigh_p] = partial( + _eigh_cpu_gpu_translation_rule, _cpu_syevd) + +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if cusolver: + xla.backend_specific_translations['gpu'][eigh_p] = partial( + _eigh_cpu_gpu_translation_rule, cusolver.syevd) batching.primitive_batchers[eigh_p] = eigh_batching_rule @@ -522,14 +550,13 @@ def _lu_batching_rule(batched_args, batch_dims): x = batching.bdim_at_front(x, bd) return lu_p.bind(x), 0 -def _lu_cpu_translation_rule(c, operand): +def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand): shape = c.GetShape(operand) batch_dims = shape.dimensions()[:-2] - getrf_out = lapack.jax_getrf(c, operand) - lu = c.GetTupleElement(getrf_out, 0) + lu, pivot, info = getrf_impl(c, operand) # Subtract 1 from the pivot to get 0-based indices. - pivot = c.Sub(c.GetTupleElement(getrf_out, 1), c.ConstantS32Scalar(1)) - ok = c.Eq(c.GetTupleElement(getrf_out, 2), c.ConstantS32Scalar(0)) + pivot = c.Sub(pivot, c.ConstantS32Scalar(1)) + ok = c.Eq(info, c.ConstantS32Scalar(0)) lu = _broadcasting_select(c, c.Reshape(ok, None, batch_dims + (1, 1)), lu, _nan_like(c, lu)) return c.Tuple(lu, pivot) @@ -541,7 +568,20 @@ lu_p.def_abstract_eval(_lu_abstract_eval) xla.translations[lu_p] = xla.lower_fun(_lu_python, instantiate=True) ad.primitive_jvps[lu_p] = _lu_jvp_rule batching.primitive_batchers[lu_p] = _lu_batching_rule -xla.backend_specific_translations['cpu'][lu_p] = _lu_cpu_translation_rule + +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if hasattr(lapack, "getrf"): + _cpu_getrf = lapack.getrf +else: + _cpu_getrf = _unpack_tuple(lapack.jax_getrf, 3) + +xla.backend_specific_translations['cpu'][lu_p] = partial( + _lu_cpu_gpu_translation_rule, _cpu_getrf) + +if cusolver: + xla.backend_specific_translations['gpu'][lu_p] = partial( + _lu_cpu_gpu_translation_rule, cusolver.getrf) def lu_pivots_to_permutation(swaps, m): @@ -681,16 +721,13 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T, U)) / s_dim return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T)) -def svd_cpu_translation_rule(c, operand, full_matrices, compute_uv): +def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute_uv): shape = c.GetShape(operand) dtype = shape.element_type().type if len(shape.dimensions()) == 2 and dtype in _cpu_lapack_types: - gesdd_out = lapack.jax_gesdd(c, operand, full_matrices=full_matrices, - compute_uv=compute_uv) - s = c.GetTupleElement(gesdd_out, 0) - u = c.GetTupleElement(gesdd_out, 1) - vt = c.GetTupleElement(gesdd_out, 2) - ok = c.Eq(c.GetTupleElement(gesdd_out, 3), c.ConstantS32Scalar(0)) + s, u, vt, info = gesvd_impl(c, operand, full_matrices=full_matrices, + compute_uv=compute_uv) + ok = c.Eq(info, c.ConstantS32Scalar(0)) s = _broadcasting_select(c, c.Reshape(ok, None, (1,)), s, _nan_like(c, s)) u = _broadcasting_select(c, c.Reshape(ok, None, (1, 1)), u, @@ -711,7 +748,22 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): svd_p = Primitive('svd') svd_p.def_impl(svd_impl) svd_p.def_abstract_eval(svd_abstract_eval) -xla.translations[svd_p] = svd_translation_rule -xla.backend_specific_translations['cpu'][svd_p] = svd_cpu_translation_rule ad.primitive_jvps[svd_p] = svd_jvp_rule batching.primitive_batchers[svd_p] = svd_batching_rule +xla.translations[svd_p] = svd_translation_rule + +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if hasattr(lapack, "gesdd"): + _cpu_gesdd = lapack.gesdd +else: + _cpu_gesdd = _unpack_tuple(lapack.jax_gesdd, 4) + +xla.backend_specific_translations['cpu'][svd_p] = partial( + _svd_cpu_gpu_translation_rule, _cpu_gesdd) + +# TODO(phawkins): remove if-condition after increasing minimum Jaxlib version to +# 0.1.23. +if cusolver: + xla.backend_specific_translations['gpu'][svd_p] = partial( + _svd_cpu_gpu_translation_rule, cusolver.gesvd) diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 4447ae304..0d476e066 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -47,3 +47,10 @@ try: from jaxlib import pytree except ImportError: pytree = None + +# TODO(phawkins): make the import unconditional when the minimum Jaxlib version +# has been increased to 0.1.23. +try: + from jaxlib import cusolver +except ImportError: + cusolver = None diff --git a/jaxlib/BUILD b/jaxlib/BUILD index f50a2edee..8de6d92dc 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -29,7 +29,10 @@ pyx_library( py_library( name = "jaxlib", - srcs = ["version.py"], + srcs = [ + "cusolver.py", + "version.py" + ], ) tf_pybind_extension( @@ -52,3 +55,29 @@ tf_pybind_extension( "@pybind11", ], ) + +tf_pybind_extension( + name = "cusolver_kernels", + srcs = ["cusolver.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", + ], + features = ["-use_header_modules"], + module_name = "cusolver_kernels", + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cudart", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cusolver", + "@pybind11", + ], +) \ No newline at end of file diff --git a/jaxlib/cusolver.cc b/jaxlib/cusolver.cc new file mode 100644 index 000000000..0bd3a0e4a --- /dev/null +++ b/jaxlib/cusolver.cc @@ -0,0 +1,585 @@ +/* Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "include/pybind11/numpy.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/stl.h" + +namespace jax { +namespace { + +namespace py = pybind11; + +void ThrowIfError(cudaError_t error) { + if (error != cudaSuccess) { + throw std::runtime_error("CUDA operation failed"); + } +} + +void ThrowIfErrorStatus(cusolverStatus_t status) { + switch (status) { + case CUSOLVER_STATUS_SUCCESS: + return; + case CUSOLVER_STATUS_NOT_INITIALIZED: + throw std::runtime_error("cuSolver has not been initialized"); + case CUSOLVER_STATUS_ALLOC_FAILED: + throw std::runtime_error("cuSolver allocation failed"); + case CUSOLVER_STATUS_INVALID_VALUE: + throw std::runtime_error("cuSolver invalid value error"); + case CUSOLVER_STATUS_ARCH_MISMATCH: + throw std::runtime_error("cuSolver architecture mismatch error"); + case CUSOLVER_STATUS_MAPPING_ERROR: + throw std::runtime_error("cuSolver mapping error"); + case CUSOLVER_STATUS_EXECUTION_FAILED: + throw std::runtime_error("cuSolver execution failed"); + case CUSOLVER_STATUS_INTERNAL_ERROR: + throw std::runtime_error("cuSolver internal error"); + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + throw std::invalid_argument("cuSolver matrix type not supported error"); + case CUSOLVER_STATUS_NOT_SUPPORTED: + throw std::runtime_error("cuSolver not supported error"); + case CUSOLVER_STATUS_ZERO_PIVOT: + throw std::runtime_error("cuSolver zero pivot error"); + case CUSOLVER_STATUS_INVALID_LICENSE: + throw std::runtime_error("cuSolver invalid license error"); + default: + throw std::runtime_error("Unknown cuSolver error"); + } +} + +// To avoid creating cusolver contexts in the middle of execution, we maintain +// a pool of them. +class SolverHandlePool { + public: + SolverHandlePool() = default; + + // RAII class representing a cusolver handle borrowed from the pool. Returns + // the handle to the pool on destruction. + class Handle { + public: + Handle() = default; + ~Handle() { + if (pool_) { + pool_->Return(handle_); + } + } + + Handle(Handle const&) = delete; + Handle(Handle&& other) { + pool_ = other.pool_; + handle_ = other.handle_; + other.pool_ = nullptr; + other.handle_ = nullptr; + } + Handle& operator=(Handle const&) = delete; + Handle& operator=(Handle&& other) { + pool_ = other.pool_; + handle_ = other.handle_; + other.pool_ = nullptr; + other.handle_ = nullptr; + return *this; + } + + cusolverDnHandle_t get() { return handle_; } + + private: + friend class SolverHandlePool; + Handle(SolverHandlePool* pool, cusolverDnHandle_t handle) + : pool_(pool), handle_(handle) {} + SolverHandlePool* pool_ = nullptr; + cusolverDnHandle_t handle_ = nullptr; + }; + + // Borrows a handle from the pool. If 'stream' is non-null, sets the stream + // associated with the handle. + static Handle Borrow(cudaStream_t stream = nullptr); + + private: + static SolverHandlePool* Instance(); + + void Return(cusolverDnHandle_t handle); + + absl::Mutex mu_; + std::vector handles_ GUARDED_BY(mu_); +}; + +/*static*/ SolverHandlePool* SolverHandlePool::Instance() { + static auto* pool = new SolverHandlePool; + return pool; +} + +/*static*/ SolverHandlePool::Handle SolverHandlePool::Borrow( + cudaStream_t stream) { + SolverHandlePool* pool = Instance(); + absl::MutexLock lock(&pool->mu_); + cusolverDnHandle_t handle; + if (pool->handles_.empty()) { + ThrowIfErrorStatus(cusolverDnCreate(&handle)); + } else { + handle = pool->handles_.back(); + pool->handles_.pop_back(); + } + if (stream) { + ThrowIfErrorStatus(cusolverDnSetStream(handle, stream)); + } + return Handle(pool, handle); +} + +void SolverHandlePool::Return(cusolverDnHandle_t handle) { + absl::MutexLock lock(&mu_); + handles_.push_back(handle); +} + +// Set of types known to Cusolver. +enum class Type { + F32, + F64, + C64, + C128, +}; + +// Converts a NumPy dtype to a Type. +Type DtypeToType(const py::dtype& np_type) { + static auto* types = new absl::flat_hash_map, Type>({ + {{'f', 4}, Type::F32}, + {{'f', 8}, Type::F64}, + {{'c', 8}, Type::C64}, + {{'c', 16}, Type::C128}, + }); + auto it = types->find({np_type.kind(), np_type.itemsize()}); + if (it == types->end()) { + throw std::invalid_argument( + absl::StrFormat("Unsupported dtype %s", py::repr(np_type))); + } + return it->second; +} + +int SizeOfType(Type type) { + switch (type) { + case Type::F32: + return sizeof(float); + case Type::F64: + return sizeof(double); + case Type::C64: + return sizeof(cuComplex); + case Type::C128: + return sizeof(cuDoubleComplex); + } +} + +// Descriptor objects are opaque host-side objects used to pass data from JAX +// to the custom kernel launched by XLA. Currently simply treat host-side +// structures as byte-strings; this is not portable across architectures. If +// portability is needed, we could switch to using a representation such as +// protocol buffers or flatbuffers. + +// Packs a descriptor object into a py::bytes structure. +template +py::bytes PackDescriptor(const T& descriptor) { + return py::bytes(absl::bit_cast(&descriptor), sizeof(T)); +} + +// Unpacks a descriptor object from a byte string. +template +const T* UnpackDescriptor(const char* opaque, size_t opaque_len) { + if (opaque_len != sizeof(T)) { + throw std::runtime_error("Invalid size for linalg operation descriptor."); + } + return absl::bit_cast(opaque); +} + +// getrf: LU decomposition + +struct GetrfDescriptor { + Type type; + int batch, m, n; +}; + +// Returns the workspace size and a descriptor for a getrf operation. +std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, + int m, int n) { + Type type = DtypeToType(dtype); + auto handle = SolverHandlePool::Borrow(); + int lwork; + switch (type) { + case Type::F32: + ThrowIfErrorStatus(cusolverDnSgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork)); + break; + case Type::F64: + ThrowIfErrorStatus(cusolverDnDgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork)); + break; + case Type::C64: + ThrowIfErrorStatus(cusolverDnCgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork)); + break; + case Type::C128: + ThrowIfErrorStatus(cusolverDnZgetrf_bufferSize(handle.get(), m, n, + /*A=*/nullptr, + /*lda=*/m, &lwork)); + break; + } + return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n})}; +} + +void Getrf(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + const GetrfDescriptor& d = + *UnpackDescriptor(opaque, opaque_len); + auto handle = SolverHandlePool::Borrow(stream); + ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0], + SizeOfType(d.type) * d.batch * d.m * d.n, + cudaMemcpyDeviceToDevice, stream)); + + void* workspace = buffers[2]; + int* ipiv = static_cast(buffers[3]); + int* info = static_cast(buffers[4]); + switch (d.type) { + case Type::F32: { + float* a = static_cast(buffers[1]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnSgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), + ipiv, info)); + a += d.m * d.n; + ipiv += std::min(d.m, d.n); + ++info; + } + break; + } + case Type::F64: { + double* a = static_cast(buffers[1]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnDgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), + ipiv, info)); + a += d.m * d.n; + ipiv += std::min(d.m, d.n); + ++info; + } + break; + } + case Type::C64: { + cuComplex* a = static_cast(buffers[1]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnCgetrf(handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), + ipiv, info)); + a += d.m * d.n; + ipiv += std::min(d.m, d.n); + ++info; + } + break; + } + case Type::C128: { + cuDoubleComplex* a = static_cast(buffers[1]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnZgetrf( + handle.get(), d.m, d.n, a, d.m, + static_cast(workspace), ipiv, info)); + a += d.m * d.n; + ipiv += std::min(d.m, d.n); + ++info; + } + break; + } + } +} + +// Symmetric (Hermitian) eigendecomposition: syevd/heevd + +struct SyevdDescriptor { + Type type; + cublasFillMode_t uplo; + int batch, n; + int lwork; +}; + +// Returns the workspace size and a descriptor for a syevd operation. +std::pair BuildSyevdDescriptor(const py::dtype& dtype, + bool lower, int b, int n) { + Type type = DtypeToType(dtype); + auto handle = SolverHandlePool::Borrow(); + int lwork; + cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; + cublasFillMode_t uplo = + lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + switch (type) { + case Type::F32: + ThrowIfErrorStatus(cusolverDnSsyevd_bufferSize( + handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, + &lwork)); + break; + case Type::F64: + ThrowIfErrorStatus(cusolverDnDsyevd_bufferSize( + handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, + &lwork)); + break; + case Type::C64: + ThrowIfErrorStatus(cusolverDnCheevd_bufferSize( + handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, + &lwork)); + break; + case Type::C128: + ThrowIfErrorStatus(cusolverDnZheevd_bufferSize( + handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, + &lwork)); + break; + } + return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; +} + +void Syevd(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + const SyevdDescriptor& d = + *UnpackDescriptor(opaque, opaque_len); + auto handle = SolverHandlePool::Borrow(stream); + ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0], + SizeOfType(d.type) * d.batch * d.n * d.n, + cudaMemcpyDeviceToDevice, stream)); + cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; + int* info = static_cast(buffers[3]); + void* work = buffers[4]; + switch (d.type) { + case Type::F32: { + float* a = static_cast(buffers[1]); + float* w = static_cast(buffers[2]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, + d.n, w, static_cast(work), + d.lwork, info)); + a += d.n * d.n; + w += d.n; + ++info; + } + break; + } + case Type::F64: { + double* a = static_cast(buffers[1]); + double* w = static_cast(buffers[2]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, + d.n, w, static_cast(work), + d.lwork, info)); + a += d.n * d.n; + w += d.n; + ++info; + } + break; + } + case Type::C64: { + cuComplex* a = static_cast(buffers[1]); + float* w = static_cast(buffers[2]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus( + cusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info)); + a += d.n * d.n; + w += d.n; + ++info; + } + break; + } + case Type::C128: { + cuDoubleComplex* a = static_cast(buffers[1]); + double* w = static_cast(buffers[2]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnZheevd( + handle.get(), jobz, d.uplo, d.n, a, d.n, w, + static_cast(work), d.lwork, info)); + a += d.n * d.n; + w += d.n; + ++info; + } + break; + } + } +} + +// Singular value decomposition: gesvd + +struct GesvdDescriptor { + Type type; + int batch, m, n; + int lwork; + signed char jobu, jobvt; +}; + +// Returns the workspace size and a descriptor for a gesvd operation. +std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, + int m, int n, bool compute_uv, + bool full_matrices) { + Type type = DtypeToType(dtype); + auto handle = SolverHandlePool::Borrow(); + int lwork; + switch (type) { + case Type::F32: + ThrowIfErrorStatus( + cusolverDnSgesvd_bufferSize(handle.get(), m, n, &lwork)); + break; + case Type::F64: + ThrowIfErrorStatus( + cusolverDnDgesvd_bufferSize(handle.get(), m, n, &lwork)); + break; + case Type::C64: + ThrowIfErrorStatus( + cusolverDnCgesvd_bufferSize(handle.get(), m, n, &lwork)); + break; + case Type::C128: + ThrowIfErrorStatus( + cusolverDnZgesvd_bufferSize(handle.get(), m, n, &lwork)); + break; + } + signed char jobu, jobvt; + if (compute_uv) { + if (full_matrices) { + jobu = jobvt = 'A'; + } else { + jobu = jobvt = 'S'; + } + } else { + jobu = jobvt = 'N'; + } + return {lwork, + PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; +} + +// TODO(phawkins): in the batched case, we should consider using the batched +// Jacobi implementation instead. +void Gesvd(cudaStream_t stream, void** buffers, const char* opaque, + size_t opaque_len) { + const GesvdDescriptor& d = + *UnpackDescriptor(opaque, opaque_len); + auto handle = SolverHandlePool::Borrow(stream); + ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0], + SizeOfType(d.type) * d.batch * d.m * d.n, + cudaMemcpyDeviceToDevice, stream)); + int* info = static_cast(buffers[5]); + void* work = buffers[6]; + switch (d.type) { + case Type::F32: { + float* a = static_cast(buffers[1]); + float* s = static_cast(buffers[2]); + float* u = static_cast(buffers[3]); + float* vt = static_cast(buffers[4]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnSgesvd(handle.get(), d.jobu, d.jobvt, d.m, + d.n, a, d.m, s, u, d.m, vt, d.n, + static_cast(work), d.lwork, + /*rwork=*/nullptr, info)); + a += d.m * d.n; + s += std::min(d.m, d.n); + u += d.m * d.m; + vt += d.n * d.n; + ++info; + + } + break; + } + case Type::F64: { + double* a = static_cast(buffers[1]); + double* s = static_cast(buffers[2]); + double* u = static_cast(buffers[3]); + double* vt = static_cast(buffers[4]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnDgesvd(handle.get(), d.jobu, d.jobvt, d.m, + d.n, a, d.m, s, u, d.m, vt, d.n, + static_cast(work), d.lwork, + /*rwork=*/nullptr, info)); + a += d.m * d.n; + s += std::min(d.m, d.n); + u += d.m * d.m; + vt += d.n * d.n; + ++info; + } + break; + } + case Type::C64: { + cuComplex* a = static_cast(buffers[1]); + float* s = static_cast(buffers[2]); + cuComplex* u = static_cast(buffers[3]); + cuComplex* vt = static_cast(buffers[4]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus( + cusolverDnCgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, + u, d.m, vt, d.n, static_cast(work), + d.lwork, /*rwork=*/nullptr, info)); + a += d.m * d.n; + s += std::min(d.m, d.n); + u += d.m * d.m; + vt += d.n * d.n; + ++info; + } + break; + } + case Type::C128: { + cuDoubleComplex* a = static_cast(buffers[1]); + double* s = static_cast(buffers[2]); + cuDoubleComplex* u = static_cast(buffers[3]); + cuDoubleComplex* vt = static_cast(buffers[4]); + for (int i = 0; i < d.batch; ++i) { + ThrowIfErrorStatus(cusolverDnZgesvd( + handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, + static_cast(work), d.lwork, + /*rwork=*/nullptr, info)); + a += d.m * d.n; + s += std::min(d.m, d.n); + u += d.m * d.m; + vt += d.n * d.n; + ++info; + } + break; + } + } +} + +template +py::capsule EncapsulateFunction(T* fn) { + return py::capsule(absl::bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +py::dict Registrations() { + py::dict dict; + dict["cusolver_getrf"] = EncapsulateFunction(Getrf); + dict["cusolver_syevd"] = EncapsulateFunction(Syevd); + dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd); + return dict; +} + +PYBIND11_MODULE(cusolver_kernels, m) { + m.def("registrations", &Registrations); + m.def("build_getrf_descriptor", &BuildGetrfDescriptor); + m.def("build_syevd_descriptor", &BuildSyevdDescriptor); + m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); +} + +} // namespace +} // namespace jax diff --git a/jaxlib/cusolver.py b/jaxlib/cusolver.py new file mode 100644 index 000000000..1b79f4be8 --- /dev/null +++ b/jaxlib/cusolver.py @@ -0,0 +1,182 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from jaxlib import xla_client + +try: + from jaxlib import cusolver_kernels + for _name, _value in cusolver_kernels.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") +except ImportError: + pass + +_Shape = xla_client.Shape + + +def _real_type(dtype): + """Returns the real equivalent of 'dtype'.""" + if dtype == np.float32: + return np.float32 + elif dtype == np.float64: + return np.float64 + elif dtype == np.complex64: + return np.float32 + elif dtype == np.complex128: + return np.float64 + else: + raise NotImplementedError("Unsupported dtype {}".format(dtype)) + + +def getrf(c, a): + """LU decomposition.""" + a_shape = c.GetShape(a) + dtype = a_shape.element_type() + dims = a_shape.dimensions() + assert len(dims) >= 2 + m, n = dims[-2:] + batch_dims = tuple(dims[:-2]) + num_bd = len(batch_dims) + b = 1 + for d in batch_dims: + b *= d + + lwork, opaque = cusolver_kernels.build_getrf_descriptor( + np.dtype(dtype), b, m, n) + out = c.CustomCall( + b"cusolver_getrf", + operands=(a,), + shape_with_layout=_Shape.tuple_shape(( + _Shape.array_shape( + dtype, batch_dims + (m, n), + (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), + _Shape.array_shape(dtype, (lwork,), (0,)), + _Shape.array_shape( + np.dtype(np.int32), batch_dims + (min(m, n),), + tuple(range(num_bd, -1, -1))), + _Shape.array_shape( + np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))), + )), + operand_shapes_with_layout=(_Shape.array_shape( + dtype, batch_dims + (m, n), + (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),), + opaque=opaque) + return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 2), + c.GetTupleElement(out, 3)) + + +def syevd(c, a, lower=False): + """Symmetric (Hermitian) eigendecomposition.""" + + a_shape = c.GetShape(a) + dtype = a_shape.element_type() + dims = a_shape.dimensions() + assert len(dims) >= 2 + m, n = dims[-2:] + assert m == n + batch_dims = tuple(dims[:-2]) + num_bd = len(batch_dims) + b = 1 + for d in batch_dims: + b *= d + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + lwork, opaque = cusolver_kernels.build_syevd_descriptor( + np.dtype(dtype), lower, b, n) + eigvals_type = _real_type(dtype) + + out = c.CustomCall( + b"cusolver_syevd", + operands=(a,), + shape_with_layout=_Shape.tuple_shape(( + _Shape.array_shape(dtype, dims, layout), + _Shape.array_shape( + np.dtype(eigvals_type), batch_dims + (n,), + tuple(range(num_bd, -1, -1))), + _Shape.array_shape( + np.dtype(np.int32), batch_dims, + tuple(range(num_bd - 1, -1, -1))), + _Shape.array_shape(dtype, (lwork,), (0,)) + )), + operand_shapes_with_layout=( + _Shape.array_shape(dtype, dims, layout), + ), + opaque=opaque) + return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1), + c.GetTupleElement(out, 2)) + + +def gesvd(c, a, full_matrices=True, compute_uv=True): + """Singular value decomposition.""" + + a_shape = c.GetShape(a) + dtype = a_shape.element_type() + b = 1 + m, n = a_shape.dimensions() + singular_vals_dtype = _real_type(dtype) + + if m < n: + lwork, opaque = cusolver_kernels.build_gesvd_descriptor( + np.dtype(dtype), b, n, m, compute_uv, full_matrices) + out = c.CustomCall( + b"cusolver_gesvd", + operands=(a,), + shape_with_layout=_Shape.tuple_shape(( + _Shape.array_shape(dtype, (m, n), (1, 0)), + _Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)), + _Shape.array_shape(dtype, (n, n), (1, 0)), + _Shape.array_shape(dtype, (m, m), (1, 0)), + _Shape.array_shape(np.dtype(np.int32), (), ()), + _Shape.array_shape(dtype, (lwork,), (0,)), + )), + operand_shapes_with_layout=( + _Shape.array_shape(dtype, (m, n), (1, 0)), + ), + opaque=opaque) + s = c.GetTupleElement(out, 1) + vt = c.GetTupleElement(out, 2) + u = c.GetTupleElement(out, 3) + info = c.GetTupleElement(out, 4) + else: + lwork, opaque = cusolver_kernels.build_gesvd_descriptor( + np.dtype(dtype), b, m, n, compute_uv, full_matrices) + + out = c.CustomCall( + b"cusolver_gesvd", + operands=(a,), + shape_with_layout=_Shape.tuple_shape(( + _Shape.array_shape(dtype, (m, n), (0, 1)), + _Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)), + _Shape.array_shape(dtype, (m, m), (0, 1)), + _Shape.array_shape(dtype, (n, n), (0, 1)), + _Shape.array_shape(np.dtype(np.int32), (), ()), + _Shape.array_shape(dtype, (lwork,), (0,)), + )), + operand_shapes_with_layout=( + _Shape.array_shape(dtype, (m, n), (0, 1)), + ), + opaque=opaque) + s = c.GetTupleElement(out, 1) + u = c.GetTupleElement(out, 2) + vt = c.GetTupleElement(out, 3) + info = c.GetTupleElement(out, 4) + if not full_matrices: + u = c.Slice(u, (0, 0), (m, min(m, n))) + vt = c.Slice(vt, (0, 0), (min(m, n), n)) + return s, u, vt, info diff --git a/jaxlib/lapack.pyx b/jaxlib/lapack.pyx index 22743fae6..98763e580 100644 --- a/jaxlib/lapack.pyx +++ b/jaxlib/lapack.pyx @@ -166,7 +166,7 @@ cdef void blas_ztrsm(void* out, void** data) nogil: register_cpu_custom_call_target(b"blas_ztrsm", (blas_ztrsm)) -def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False, +def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False, conj_a=False, diag=False): b_shape = c.GetShape(b) dtype = b_shape.element_type() @@ -214,7 +214,7 @@ def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False, Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)), Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)), )) - +jax_trsm = trsm # ?getrf: LU decomposition @@ -305,7 +305,7 @@ cdef void lapack_zgetrf(void* out_tuple, void** data) nogil: register_cpu_custom_call_target(b"lapack_zgetrf", (lapack_zgetrf)) -def jax_getrf(c, a): +def getrf(c, a): assert sizeof(int32_t) == sizeof(int) a_shape = c.GetShape(a) @@ -330,7 +330,7 @@ def jax_getrf(c, a): else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) - return c.CustomCall( + out = c.CustomCall( fn, operands=( c.ConstantS32Scalar(b), @@ -358,8 +358,10 @@ def jax_getrf(c, a): batch_dims + (m, n), (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))), )) + return tuple(c.GetTupleElement(out, i) for i in range(3)) - +def jax_getrf(c, a): + return c.Tuple(*getrf(c, a)) # ?potrf: Cholesky decomposition @@ -429,7 +431,7 @@ cdef void lapack_zpotrf(void* out_tuple, void** data) nogil: register_cpu_custom_call_target(b"lapack_zpotrf", (lapack_zpotrf)) -def jax_potrf(c, a, lower=False): +def potrf(c, a, lower=False): assert sizeof(int32_t) == sizeof(int) a_shape = c.GetShape(a) @@ -448,7 +450,7 @@ def jax_potrf(c, a, lower=False): else: raise NotImplementedError("Unsupported dtype {}".format(dtype)) - return c.CustomCall( + out = c.CustomCall( fn, operands=(c.ConstantS32Scalar(int(lower)), c.ConstantS32Scalar(n), a), shape_with_layout=Shape.tuple_shape(( @@ -460,6 +462,10 @@ def jax_potrf(c, a, lower=False): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, (n, n), (0, 1)), )) + return tuple(c.GetTupleElement(out, i) for i in range(2)) + +def jax_potrf(c, a, lower=False): + return c.Tuple(*potrf(c, a, lower)) # ?gesdd: Singular value decomposition @@ -555,7 +561,7 @@ cdef void lapack_dgesdd(void* out_tuple, void** data) nogil: ldvt = min(m, n) # First perform a workspace query to get the optimal lwork - # NB: We perform a workspace query with malloc and free for the work array, + # NB: We perform a workspace query with malloc and free for the work array, # because it is officially recommended in the LAPACK documentation cdef double wkopt = 0 cdef int lwork = -1 @@ -667,7 +673,7 @@ cdef void lapack_zgesdd(void* out_tuple, void** data) nogil: register_cpu_custom_call_target(b"lapack_zgesdd", (lapack_zgesdd)) -def jax_gesdd(c, a, full_matrices=True, compute_uv=True): +def gesdd(c, a, full_matrices=True, compute_uv=True): assert sizeof(int32_t) == sizeof(int) a_shape = c.GetShape(a) @@ -720,8 +726,11 @@ def jax_gesdd(c, a, full_matrices=True, compute_uv=True): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, (m, n), (0, 1)), )) - return c.Tuple(c.GetTupleElement(out, 1), c.GetTupleElement(out, 2), - c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)) + return (c.GetTupleElement(out, 1), c.GetTupleElement(out, 2), + c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)) + +def jax_gesdd(c, a, full_matrices=True, compute_uv=True): + return c.Tuple(*gesdd(c, a, full_matrices, compute_uv)) # syevd: Symmetric eigendecomposition @@ -861,7 +870,7 @@ cdef void lapack_zheevd(void* out_tuple, void** data) nogil: register_cpu_custom_call_target(b"lapack_zheevd", (lapack_zheevd)) -def jax_syevd(c, a, lower=False): +def syevd(c, a, lower=False): assert sizeof(int32_t) == sizeof(int) a_shape = c.GetShape(a) @@ -928,8 +937,11 @@ def jax_syevd(c, a, lower=False): Shape.array_shape(np.dtype(np.int32), (), ()), Shape.array_shape(dtype, dims, layout), )) - return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1), - c.GetTupleElement(out, 2)) + return (c.GetTupleElement(out, 0), c.GetTupleElement(out, 1), + c.GetTupleElement(out, 2)) + +def jax_syevd(c, a, lower=False): + return c.Tuple(*syevd(c, a, lower)) # geev: Nonsymmetric eigendecomposition @@ -1140,7 +1152,7 @@ register_cpu_custom_call_target(b"lapack_zgeev", (lapack_zgeev)) -def jax_geev(c, a): +def geev(c, a): assert sizeof(int32_t) == sizeof(int) a_shape = c.GetShape(a) @@ -1212,11 +1224,12 @@ def jax_geev(c, a): Shape.array_shape(dtype, dims, layout), )) if real: - return c.Tuple( - c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)), - c.GetTupleElement(out, 5), c.GetTupleElement(out, 6), - c.GetTupleElement(out, 7)) + return (c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)), + c.GetTupleElement(out, 5), c.GetTupleElement(out, 6), + c.GetTupleElement(out, 7)) else: - return c.Tuple( - c.GetTupleElement(out, 2), c.GetTupleElement(out, 3), - c.GetTupleElement(out, 4), c.GetTupleElement(out, 5)) + return (c.GetTupleElement(out, 2), c.GetTupleElement(out, 3), + c.GetTupleElement(out, 4), c.GetTupleElement(out, 5)) + +def jax_geev(c, a): + return c.Tuple(*geev(c, a)) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d20cdaa5d..109c16cb3 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -163,8 +163,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + # for TPU. + @jtu.skip_on_devices("tpu") def testEigh(self, n, dtype, lower, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] @@ -196,8 +196,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): for rng in [jtu.rand_default()] for lower in [True, False])) # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + # for TPU. + @jtu.skip_on_devices("tpu") def testEighGrad(self, shape, dtype, rng, lower): self.skipTest("Test fails with numeric errors.") uplo = "L" if lower else "U" @@ -224,8 +224,8 @@ class NumpyLinalgTest(jtu.JaxTestCase): for lower in [True, False] for eps in [1e-4])) # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.skip_on_devices("gpu", "tpu") + # for TPU. + @jtu.skip_on_devices("tpu") def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps): _skip_if_unsupported_type(dtype) # Special case to test for complex eigenvector grad correctness. @@ -263,7 +263,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.skip_on_devices("tpu") def testEighBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) shape = (10,) + shape @@ -318,7 +318,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.skip_on_devices("tpu") def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((m, n), dtype)] @@ -414,7 +414,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,)) - @jtu.skip_on_devices("gpu", "tpu") + @jtu.skip_on_devices("tpu") def testQrBatching(self): shape = (10, 4, 5) dtype = np.float32 @@ -476,7 +476,7 @@ class NumpyLinalgTest(jtu.JaxTestCase): self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True) # Regression test for incorrect type for eigenvalues of a complex matrix. - @jtu.skip_on_devices("gpu", "tpu") + @jtu.skip_on_devices("tpu") def testIssue669(self): def test(x): val, vec = np.linalg.eigh(x) @@ -499,9 +499,9 @@ class ScipyLinalgTest(jtu.JaxTestCase): def testLu(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker, - check_dtypes=True, tol=1e-3) + x, = args_maker() + p, l, u = jsp.linalg.lu(x) + self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) # TODO(phawkins): figure out why this test fails on Travis and reenable. @@ -555,8 +555,15 @@ class ScipyLinalgTest(jtu.JaxTestCase): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] - self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor, - args_maker, check_dtypes=True, tol=1e-3) + x, = args_maker() + lu, piv = jsp.linalg.lu_factor(x) + l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype) + u = onp.triu(lu) + for i in range(n): + x[[i, piv[i]],] = x[[piv[i], i],] + self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3) + # self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor, + # args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list(