mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Parameterize geev/eig to allow for not computing the left/right eigenvectors (#3882)
* Add options to compute L/R eigenvectors in geev. The new arguments are by default set to True to ensure backwards compatibility between jaxlib and jax. * Reformulate eig-related operations based on the new geev API. * Change jobvl/jobvr to a new variable in jaxlib, and fix lax_linalg.eig to account for that. * Maintain jaxlib.lapack.eig function signature backwards compatible. The rationale is to start by only updating lapack.pyx in a way that is backward-compatible with JAX before updating the calls to lapack.geev in a subsequent PR.
This commit is contained in:
parent
7065d07166
commit
98a511a2c7
@ -25,7 +25,7 @@ cdef extern from "<cmath>" namespace "std":
|
||||
bint isnan(double x) nogil
|
||||
|
||||
from libc.stdlib cimport malloc, free
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
from libc.stdint cimport int32_t, int64_t, uint8_t
|
||||
from libc.string cimport memcpy
|
||||
from libcpp cimport bool as bool_t
|
||||
from libcpp.string cimport string
|
||||
@ -1525,7 +1525,10 @@ cdef void _unpack_float_eigenvectors(
|
||||
cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
|
||||
cdef int b = (<int32_t*>(data[0]))[0]
|
||||
cdef int n = (<int32_t*>(data[1]))[0]
|
||||
cdef const float* a_in = <float*>(data[2])
|
||||
cdef char jobvl = (<uint8_t*>(data[2]))[0]
|
||||
cdef char jobvr = (<uint8_t*>(data[3]))[0]
|
||||
|
||||
cdef const float* a_in = <float*>(data[4])
|
||||
|
||||
cdef void** out = <void**>(out_tuple)
|
||||
cdef float* a_work = <float*>(out[0])
|
||||
@ -1538,17 +1541,16 @@ cdef void lapack_sgeev(void* out_tuple, void** data) nogil:
|
||||
cdef float complex* vr_out = <float complex*>(out[6])
|
||||
cdef int* info_out = <int*>(out[7])
|
||||
|
||||
cdef char jobvlr = 'V'
|
||||
cdef float work_query
|
||||
cdef int lwork = -1
|
||||
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
vr_work, &n, &work_query, &lwork, info_out)
|
||||
lwork = <int>(work_query)
|
||||
cdef float* work = <float*> malloc(lwork * sizeof(float))
|
||||
|
||||
for i in range(b):
|
||||
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float))
|
||||
sgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
sgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
vr_work, &n, work, &lwork, info_out)
|
||||
if info_out[0] == 0:
|
||||
_unpack_float_eigenvectors(n, wi_out, vl_work, vl_out)
|
||||
@ -1589,7 +1591,10 @@ cdef void _unpack_double_eigenvectors(
|
||||
cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
|
||||
cdef int b = (<int32_t*>(data[0]))[0]
|
||||
cdef int n = (<int32_t*>(data[1]))[0]
|
||||
cdef const double* a_in = <double*>(data[2])
|
||||
cdef char jobvl = (<uint8_t*>(data[2]))[0]
|
||||
cdef char jobvr = (<uint8_t*>(data[3]))[0]
|
||||
|
||||
cdef const double* a_in = <double*>(data[4])
|
||||
|
||||
cdef void** out = <void**>(out_tuple)
|
||||
cdef double* a_work = <double*>(out[0])
|
||||
@ -1602,17 +1607,16 @@ cdef void lapack_dgeev(void* out_tuple, void** data) nogil:
|
||||
cdef double complex* vr_out = <double complex*>(out[6])
|
||||
cdef int* info_out = <int*>(out[7])
|
||||
|
||||
cdef char jobvlr = 'V'
|
||||
cdef double work_query
|
||||
cdef int lwork = -1
|
||||
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
vr_work, &n, &work_query, &lwork, info_out)
|
||||
lwork = <int>(work_query)
|
||||
cdef double* work = <double*> malloc(lwork * sizeof(double))
|
||||
|
||||
for i in range(b):
|
||||
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double))
|
||||
dgeev(&jobvlr, &jobvlr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
dgeev(&jobvl, &jobvr, &n, a_work, &n, wr_out, wi_out, vl_work, &n,
|
||||
vr_work, &n, work, &lwork, info_out)
|
||||
if info_out[0] == 0:
|
||||
_unpack_double_eigenvectors(n, wi_out, vl_work, vl_out)
|
||||
@ -1632,7 +1636,10 @@ register_cpu_custom_call_target(b"lapack_dgeev", <void*>(lapack_dgeev))
|
||||
cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
|
||||
cdef int b = (<int32_t*>(data[0]))[0]
|
||||
cdef int n = (<int32_t*>(data[1]))[0]
|
||||
cdef const float complex* a_in = <float complex*>(data[2])
|
||||
cdef char jobvl = (<uint8_t*>(data[2]))[0]
|
||||
cdef char jobvr = (<uint8_t*>(data[3]))[0]
|
||||
|
||||
cdef const float complex* a_in = <float complex*>(data[4])
|
||||
|
||||
cdef void** out = <void**>(out_tuple)
|
||||
cdef float complex* a_work = <float complex*>(out[0])
|
||||
@ -1643,10 +1650,9 @@ cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
|
||||
cdef float complex* vr_out = <float complex*>(out[4])
|
||||
cdef int* info_out = <int*>(out[5])
|
||||
|
||||
cdef char jobvlr = 'V'
|
||||
cdef float complex work_query
|
||||
cdef int lwork = -1
|
||||
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
|
||||
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
|
||||
vr_out, &n, &work_query, &lwork, r_work, info_out)
|
||||
lwork = <int>(work_query.real)
|
||||
cdef float complex* work = <float complex*>malloc(
|
||||
@ -1654,7 +1660,7 @@ cdef void lapack_cgeev(void* out_tuple, void** data) nogil:
|
||||
|
||||
for i in range(b):
|
||||
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
|
||||
cgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
|
||||
cgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
|
||||
work, &lwork, r_work, info_out)
|
||||
|
||||
a_in += n * n
|
||||
@ -1670,7 +1676,10 @@ register_cpu_custom_call_target(b"lapack_cgeev", <void*>(lapack_cgeev))
|
||||
cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
|
||||
cdef int b = (<int32_t*>(data[0]))[0]
|
||||
cdef int n = (<int32_t*>(data[1]))[0]
|
||||
cdef const double complex* a_in = <double complex*>(data[2])
|
||||
cdef char jobvl = (<uint8_t*>(data[2]))[0]
|
||||
cdef char jobvr = (<uint8_t*>(data[3]))[0]
|
||||
|
||||
cdef const double complex* a_in = <double complex*>(data[4])
|
||||
|
||||
cdef void** out = <void**>(out_tuple)
|
||||
cdef double complex* a_work = <double complex*>(out[0])
|
||||
@ -1681,10 +1690,9 @@ cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
|
||||
cdef double complex* vr_out = <double complex*>(out[4])
|
||||
cdef int* info_out = <int*>(out[5])
|
||||
|
||||
cdef char jobvlr = 'V'
|
||||
cdef double complex work_query
|
||||
cdef int lwork = -1
|
||||
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n,
|
||||
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n,
|
||||
vr_out, &n, &work_query, &lwork, r_work, info_out)
|
||||
lwork = <int>(work_query.real)
|
||||
cdef double complex* work = <double complex*>malloc(
|
||||
@ -1692,7 +1700,7 @@ cdef void lapack_zgeev(void* out_tuple, void** data) nogil:
|
||||
|
||||
for i in range(b):
|
||||
memcpy(a_work, a_in, <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
|
||||
zgeev(&jobvlr, &jobvlr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
|
||||
zgeev(&jobvl, &jobvr, &n, a_work, &n, w_out, vl_out, &n, vr_out, &n,
|
||||
work, &lwork, r_work, info_out)
|
||||
|
||||
a_in += n * n
|
||||
@ -1706,7 +1714,7 @@ register_cpu_custom_call_target(b"lapack_zgeev", <void*>(lapack_zgeev))
|
||||
|
||||
|
||||
|
||||
def geev(c, a):
|
||||
def geev(c, a, jobvl=True, jobvr=True):
|
||||
c = _unpack_builder(c)
|
||||
assert sizeof(int32_t) == sizeof(int)
|
||||
|
||||
@ -1723,6 +1731,9 @@ def geev(c, a):
|
||||
b *= d
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
jobvl_c = ord('V' if jobvl else 'N')
|
||||
jobvr_c = ord('V' if jobvr else 'N')
|
||||
|
||||
if dtype == np.float32:
|
||||
fn = b"lapack_sgeev"
|
||||
real = True
|
||||
@ -1766,7 +1777,11 @@ def geev(c, a):
|
||||
|
||||
out = _ops.CustomCallWithLayout(
|
||||
c, fn,
|
||||
operands=(_constant_s32_scalar(c, b), _constant_s32_scalar(c, n), a),
|
||||
operands=(_constant_s32_scalar(c, b),
|
||||
_constant_s32_scalar(c, n),
|
||||
_ops.Constant(c, np.uint8(jobvl_c)),
|
||||
_ops.Constant(c, np.uint8(jobvr_c)),
|
||||
a),
|
||||
shape_with_layout=Shape.tuple_shape(workspaces + eigvals + (
|
||||
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
|
||||
Shape.array_shape(np.dtype(eigvecs_type), dims, layout),
|
||||
@ -1776,6 +1791,8 @@ def geev(c, a):
|
||||
operand_shapes_with_layout=(
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(np.dtype(np.uint8), (), ()),
|
||||
Shape.array_shape(dtype, dims, layout),
|
||||
))
|
||||
if real:
|
||||
|
Loading…
x
Reference in New Issue
Block a user