rocm_jax/jaxlib/lapack.pyx
Peter Hawkins c3add0c750 Enable some tests that now pass.
Fix some dtypes in the lapack eig implementation to fix test failures.
2019-05-29 10:23:27 -04:00

1223 lines
41 KiB
Cython

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# distutils: language = c++
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCall.
from __future__ import print_function
from libc.stdlib cimport malloc, free
from libc.stdint cimport int32_t
from libc.string cimport memcpy
from libcpp.string cimport string
from cpython.pycapsule cimport PyCapsule_New
from scipy.linalg.cython_blas cimport strsm, dtrsm, ctrsm, ztrsm
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, zgetrf
from scipy.linalg.cython_lapack cimport spotrf, dpotrf, cpotrf, zpotrf
from scipy.linalg.cython_lapack cimport sgesdd, dgesdd, cgesdd, zgesdd
from scipy.linalg.cython_lapack cimport ssyevd, dsyevd, cheevd, zheevd
from scipy.linalg.cython_lapack cimport sgeev, dgeev, cgeev, zgeev
import numpy as np
from jaxlib import xla_client
Shape = xla_client.Shape
cdef register_cpu_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET"
xla_client.register_cpu_custom_call_target(
fn_name, PyCapsule_New(fn, name, NULL))
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
cdef void blas_strsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef float* alpha = <float*>(data[6])
cdef float* a = <float*>(data[7])
cdef float* b = <float*>(data[8])
cdef float* x = <float*>(out)
if x != b:
memcpy(x, b, m * n * sizeof(float))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
register_cpu_custom_call_target(b"blas_strsm", <void*>(blas_strsm))
cdef void blas_dtrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef double* alpha = <double*>(data[6])
cdef double* a = <double*>(data[7])
cdef double* b = <double*>(data[8])
cdef double* x = <double*>(out)
if x != b:
memcpy(x, b, m * n * sizeof(double))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
register_cpu_custom_call_target(b"blas_dtrsm", <void*>(blas_dtrsm))
cdef void blas_ctrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef float complex* alpha = <float complex*>(data[6])
cdef float complex* a = <float complex*>(data[7])
cdef float complex* b = <float complex*>(data[8])
cdef float complex* x = <float complex*>(out)
if x != b:
memcpy(x, b, m * n * sizeof(float complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
register_cpu_custom_call_target(b"blas_ctrsm", <void*>(blas_ctrsm))
cdef void blas_ztrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef double complex* alpha = <double complex*>(data[6])
cdef double complex* a = <double complex*>(data[7])
cdef double complex* b = <double complex*>(data[8])
cdef double complex* x = <double complex*>(out)
if x != b:
memcpy(x, b, m * n * sizeof(double complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
def jax_trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
b_shape = c.GetShape(b)
dtype = b_shape.element_type()
m, n = b_shape.dimensions()
k = m if left_side else n
a_shape = c.GetShape(a)
if (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
if dtype == np.float32:
fn = b"blas_strsm"
elif dtype == np.float64:
fn = b"blas_dtrsm"
elif dtype == np.complex64:
fn = b"blas_ctrsm"
elif dtype == np.complex128:
fn = b"blas_ztrsm"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
return c.CustomCall(
fn,
operands=(
c.ConstantS32Scalar(int(left_side)),
c.ConstantS32Scalar(int(lower)),
c.ConstantS32Scalar((2 if conj_a else 1) if trans_a else 0),
c.ConstantS32Scalar(int(diag)),
c.ConstantS32Scalar(m),
c.ConstantS32Scalar(n),
alpha, a, b),
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (), ()),
Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)),
Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
))
# ?getrf: LU decomposition
cdef void lapack_sgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in, b * m * n * sizeof(float))
for i in range(b):
sgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_sgetrf", <void*>(lapack_sgetrf))
cdef void lapack_dgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in, b * m * n * sizeof(double))
for i in range(b):
dgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_dgetrf", <void*>(lapack_dgetrf))
cdef void lapack_cgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in, b * m * n * sizeof(float complex))
for i in range(b):
cgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_cgetrf", <void*>(lapack_cgetrf))
cdef void lapack_zgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in, b * m * n * sizeof(double complex))
for i in range(b):
zgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_zgetrf", <void*>(lapack_zgetrf))
def jax_getrf(c, a):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgetrf"
elif dtype == np.float64:
fn = b"lapack_dgetrf"
elif dtype == np.complex64:
fn = b"lapack_cgetrf"
elif dtype == np.complex128:
fn = b"lapack_zgetrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
return c.CustomCall(
fn,
operands=(
c.ConstantS32Scalar(b),
c.ConstantS32Scalar(m),
c.ConstantS32Scalar(n),
a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
np.dtype(np.int32),
batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
# ?potrf: Cholesky decomposition
cdef void lapack_spotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float* a_in = <float*>(data[2])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in, n * n * sizeof(float))
spotrf(&uplo, &n, a_out, &n, info)
register_cpu_custom_call_target(b"lapack_spotrf", <void*>(lapack_spotrf))
cdef void lapack_dpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double* a_in = <double*>(data[2])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in, n * n * sizeof(double))
dpotrf(&uplo, &n, a_out, &n, info)
register_cpu_custom_call_target(b"lapack_dpotrf", <void*>(lapack_dpotrf))
cdef void lapack_cpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float complex* a_in = <float complex*>(data[2])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in, n * n * sizeof(float complex))
cpotrf(&uplo, &n, a_out, &n, info)
register_cpu_custom_call_target(b"lapack_cpotrf", <void*>(lapack_cpotrf))
cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double complex* a_in = <double complex*>(data[2])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in, n * n * sizeof(double complex))
zpotrf(&uplo, &n, a_out, &n, info)
register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))
def jax_potrf(c, a, lower=False):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
m, n = a_shape.dimensions()
if m != n:
raise ValueError("potrf expects a square matrix, got {}".format(a_shape))
if dtype == np.float32:
fn = b"lapack_spotrf"
elif dtype == np.float64:
fn = b"lapack_dpotrf"
elif dtype == np.complex64:
fn = b"lapack_cpotrf"
elif dtype == np.complex128:
fn = b"lapack_zpotrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
return c.CustomCall(
fn,
operands=(c.ConstantS32Scalar(int(lower)), c.ConstantS32Scalar(n), a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.int32), (), ()),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (n, n), (0, 1)),
))
# ?gesdd: Singular value decomposition
cdef int gesdd_iwork_size(int m, int n) nogil:
return 8 * min(m, n)
cdef int cgesdd_rwork_size(int m, int n, int compute_uv) nogil:
cdef int mn = min(m, n)
if compute_uv == 0:
return 7 * mn
cdef int mx = max(m, n)
return max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)
cdef void lapack_sgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int m = (<int32_t*>(data[2]))[0]
cdef int n = (<int32_t*>(data[3]))[0]
cdef float* a_in = <float*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* s = <float*>(out[1])
cdef float* u = <float*>(out[2])
cdef float* vt = <float*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in, m * n * sizeof(float))
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
cdef int lda = m
cdef int ldu = m
cdef int ldvt = n
if job_opt_full_matrices == 0:
ldvt = min(m, n)
# First perform a workspace query to get the optimal lwork
# NB: We perform a workspace query with malloc and free for the work array,
# because it is officially recommended in the LAPACK documentation
cdef float wkopt = 0
cdef int lwork = -1
sgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, &wkopt, &lwork, iwork, info)
lwork = <int> wkopt
# Now get the actual SVD
cdef float* work = <float *> malloc(lwork * sizeof(float))
sgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info)
free(work)
register_cpu_custom_call_target(b"lapack_sgesdd", <void*>(lapack_sgesdd))
cdef void lapack_dgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int m = (<int32_t*>(data[2]))[0]
cdef int n = (<int32_t*>(data[3]))[0]
cdef double* a_in = <double*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* s = <double*>(out[1])
cdef double* u = <double*>(out[2])
cdef double* vt = <double*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in, m * n * sizeof(double))
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
cdef int lda = m
cdef int ldu = m
cdef int ldvt = n
if job_opt_full_matrices == 0:
ldvt = min(m, n)
# First perform a workspace query to get the optimal lwork
# NB: We perform a workspace query with malloc and free for the work array,
# because it is officially recommended in the LAPACK documentation
cdef double wkopt = 0
cdef int lwork = -1
dgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, &wkopt, &lwork, iwork, info)
lwork = <int> wkopt
# Now get the actual SVD
cdef double* work = <double *> malloc(lwork * sizeof(double))
dgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info)
free(work)
register_cpu_custom_call_target(b"lapack_dgesdd", <void*>(lapack_dgesdd))
cdef void lapack_cgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int m = (<int32_t*>(data[2]))[0]
cdef int n = (<int32_t*>(data[3]))[0]
cdef float complex* a_in = <float complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float* s = <float*>(out[1])
cdef float complex* u = <float complex*>(out[2])
cdef float complex* vt = <float complex*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef float* rwork = <float*>(out[6])
if a_out != a_in:
memcpy(a_out, a_in, m * n * sizeof(float complex))
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
cdef int lda = m
cdef int ldu = m
cdef int ldvt = n
if job_opt_full_matrices == 0:
ldvt = min(m, n)
# First perform a workspace query to get the optimal lwork
# NB: We perform a workspace query with malloc and free for the work array,
# because it is officially recommended in the LAPACK documentation
cdef float complex wkopt = 0
cdef int lwork = -1
cgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, &wkopt, &lwork, rwork, iwork, info)
lwork = <int>(wkopt.real)
# Now get the actual SVD
cdef float complex* work = <float complex*> malloc(lwork * sizeof(float complex))
cgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, iwork, info)
free(work)
register_cpu_custom_call_target(b"lapack_cgesdd", <void*>(lapack_cgesdd))
cdef void lapack_zgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int m = (<int32_t*>(data[2]))[0]
cdef int n = (<int32_t*>(data[3]))[0]
cdef double complex* a_in = <double complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double* s = <double*>(out[1])
cdef double complex* u = <double complex*>(out[2])
cdef double complex* vt = <double complex*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef double* rwork = <double*>(out[6])
if a_out != a_in:
memcpy(a_out, a_in, m * n * sizeof(double complex))
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
cdef int lda = m
cdef int ldu = m
cdef int ldvt = n
if job_opt_full_matrices == 0:
ldvt = min(m, n)
# First perform a workspace query to get the optimal lwork
# NB: We perform a workspace query with malloc and free for the work array,
# because it is officially recommended in the LAPACK documentation
cdef double complex wkopt = 0
cdef int lwork = -1
zgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, &wkopt, &lwork, rwork, iwork, info)
lwork = <int>(wkopt.real)
# Now get the actual SVD
cdef double complex* work = <double complex*> malloc(lwork * sizeof(double complex))
zgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, iwork, info)
free(work)
register_cpu_custom_call_target(b"lapack_zgesdd", <void*>(lapack_zgesdd))
def jax_gesdd(c, a, full_matrices=True, compute_uv=True):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
m, n = a_shape.dimensions()
if dtype == np.float32:
fn = b"lapack_sgesdd"
singular_vals_dtype = np.float32
workspace = (Shape.array_shape(np.dtype(np.int32),
(gesdd_iwork_size(m, n),), (0,)),)
elif dtype == np.float64:
fn = b"lapack_dgesdd"
singular_vals_dtype = np.float64
workspace = (Shape.array_shape(np.dtype(np.int32),
(gesdd_iwork_size(m, n),), (0,)),)
elif dtype == np.complex64:
fn = b"lapack_cgesdd"
singular_vals_dtype = np.float32
workspace = (Shape.array_shape(np.dtype(np.int32),
(gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(np.dtype(np.float32),
(cgesdd_rwork_size(m, n, int(compute_uv)),),
(0,)))
elif dtype == np.complex128:
fn = b"lapack_zgesdd"
singular_vals_dtype = np.float64
workspace = (Shape.array_shape(np.dtype(np.int32),
(gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(np.dtype(np.float64),
(cgesdd_rwork_size(m, n, int(compute_uv)),),
(0,)))
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
fn,
operands=(c.ConstantS32Scalar(int(full_matrices)), c.ConstantS32Scalar(int(compute_uv)),
c.ConstantS32Scalar(m), c.ConstantS32Scalar(n), a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, (m, n), (0, 1)),
Shape.array_shape(np.dtype(singular_vals_dtype), (min(m, n),), (0,)),
Shape.array_shape(dtype, (m, m if full_matrices else min(m, n)), (0, 1)),
Shape.array_shape(dtype, (n if full_matrices else min(m, n), n), (0, 1)),
Shape.array_shape(np.dtype(np.int32), (), ())) + workspace
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (m, n), (0, 1)),
))
return c.Tuple(c.GetTupleElement(out, 1), c.GetTupleElement(out, 2),
c.GetTupleElement(out, 3), c.GetTupleElement(out, 4))
# syevd: Symmetric eigendecomposition
# Workspace sizes, taken from the LAPACK documentation.
cdef int syevd_work_size(int n) nogil:
return 1 + 6 * n + 2 * n * n
cdef int syevd_iwork_size(int n) nogil:
return 3 + 5 * n
cdef void lapack_ssyevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* w_out = <float*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef float* work = <float*>(out[3])
cdef int* iwork = <int*>(out[4])
if a_out != a_in:
memcpy(a_out, a_in, b * n * n * sizeof(float))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = syevd_work_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
ssyevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork,
info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_ssyevd", <void*>(lapack_ssyevd))
cdef void lapack_dsyevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* w_out = <double*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef double* work = <double*>(out[3])
cdef int* iwork = <int*>(out[4])
if a_out != a_in:
memcpy(a_out, a_in, b * n * n * sizeof(double))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = syevd_work_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
dsyevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork,
info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_dsyevd", <void*>(lapack_dsyevd))
# Workspace sizes, taken from the LAPACK documentation.
cdef int heevd_work_size(int n) nogil:
return 1 + 2 * n + n * n
cdef int heevd_rwork_size(int n) nogil:
return 1 + 5 * n + 2 * n * n
cdef void lapack_cheevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float* w_out = <float*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef float complex* work = <float complex*>(out[3])
cdef float* rwork = <float*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in, b * n * n * sizeof(float complex))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = heevd_work_size(n)
cdef int lrwork = heevd_rwork_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
cheevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork,
iwork, &liwork, info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_cheevd", <void*>(lapack_cheevd))
cdef void lapack_zheevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double* w_out = <double*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef double complex* work = <double complex*>(out[3])
cdef double* rwork = <double*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in, b * n * n * sizeof(double complex))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = heevd_work_size(n)
cdef int lrwork = heevd_rwork_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
zheevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork,
iwork, &liwork, info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_zheevd", <void*>(lapack_zheevd))
def jax_syevd(c, a, lower=False):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if dtype == np.float32:
fn = b"lapack_ssyevd"
eigvals_type = np.float32
workspace = (Shape.array_shape(dtype, (syevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.float64:
fn = b"lapack_dsyevd"
eigvals_type = np.float64
workspace = (Shape.array_shape(dtype, (syevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.complex64:
fn = b"lapack_cheevd"
eigvals_type = np.float32
workspace = (Shape.array_shape(dtype, (heevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.float32),
(heevd_rwork_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.complex128:
fn = b"lapack_zheevd"
eigvals_type = np.float64
workspace = (Shape.array_shape(dtype, (heevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.float64),
(heevd_rwork_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
fn,
operands=(c.ConstantS32Scalar(1 if lower else 0),
c.ConstantS32Scalar(b),
c.ConstantS32Scalar(n),
a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, dims, layout),
Shape.array_shape(np.dtype(eigvals_type), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))))
+ workspace
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
return c.Tuple(c.GetTupleElement(out, 0), c.GetTupleElement(out, 1),
c.GetTupleElement(out, 2))
# geev: Nonsymmetric eigendecomposition
# LAPACK uses a packed representation to represent a mixture of real
# eigenvectors and complex conjugate pairs. This helper unpacks the
# representation into regular complex matrices.
cdef void _unpack_float_eigenvectors(
int n, const float* im_eigenvalues, const float* packed,
float complex* unpacked) nogil:
cdef float re, im
cdef int j, k
j = 0
while j < n:
if im_eigenvalues[j] == 0.:
for k in range(n):
unpacked[j*n + k].real = packed[j*n + k]
unpacked[j*n + k].imag = 0.
j += 1
else:
for k in range(n):
re = packed[j*n + k]
im = packed[(j+1)*n + k]
unpacked[j*n + k].real = unpacked[(j + 1)*n + k].real = re
unpacked[j*n + k].imag = im
unpacked[(j + 1)*n + k].imag = -im
j += 2
cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float* a_in = <float*>(data[2])
cdef void** out = <void**>(out_tuple)
cdef float* a_work = <float*>(out[0])
cdef float* vl_work = <float*>(out[1])
cdef float* vr_work = <float*>(out[2])
cdef float* wr_out = <float*>(out[3])
cdef float* wi_out = <float*>(out[4])
cdef float complex* vl_out = <float complex*>(out[5])
cdef float complex* vr_out = <float complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef char jobvlr = 'V'
cdef float work_query
cdef int lwork = -1
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef float* work = <float*> malloc(lwork * sizeof(float))
for i in range(b):
memcpy(a_work, a_in, n * n * sizeof(float))
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
_unpack_float_eigenvectors(n, wi_out, vl_work, vl_out)
_unpack_float_eigenvectors(n, wi_out, vr_work, vr_out)
a_in += n * n
wr_out += n
wi_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_sgeev", <void*>(lapack_sgeev))
cdef void _unpack_double_eigenvectors(
int n, const double* im_eigenvalues, const double* packed,
double complex* unpacked) nogil:
cdef double re, im
cdef int j, k
j = 0
while j < n:
if im_eigenvalues[j] == 0.:
for k in range(n):
unpacked[j*n + k].real = packed[j*n + k]
unpacked[j*n + k].imag = 0.
j += 1
else:
for k in range(n):
re = packed[j*n + k]
im = packed[(j+1)*n + k]
unpacked[j*n + k].real = unpacked[(j + 1)*n + k].real = re
unpacked[j*n + k].imag = im
unpacked[(j + 1)*n + k].imag = -im
j += 2
cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double* a_in = <double*>(data[2])
cdef void** out = <void**>(out_tuple)
cdef double* a_work = <double*>(out[0])
cdef double* vl_work = <double*>(out[1])
cdef double* vr_work = <double*>(out[2])
cdef double* wr_out = <double*>(out[3])
cdef double* wi_out = <double*>(out[4])
cdef double complex* vl_out = <double complex*>(out[5])
cdef double complex* vr_out = <double complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef char jobvlr = 'V'
cdef double work_query
cdef int lwork = -1
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef double* work = <double*> malloc(lwork * sizeof(double))
for i in range(b):
memcpy(a_work, a_in, n * n * sizeof(double))
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
_unpack_double_eigenvectors(n, wi_out, vl_work, vl_out)
_unpack_double_eigenvectors(n, wi_out, vr_work, vr_out)
a_in += n * n
wr_out += n
wi_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_dgeev", <void*>(lapack_dgeev))
cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float complex* a_in = <float complex*>(data[2])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_work = <float complex*>(out[0])
cdef float* r_work = <float*>(out[1])
cdef float complex* w_out = <float complex*>(out[2])
cdef float complex* vl_out = <float complex*>(out[3])
cdef float complex* vr_out = <float complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef char jobvlr = 'V'
cdef float complex work_query
cdef int lwork = -1
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef float complex* work = <float complex*>malloc(
lwork * sizeof(float complex))
for i in range(b):
memcpy(a_work, a_in, n * n * sizeof(float complex))
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
w_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_cgeev", <void*>(lapack_cgeev))
cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double complex* a_in = <double complex*>(data[2])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_work = <double complex*>(out[0])
cdef double* r_work = <double*>(out[1])
cdef double complex* w_out = <double complex*>(out[2])
cdef double complex* vl_out = <double complex*>(out[3])
cdef double complex* vr_out = <double complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef char jobvlr = 'V'
cdef double complex work_query
cdef int lwork = -1
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef double complex* work = <double complex*>malloc(
lwork * sizeof(double complex))
for i in range(b):
memcpy(a_work, a_in, n * n * sizeof(double complex))
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
w_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
def jax_geev(c, a):
assert sizeof(int32_t) == sizeof(int)
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
eigvecs_type = np.complex64
workspaces = (Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)))
eigvals = (Shape.array_shape(np.dtype(np.float32), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.float32), batch_dims + (n,),
tuple(range(num_bd, -1, -1))))
elif dtype == np.float64:
fn = b"lapack_dgeev"
real = True
eigvecs_type = np.complex128
workspaces = (Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)))
eigvals = (Shape.array_shape(np.dtype(np.float64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.float64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))))
elif dtype == np.complex64:
fn = b"lapack_cgeev"
real = False
eigvecs_type = np.complex64
workspaces = (Shape.array_shape(np.dtype(np.complex64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (2 * n,), (0,)))
eigvals = (Shape.array_shape(np.dtype(np.complex64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),)
elif dtype == np.complex128:
fn = b"lapack_zgeev"
real = False
eigvecs_type = np.complex128
workspaces = (Shape.array_shape(np.dtype(np.complex128), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (2 * n,), (0,)))
eigvals = (Shape.array_shape(np.dtype(np.complex128), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = c.CustomCall(
fn,
operands=(c.ConstantS32Scalar(b), c.ConstantS32Scalar(n), a),
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))))
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
if real:
return c.Tuple(
c.Complex(c.GetTupleElement(out, 3), c.GetTupleElement(out, 4)),
c.GetTupleElement(out, 5), c.GetTupleElement(out, 6),
c.GetTupleElement(out, 7))
else:
return c.Tuple(
c.GetTupleElement(out, 2), c.GetTupleElement(out, 3),
c.GetTupleElement(out, 4), c.GetTupleElement(out, 5))