Parameterize geev/eig to allow for not computing the left/right eigenvectors (#3882)

* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

* Reformulate eig-related operations based on the new geev API.
* Change jobvl/jobvr to a new variable in jaxlib, and fix lax_linalg.eig to account for that.
* Maintain jaxlib.lapack.eig function signature backwards compatible.

The rationale is to start by only updating lapack.pyx in a way that is
backward-compatible with JAX before updating the calls to lapack.geev in
a subsequent PR.
This commit is contained in:
Benjamin Chetioui 2020-08-05 12:58:30 +02:00 committed by GitHub
parent 7065d07166
commit 98a511a2c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,7 +25,7 @@ cdef extern from "<cmath>" namespace "std":
bint isnan(double x) nogil
from libc.stdlib cimport malloc, free
from libc.stdint cimport int32_t, int64_t
from libc.stdint cimport int32_t, int64_t, uint8_t
from libc.string cimport memcpy
from libcpp cimport bool as bool_t
from libcpp.string cimport string
@ -1525,7 +1525,10 @@ cdef void _unpack_float_eigenvectors(
cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float* a_in = <float*>(data[2])
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const float* a_in = <float*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float* a_work = <float*>(out[0])
@ -1538,17 +1541,16 @@ cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
cdef float complex* vr_out = <float complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef char jobvlr = 'V'
cdef float work_query
cdef int lwork = -1
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef float* work = <float*> malloc(lwork * sizeof(float))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float))
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
if info_out[0] == 0:
_unpack_float_eigenvectors(n, wi_out, vl_work, vl_out)
@ -1589,7 +1591,10 @@ cdef void _unpack_double_eigenvectors(
cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double* a_in = <double*>(data[2])
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const double* a_in = <double*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double* a_work = <double*>(out[0])
@ -1602,17 +1607,16 @@ cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
cdef double complex* vr_out = <double complex*>(out[6])
cdef int* info_out = <int*>(out[7])
cdef char jobvlr = 'V'
cdef double work_query
cdef int lwork = -1
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, &work_query, &lwork, info_out)
lwork = <int>(work_query)
cdef double* work = <double*> malloc(lwork * sizeof(double))
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double))
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
vr_work, &n, work, &lwork, info_out)
if info_out[0] == 0:
_unpack_double_eigenvectors(n, wi_out, vl_work, vl_out)
@ -1632,7 +1636,10 @@ register_cpu_custom_call_target(b"lapack_dgeev", <void*>(lapack_dgeev))
cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float complex* a_in = <float complex*>(data[2])
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const float complex* a_in = <float complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_work = <float complex*>(out[0])
@ -1643,10 +1650,9 @@ cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
cdef float complex* vr_out = <float complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef char jobvlr = 'V'
cdef float complex work_query
cdef int lwork = -1
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef float complex* work = <float complex*>malloc(
@ -1654,7 +1660,7 @@ cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
@ -1670,7 +1676,10 @@ register_cpu_custom_call_target(b"lapack_cgeev", <void*>(lapack_cgeev))
cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const double complex* a_in = <double complex*>(data[2])
cdef char jobvl = (<uint8_t*>(data[2]))[0]
cdef char jobvr = (<uint8_t*>(data[3]))[0]
cdef const double complex* a_in = <double complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_work = <double complex*>(out[0])
@ -1681,10 +1690,9 @@ cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
cdef double complex* vr_out = <double complex*>(out[4])
cdef int* info_out = <int*>(out[5])
cdef char jobvlr = 'V'
cdef double complex work_query
cdef int lwork = -1
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
vr_out, &n, &work_query, &lwork, r_work, info_out)
lwork = <int>(work_query.real)
cdef double complex* work = <double complex*>malloc(
@ -1692,7 +1700,7 @@ cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
for i in range(b):
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
work, &lwork, r_work, info_out)
a_in += n * n
@ -1706,7 +1714,7 @@ register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
def geev(c, a):
def geev(c, a, jobvl=True, jobvr=True):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
@ -1723,6 +1731,9 @@ def geev(c, a):
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
jobvl_c = ord('V' if jobvl else 'N')
jobvr_c = ord('V' if jobvr else 'N')
if dtype == np.float32:
fn = b"lapack_sgeev"
real = True
@ -1766,7 +1777,11 @@ def geev(c, a):
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, b), _constant_s32_scalar(c, n), a),
operands=(_constant_s32_scalar(c, b),
_constant_s32_scalar(c, n),
_ops.Constant(c, np.uint8(jobvl_c)),
_ops.Constant(c, np.uint8(jobvr_c)),
a),
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
@ -1776,6 +1791,8 @@ def geev(c, a):
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(np.dtype(np.uint8), (), ()),
Shape.array_shape(dtype, dims, layout),
))
if real: