rocm_jax/jaxlib/lapack.pyx
Peter Hawkins 7f4e115a6a [XLA:Python] Validate shapes in Python bindings to avoid crashes.
[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32.

Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib.

PiperOrigin-RevId: 337367394
2020-10-15 13:10:04 -07:00

1815 lines
60 KiB
Cython

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: language_level=2
# distutils: language = c++
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCallWithLayout.
from __future__ import print_function
cdef extern from "<cmath>" namespace "std":
bint isnan(float x) nogil
bint isnan(double x) nogil
from libc.stdlib cimport malloc, free
from libc.stdint cimport int32_t, int64_t, uint8_t
from libc.string cimport memcpy
from libcpp cimport bool as bool_t
from libcpp.string cimport string
from cpython.pycapsule cimport PyCapsule_New
from scipy.linalg.cython_blas cimport strsm, dtrsm, ctrsm, ztrsm
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, zgetrf
from scipy.linalg.cython_lapack cimport sgeqrf, dgeqrf, cgeqrf, zgeqrf
from scipy.linalg.cython_lapack cimport sorgqr, dorgqr, cungqr, zungqr
from scipy.linalg.cython_lapack cimport spotrf, dpotrf, cpotrf, zpotrf
from scipy.linalg.cython_lapack cimport sgesdd, dgesdd, cgesdd, zgesdd
from scipy.linalg.cython_lapack cimport ssyevd, dsyevd, cheevd, zheevd
from scipy.linalg.cython_lapack cimport sgeev, dgeev, cgeev, zgeev
import numpy as np
from jaxlib import xla_client
_ops = xla_client.ops
Shape = xla_client.Shape
cdef int _int32_max = 0x7FFFFFFF;
cdef register_cpu_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
xla_client.register_custom_call_target(fn_name, PyCapsule_New(fn, name, NULL))
def _constant_s32_scalar(c, x):
return _ops.Constant(c, np.int32(x))
# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
cdef void blas_strsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float* alpha = <float*>(data[7])
cdef float* a = <float*>(data[8])
cdef float* b = <float*>(data[9])
cdef float* x = <float*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_strsm", <void*>(blas_strsm))
cdef void blas_dtrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double* alpha = <double*>(data[7])
cdef double* a = <double*>(data[8])
cdef double* b = <double*>(data[9])
cdef double* x = <double*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_dtrsm", <void*>(blas_dtrsm))
cdef void blas_ctrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float complex* alpha = <float complex*>(data[7])
cdef float complex* a = <float complex*>(data[8])
cdef float complex* b = <float complex*>(data[9])
cdef float complex* x = <float complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ctrsm", <void*>(blas_ctrsm))
cdef void blas_ztrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double complex* alpha = <double complex*>(data[7])
cdef double complex* a = <double complex*>(data[8])
cdef double complex* b = <double complex*>(data[9])
cdef double complex* x = <double complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
c = _unpack_builder(c)
a_shape = c.get_shape(a)
b_shape = c.get_shape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
m, n = dims[-2:]
k = m if left_side else n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
num_b = 1
for d in batch_dims:
num_b *= d
if batch_dims + (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
if dtype == np.float32:
fn = b"blas_strsm"
elif dtype == np.float64:
fn = b"blas_dtrsm"
elif dtype == np.complex64:
fn = b"blas_ctrsm"
elif dtype == np.complex128:
fn = b"blas_ztrsm"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
return _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, int(left_side)),
_constant_s32_scalar(c, int(lower)),
_constant_s32_scalar(c, (2 if conj_a else 1) if trans_a else 0),
_constant_s32_scalar(c, int(diag)),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, num_b),
alpha, a, b),
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), layout),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (), ()),
Shape.array_shape(dtype, a_shape.dimensions(), layout),
Shape.array_shape(dtype, b_shape.dimensions(), layout),
))
jax_trsm = trsm
# ?getrf: LU decomposition
cdef void lapack_sgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_sgetrf", <void*>(lapack_sgetrf))
cdef void lapack_dgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_dgetrf", <void*>(lapack_dgetrf))
cdef void lapack_cgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_cgetrf", <void*>(lapack_cgetrf))
cdef void lapack_zgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_zgetrf", <void*>(lapack_zgetrf))
def getrf(c, a):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgetrf"
elif dtype == np.float64:
fn = b"lapack_dgetrf"
elif dtype == np.complex64:
fn = b"lapack_cgetrf"
elif dtype == np.complex128:
fn = b"lapack_zgetrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
np.dtype(np.int32),
batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
# ?geqrf: QR decomposition
cdef int lapack_sgeqrf_workspace(int m, int n):
cdef float work
cdef int lwork = -1
cdef int info
sgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_sgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const float* a_in = <float*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* tau = <float*>(out[1])
cdef int* info = <int*>(out[2])
cdef float* work = <float*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_sgeqrf", <void*>(lapack_sgeqrf))
cdef int lapack_dgeqrf_workspace(int m, int n):
cdef double work
cdef int lwork = -1
cdef int info
dgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_dgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const double* a_in = <double*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* tau = <double*>(out[1])
cdef int* info = <int*>(out[2])
cdef double* work = <double*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_dgeqrf", <void*>(lapack_dgeqrf))
cdef int lapack_cgeqrf_workspace(int m, int n):
cdef float complex work
cdef int lwork = -1
cdef int info
cgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_cgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const float complex* a_in = <float complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float complex* tau = <float complex*>(out[1])
cdef int* info = <int*>(out[2])
cdef float complex* work = <float complex*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_cgeqrf", <void*>(lapack_cgeqrf))
cdef int lapack_zgeqrf_workspace(int m, int n):
cdef double complex work
cdef int lwork = -1
cdef int info
zgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_zgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const double complex* a_in = <double complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double complex* tau = <double complex*>(out[1])
cdef int* info = <int*>(out[2])
cdef double complex* work = <double complex*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_zgeqrf", <void*>(lapack_zgeqrf))
def geqrf(c, a):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgeqrf"
lwork = lapack_sgeqrf_workspace(m, n)
elif dtype == np.float64:
fn = b"lapack_dgeqrf"
lwork = lapack_dgeqrf_workspace(m, n)
elif dtype == np.complex64:
fn = b"lapack_cgeqrf"
lwork = lapack_cgeqrf_workspace(m, n)
elif dtype == np.complex128:
fn = b"lapack_zgeqrf"
lwork = lapack_zgeqrf_workspace(m, n)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, lwork),
a,
),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
np.dtype(dtype),
batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(np.dtype(dtype), (lwork,), (0,)),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
# ?orgqr: product of elementary Householder reflectors:
cdef int lapack_sorgqr_workspace(int m, int n, int k):
cdef float work
cdef int lwork = -1
cdef int info
sorgqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_sorgqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const float* a_in = <float*>(data[5])
cdef float* tau = <float*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* info = <int*>(out[1])
cdef float* work = <float*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sorgqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_sorgqr", <void*>(lapack_sorgqr))
cdef int lapack_dorgqr_workspace(int m, int n, int k):
cdef double work
cdef int lwork = -1
cdef int info
dorgqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_dorgqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const double* a_in = <double*>(data[5])
cdef double* tau = <double*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* info = <int*>(out[1])
cdef double* work = <double*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dorgqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_dorgqr", <void*>(lapack_dorgqr))
cdef int lapack_cungqr_workspace(int m, int n, int k):
cdef float complex work
cdef int lwork = -1
cdef int info
cungqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_cungqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const float complex* a_in = <float complex*>(data[5])
cdef float complex* tau = <float complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* info = <int*>(out[1])
cdef float complex* work = <float complex*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cungqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_cungqr", <void*>(lapack_cungqr))
cdef int lapack_zungqr_workspace(int m, int n, int k):
cdef double complex work
cdef int lwork = -1
cdef int info
zungqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_zungqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const double complex* a_in = <double complex*>(data[5])
cdef double complex* tau = <double complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* info = <int*>(out[1])
cdef double complex* work = <double complex*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zungqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_zungqr", <void*>(lapack_zungqr))
def orgqr(c, a, tau):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
tau_dims = c.get_shape(tau).dimensions()
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]
if dtype == np.float32:
fn = b"lapack_sorgqr"
lwork = lapack_sorgqr_workspace(m, n, k)
elif dtype == np.float64:
fn = b"lapack_dorgqr"
lwork = lapack_dorgqr_workspace(m, n, k)
elif dtype == np.complex64:
fn = b"lapack_cungqr"
lwork = lapack_cungqr_workspace(m, n, k)
elif dtype == np.complex128:
fn = b"lapack_zungqr"
lwork = lapack_zungqr_workspace(m, n, k)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, k),
_constant_s32_scalar(c, lwork),
a,
tau,
),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(dtype, (lwork,), (0,)),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
dtype,
batch_dims + (k,),
tuple(range(num_bd, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
# ?potrf: Cholesky decomposition
cdef void lapack_spotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float))
for i in range(b):
spotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_spotrf", <void*>(lapack_spotrf))
cdef void lapack_dpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double))
for i in range(b):
dpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_dpotrf", <void*>(lapack_dpotrf))
cdef void lapack_cpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_cpotrf", <void*>(lapack_cpotrf))
cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))
def potrf(c, a, lower=False):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
m, n = dims[-2:]
if m != n:
raise ValueError("potrf expects a square matrix, got {}".format(a_shape))
if dtype == np.float32:
fn = b"lapack_spotrf"
elif dtype == np.float64:
fn = b"lapack_dpotrf"
elif dtype == np.complex64:
fn = b"lapack_cpotrf"
elif dtype == np.complex128:
fn = b"lapack_zpotrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, int(lower)),
_constant_s32_scalar(c, b), _constant_s32_scalar(c, n), a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, dims, layout),
Shape.array_shape(
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
# ?gesdd: Singular value decomposition
cdef int gesdd_iwork_size(int64_t m, int64_t n) nogil:
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max, 8 * min(m, n))
cdef int cgesdd_rwork_size(int64_t m, int64_t n, int compute_uv) nogil:
cdef int64_t mn = min(m, n)
if compute_uv == 0:
return 7 * mn
cdef int64_t mx = max(m, n)
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max,
max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn))
cdef char gesdd_jobz(bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices) nogil:
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
return jobz
cdef int sgesdd_work_size(int m, int n, bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices):
cdef float work
cdef int lwork = -1
cdef int info
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
sgesdd(&jobz, &m, &n, NULL, &m, NULL, NULL, &m, NULL, &ldvt, &work,
&lwork, NULL, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_sgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int b = (<int32_t*>(data[2]))[0]
cdef int m = (<int32_t*>(data[3]))[0]
cdef int n = (<int32_t*>(data[4]))[0]
cdef int lwork = (<int32_t*>(data[5]))[0]
cdef float* a_in = <float*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* s = <float*>(out[1])
cdef float* u = <float*>(out[2])
cdef float* vt = <float*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef float* work = <float*>(out[6])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
cdef int lda = m
cdef int ldu = m
cdef int tdu = min(m, n) if job_opt_full_matrices == 0 else m
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
for i in range(b):
sgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork,
iwork, info)
a_out += m * n
s += min(m, n)
u += m * tdu
vt += ldvt * n
info += 1
register_cpu_custom_call_target(b"lapack_sgesdd", <void*>(lapack_sgesdd))
cdef int dgesdd_work_size(int m, int n, bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices):
cdef double work
cdef int lwork = -1
cdef int info
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
dgesdd(&jobz, &m, &n, NULL, &m, NULL, NULL, &m, NULL, &ldvt, &work,
&lwork, NULL, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_dgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int b = (<int32_t*>(data[2]))[0]
cdef int m = (<int32_t*>(data[3]))[0]
cdef int n = (<int32_t*>(data[4]))[0]
cdef int lwork = (<int32_t*>(data[5]))[0]
cdef double* a_in = <double*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* s = <double*>(out[1])
cdef double* u = <double*>(out[2])
cdef double* vt = <double*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef double* work = <double*>(out[6])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
cdef int lda = m
cdef int ldu = m
cdef int tdu = min(m, n) if job_opt_full_matrices == 0 else m
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
for i in range(b):
dgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork,
iwork, info)
a_out += m * n
s += min(m, n)
u += m * tdu
vt += ldvt * n
info += 1
register_cpu_custom_call_target(b"lapack_dgesdd", <void*>(lapack_dgesdd))
cdef int cgesdd_work_size(int m, int n, bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices):
cdef float complex work
cdef int lwork = -1
cdef int info
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
cgesdd(&jobz, &m, &n, NULL, &m, NULL, NULL, &m, NULL, &ldvt, &work,
&lwork, NULL, NULL, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_cgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int b = (<int32_t*>(data[2]))[0]
cdef int m = (<int32_t*>(data[3]))[0]
cdef int n = (<int32_t*>(data[4]))[0]
cdef int lwork = (<int32_t*>(data[5]))[0]
cdef float complex* a_in = <float complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float* s = <float*>(out[1])
cdef float complex* u = <float complex*>(out[2])
cdef float complex* vt = <float complex*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef float* rwork = <float*>(out[6])
cdef float complex* work = <float complex*>(out[7])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
cdef int lda = m
cdef int ldu = m
cdef int tdu = min(m, n) if job_opt_full_matrices == 0 else m
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
for i in range(b):
cgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork,
rwork, iwork, info)
a_out += m * n
s += min(m, n)
u += m * tdu
vt += ldvt * n
info += 1
register_cpu_custom_call_target(b"lapack_cgesdd", <void*>(lapack_cgesdd))
cdef int zgesdd_work_size(int m, int n, bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices):
cdef double complex work
cdef int lwork = -1
cdef int info
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
zgesdd(&jobz, &m, &n, NULL, &m, NULL, NULL, &m, NULL, &ldvt, &work,
&lwork, NULL, NULL, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_zgesdd(void* out_tuple, void** data) nogil:
cdef int32_t job_opt_full_matrices = (<int32_t*>(data[0]))[0]
cdef int32_t job_opt_compute_uv = (<int32_t*>(data[1]))[0]
cdef int b = (<int32_t*>(data[2]))[0]
cdef int m = (<int32_t*>(data[3]))[0]
cdef int n = (<int32_t*>(data[4]))[0]
cdef int lwork = (<int32_t*>(data[5]))[0]
cdef double complex* a_in = <double complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double* s = <double*>(out[1])
cdef double complex* u = <double complex*>(out[2])
cdef double complex* vt = <double complex*>(out[3])
cdef int* info = <int*>(out[4])
cdef int* iwork = <int*>(out[5])
cdef double* rwork = <double*>(out[6])
cdef double complex* work = <double complex*>(out[7])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
cdef int lda = m
cdef int ldu = m
cdef int tdu = min(m, n) if job_opt_full_matrices == 0 else m
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
for i in range(b):
zgesdd(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork,
rwork, iwork, info)
a_out += m * n
s += min(m, n)
u += m * tdu
vt += ldvt * n
info += 1
register_cpu_custom_call_target(b"lapack_zgesdd", <void*>(lapack_zgesdd))
def gesdd(c, a, full_matrices=True, compute_uv=True):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgesdd"
singular_vals_dtype = np.float32
lwork = sgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = (
Shape.array_shape(np.dtype(np.int32), (gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(dtype, (lwork,), (0,)),
)
elif dtype == np.float64:
fn = b"lapack_dgesdd"
singular_vals_dtype = np.float64
lwork = dgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = (
Shape.array_shape(np.dtype(np.int32), (gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(dtype, (lwork,), (0,)),
)
elif dtype == np.complex64:
fn = b"lapack_cgesdd"
singular_vals_dtype = np.float32
lwork = cgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = (
Shape.array_shape(np.dtype(np.int32), (gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(np.dtype(np.float32),
(cgesdd_rwork_size(m, n, int(compute_uv)),), (0,)),
Shape.array_shape(dtype, (lwork,), (0,)),
)
elif dtype == np.complex128:
fn = b"lapack_zgesdd"
singular_vals_dtype = np.float64
lwork = zgesdd_work_size(m, n, compute_uv, full_matrices)
workspace = (
Shape.array_shape(np.dtype(np.int32), (gesdd_iwork_size(m, n),), (0,)),
Shape.array_shape(np.dtype(np.float64),
(cgesdd_rwork_size(m, n, int(compute_uv)),), (0,)),
Shape.array_shape(dtype, (lwork,), (0,)),
)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, int(full_matrices)),
_constant_s32_scalar(c, int(compute_uv)),
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m), _constant_s32_scalar(c, n),
_constant_s32_scalar(c, lwork), a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
Shape.array_shape(np.dtype(singular_vals_dtype),
batch_dims + (min(m, n),), vector_layout),
Shape.array_shape(dtype,
batch_dims + (m, m if full_matrices else min(m, n)),
matrix_layout),
Shape.array_shape(dtype,
batch_dims + (n if full_matrices else min(m, n), n),
matrix_layout),
Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
) + workspace
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
))
return (_ops.GetTupleElement(out, 1), _ops.GetTupleElement(out, 2),
_ops.GetTupleElement(out, 3), _ops.GetTupleElement(out, 4))
# syevd: Symmetric eigendecomposition
# Workspace sizes, taken from the LAPACK documentation.
cdef int syevd_work_size(int64_t n) nogil:
# Avoids int32 overflow.
return min(_int32_max, 1 + 6 * n + 2 * n * n)
cdef int syevd_iwork_size(int64_t n) nogil:
# Avoids int32 overflow.
return min(_int32_max, 3 + 5 * n)
cdef void lapack_ssyevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* w_out = <float*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef float* work = <float*>(out[3])
cdef int* iwork = <int*>(out[4])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = syevd_work_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
ssyevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork,
info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_ssyevd", <void*>(lapack_ssyevd))
cdef void lapack_dsyevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* w_out = <double*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef double* work = <double*>(out[3])
cdef int* iwork = <int*>(out[4])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = syevd_work_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
dsyevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork,
info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_dsyevd", <void*>(lapack_dsyevd))
# Workspace sizes, taken from the LAPACK documentation.
cdef int heevd_work_size(int64_t n) nogil:
# Avoid int32 overflow.
return min(_int32_max, 1 + 2 * n + n * n)
cdef int heevd_rwork_size(int64_t n) nogil:
# Avoid int32 overflow.
return min(_int32_max, 1 + 5 * n + 2 * n * n)
cdef void lapack_cheevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float* w_out = <float*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef float complex* work = <float complex*>(out[3])
cdef float* rwork = <float*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = heevd_work_size(n)
cdef int lrwork = heevd_rwork_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
cheevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork,
iwork, &liwork, info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_cheevd", <void*>(lapack_cheevd))
cdef void lapack_zheevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double* w_out = <double*>(out[1])
cdef int* info_out = <int*>(out[2])
cdef double complex* work = <double complex*>(out[3])
cdef double* rwork = <double*>(out[4])
cdef int* iwork = <int*>(out[5])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
cdef char jobz = 'V'
cdef char uplo = 'L' if lower else 'U'
cdef int lwork = heevd_work_size(n)
cdef int lrwork = heevd_rwork_size(n)
cdef int liwork = syevd_iwork_size(n)
for i in range(b):
zheevd(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork,
iwork, &liwork, info_out)
a_out += n * n
w_out += n
info_out += 1
register_cpu_custom_call_target(b"lapack_zheevd", <void*>(lapack_zheevd))
def syevd(c, a, lower=False):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if dtype == np.float32:
fn = b"lapack_ssyevd"
eigvals_type = np.float32
workspace = (Shape.array_shape(dtype, (syevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.float64:
fn = b"lapack_dsyevd"
eigvals_type = np.float64
workspace = (Shape.array_shape(dtype, (syevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.complex64:
fn = b"lapack_cheevd"
eigvals_type = np.float32
workspace = (Shape.array_shape(dtype, (heevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.float32),
(heevd_rwork_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
elif dtype == np.complex128:
fn = b"lapack_zheevd"
eigvals_type = np.float64
workspace = (Shape.array_shape(dtype, (heevd_work_size(n),), (0,)),
Shape.array_shape(np.dtype(np.float64),
(heevd_rwork_size(n),), (0,)),
Shape.array_shape(np.dtype(np.int32),
(syevd_iwork_size(n),), (0,)))
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, 1 if lower else 0),
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, n),
a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, dims, layout),
Shape.array_shape(np.dtype(eigvals_type), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))))
+ workspace
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
return (_ops.GetTupleElement(out, 0), _ops.GetTupleElement(out, 1),
_ops.GetTupleElement(out, 2))
# geev: Nonsymmetric eigendecomposition
# LAPACK uses a packed representation to represent a mixture of real
# eigenvectors and complex conjugate pairs. This helper unpacks the
# representation into regular complex matrices.
cdef void _unpack_float_eigenvectors(
int n, const float* im_eigenvalues, const float* packed,
float complex* unpacked) nogil:
cdef float re, im
cdef int j, k
j = 0
while j < n:
if im_eigenvalues[j] == 0. or isnan(im_eigenvalues[j]):
for k in range(n):
unpacked[j*n + k].real = packed[j*n + k]
unpacked[j*n + k].imag = 0.
j += 1
else:
for k in range(n):
re = packed[j*n + k]
im = packed[(j+1)*n + k]
unpacked[j*n + k].real = unpacked[(j + 1)*n + k].real = re
unpacked[j*n + k].imag = im
unpacked[(j + 1)*n + k].imag = -im
j += 2
cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const float* a_in = <float*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float* a_work = <float*>(out[0])
cdef float* vl_work = <float*>(out[1])
cdef float* vr_work = <float*>(out[2])
cdef float* wr_out = <float*>(out[3])
cdef float* wi_out = <float*>(out[4])
cdef float complex* vl_out = <float complex*>(out[5])
cdef float complex* vr_out = <float complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef float work_query
cdef int lwork = -1
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef float* work = <float*> malloc(lwork * sizeof(float))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float))
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
if info_out[0] == 0:
_unpack_float_eigenvectors(n, wi_out, vl_work, vl_out)
_unpack_float_eigenvectors(n, wi_out, vr_work, vr_out)
a_in += n * n
wr_out += n
wi_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_sgeev", <void*>(lapack_sgeev))
cdef void _unpack_double_eigenvectors(
int n, const double* im_eigenvalues, const double* packed,
double complex* unpacked) nogil:
cdef double re, im
cdef int j, k
j = 0
while j < n:
if im_eigenvalues[j] == 0. or isnan(im_eigenvalues[j]):
for k in range(n):
unpacked[j*n + k].real = packed[j*n + k]
unpacked[j*n + k].imag = 0.
j += 1
else:
for k in range(n):
re = packed[j*n + k]
im = packed[(j+1)*n + k]
unpacked[j*n + k].real = unpacked[(j + 1)*n + k].real = re
unpacked[j*n + k].imag = im
unpacked[(j + 1)*n + k].imag = -im
j += 2
cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const double* a_in = <double*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double* a_work = <double*>(out[0])
cdef double* vl_work = <double*>(out[1])
cdef double* vr_work = <double*>(out[2])
cdef double* wr_out = <double*>(out[3])
cdef double* wi_out = <double*>(out[4])
cdef double complex* vl_out = <double complex*>(out[5])
cdef double complex* vr_out = <double complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef double work_query
cdef int lwork = -1
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef double* work = <double*> malloc(lwork * sizeof(double))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double))
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
if info_out[0] == 0:
_unpack_double_eigenvectors(n, wi_out, vl_work, vl_out)
_unpack_double_eigenvectors(n, wi_out, vr_work, vr_out)
a_in += n * n
wr_out += n
wi_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_dgeev", <void*>(lapack_dgeev))
cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const float complex* a_in = <float complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_work = <float complex*>(out[0])
cdef float* r_work = <float*>(out[1])
cdef float complex* w_out = <float complex*>(out[2])
cdef float complex* vl_out = <float complex*>(out[3])
cdef float complex* vr_out = <float complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef float complex work_query
cdef int lwork = -1
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef float complex* work = <float complex*>malloc(
lwork * sizeof(float complex))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
w_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_cgeev", <void*>(lapack_cgeev))
cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const double complex* a_in = <double complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_work = <double complex*>(out[0])
cdef double* r_work = <double*>(out[1])
cdef double complex* w_out = <double complex*>(out[2])
cdef double complex* vl_out = <double complex*>(out[3])
cdef double complex* vr_out = <double complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef double complex work_query
cdef int lwork = -1
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef double complex* work = <double complex*>malloc(
lwork * sizeof(double complex))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
w_out += n
vl_out += n * n
vr_out += n * n
info_out += 1
free(work)
register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
def geev(c, a, jobvl=True, jobvr=True):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
jobvl_c = ord('V' if jobvl else 'N')
jobvr_c = ord('V' if jobvr else 'N')
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
eigvecs_type = np.complex64
workspaces = (Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (n, n), (0, 1)))
eigvals = (Shape.array_shape(np.dtype(np.float32), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.float32), batch_dims + (n,),
tuple(range(num_bd, -1, -1))))
elif dtype == np.float64:
fn = b"lapack_dgeev"
real = True
eigvecs_type = np.complex128
workspaces = (Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (n, n), (0, 1)))
eigvals = (Shape.array_shape(np.dtype(np.float64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.float64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))))
elif dtype == np.complex64:
fn = b"lapack_cgeev"
real = False
eigvecs_type = np.complex64
workspaces = (Shape.array_shape(np.dtype(np.complex64), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float32), (2 * n,), (0,)))
eigvals = (Shape.array_shape(np.dtype(np.complex64), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),)
elif dtype == np.complex128:
fn = b"lapack_zgeev"
real = False
eigvecs_type = np.complex128
workspaces = (Shape.array_shape(np.dtype(np.complex128), (n, n), (0, 1)),
Shape.array_shape(np.dtype(np.float64), (2 * n,), (0,)))
eigvals = (Shape.array_shape(np.dtype(np.complex128), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, b),
_constant_s32_scalar(c, n),
_ops.Constant(c, np.uint8(jobvl_c)),
_ops.Constant(c, np.uint8(jobvr_c)),
a),
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))))
),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(dtype, dims, layout),
))
if real:
return (_ops.Complex(_ops.GetTupleElement(out, 3),
_ops.GetTupleElement(out, 4)),
_ops.GetTupleElement(out, 5), _ops.GetTupleElement(out, 6),
_ops.GetTupleElement(out, 7))
else:
return (_ops.GetTupleElement(out, 2), _ops.GetTupleElement(out, 3),
_ops.GetTupleElement(out, 4), _ops.GetTupleElement(out, 5))