mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
added batching to LAPACK triangular_solve (#1985)
* Added batching to cpu triangular_solver * addressed comments about int overflows and returned triangular solve to use XLA over LAPACK * add todo to benchmark LAPACK vs XLA
This commit is contained in:
parent
64bf55dc6f
commit
dcda87d0e7
@ -404,6 +404,7 @@ def _triangular_solve_cpu_translation_rule(
|
||||
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
||||
shape = c.GetShape(a)
|
||||
dtype = shape.element_type().type
|
||||
|
||||
if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
|
||||
if conjugate_a and not transpose_a:
|
||||
a = c.Conj(a)
|
||||
@ -412,9 +413,8 @@ def _triangular_solve_cpu_translation_rule(
|
||||
c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
|
||||
transpose_a, conjugate_a, unit_diagonal)
|
||||
else:
|
||||
# Fall back to the HLO implementation for batched triangular_solve or
|
||||
# unsupported types.
|
||||
# TODO(phawkins): support BLAS primitives in batched mode.
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
# TODO: Consider swapping XLA for LAPACK in batched case
|
||||
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal)
|
||||
|
||||
|
@ -59,13 +59,14 @@ cdef void blas_strsm(void* out, void** data) nogil:
|
||||
cdef int32_t diag = (<int32_t*>(data[3]))[0]
|
||||
cdef int m = (<int32_t*>(data[4]))[0]
|
||||
cdef int n = (<int32_t*>(data[5]))[0]
|
||||
cdef float* alpha = <float*>(data[6])
|
||||
cdef float* a = <float*>(data[7])
|
||||
cdef float* b = <float*>(data[8])
|
||||
cdef int batch = (<int32_t*>(data[6]))[0]
|
||||
cdef float* alpha = <float*>(data[7])
|
||||
cdef float* a = <float*>(data[8])
|
||||
cdef float* b = <float*>(data[9])
|
||||
|
||||
cdef float* x = <float*>(out)
|
||||
if x != b:
|
||||
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(float))
|
||||
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
|
||||
|
||||
cdef char cside = 'L' if left_side else 'R'
|
||||
cdef char cuplo = 'L' if lower else 'U'
|
||||
@ -77,7 +78,14 @@ cdef void blas_strsm(void* out, void** data) nogil:
|
||||
cdef char cdiag = 'U' if diag else 'N'
|
||||
cdef int lda = m if left_side else n
|
||||
cdef int ldb = m
|
||||
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
|
||||
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
|
||||
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
|
||||
|
||||
for _ in range(batch):
|
||||
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
x += x_plus
|
||||
a += a_plus
|
||||
|
||||
register_cpu_custom_call_target(b"blas_strsm", <void*>(blas_strsm))
|
||||
|
||||
@ -88,13 +96,14 @@ cdef void blas_dtrsm(void* out, void** data) nogil:
|
||||
cdef int32_t diag = (<int32_t*>(data[3]))[0]
|
||||
cdef int m = (<int32_t*>(data[4]))[0]
|
||||
cdef int n = (<int32_t*>(data[5]))[0]
|
||||
cdef double* alpha = <double*>(data[6])
|
||||
cdef double* a = <double*>(data[7])
|
||||
cdef double* b = <double*>(data[8])
|
||||
cdef int batch = (<int32_t*>(data[6]))[0]
|
||||
cdef double* alpha = <double*>(data[7])
|
||||
cdef double* a = <double*>(data[8])
|
||||
cdef double* b = <double*>(data[9])
|
||||
|
||||
cdef double* x = <double*>(out)
|
||||
if x != b:
|
||||
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(double))
|
||||
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
|
||||
|
||||
cdef char cside = 'L' if left_side else 'R'
|
||||
cdef char cuplo = 'L' if lower else 'U'
|
||||
@ -106,7 +115,15 @@ cdef void blas_dtrsm(void* out, void** data) nogil:
|
||||
cdef char cdiag = 'U' if diag else 'N'
|
||||
cdef int lda = m if left_side else n
|
||||
cdef int ldb = m
|
||||
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
|
||||
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
|
||||
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
|
||||
|
||||
for _ in range(batch):
|
||||
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
x += x_plus
|
||||
a += a_plus
|
||||
|
||||
|
||||
register_cpu_custom_call_target(b"blas_dtrsm", <void*>(blas_dtrsm))
|
||||
|
||||
@ -118,13 +135,14 @@ cdef void blas_ctrsm(void* out, void** data) nogil:
|
||||
cdef int32_t diag = (<int32_t*>(data[3]))[0]
|
||||
cdef int m = (<int32_t*>(data[4]))[0]
|
||||
cdef int n = (<int32_t*>(data[5]))[0]
|
||||
cdef float complex* alpha = <float complex*>(data[6])
|
||||
cdef float complex* a = <float complex*>(data[7])
|
||||
cdef float complex* b = <float complex*>(data[8])
|
||||
cdef int batch = (<int32_t*>(data[6]))[0]
|
||||
cdef float complex* alpha = <float complex*>(data[7])
|
||||
cdef float complex* a = <float complex*>(data[8])
|
||||
cdef float complex* b = <float complex*>(data[9])
|
||||
|
||||
cdef float complex* x = <float complex*>(out)
|
||||
if x != b:
|
||||
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
|
||||
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
|
||||
|
||||
cdef char cside = 'L' if left_side else 'R'
|
||||
cdef char cuplo = 'L' if lower else 'U'
|
||||
@ -136,7 +154,15 @@ cdef void blas_ctrsm(void* out, void** data) nogil:
|
||||
cdef char cdiag = 'U' if diag else 'N'
|
||||
cdef int lda = m if left_side else n
|
||||
cdef int ldb = m
|
||||
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
|
||||
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
|
||||
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
|
||||
|
||||
for _ in range(batch):
|
||||
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
x += x_plus
|
||||
a += a_plus
|
||||
|
||||
|
||||
register_cpu_custom_call_target(b"blas_ctrsm", <void*>(blas_ctrsm))
|
||||
|
||||
@ -147,13 +173,14 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
|
||||
cdef int32_t diag = (<int32_t*>(data[3]))[0]
|
||||
cdef int m = (<int32_t*>(data[4]))[0]
|
||||
cdef int n = (<int32_t*>(data[5]))[0]
|
||||
cdef double complex* alpha = <double complex*>(data[6])
|
||||
cdef double complex* a = <double complex*>(data[7])
|
||||
cdef double complex* b = <double complex*>(data[8])
|
||||
cdef int batch = (<int32_t*>(data[6]))[0]
|
||||
cdef double complex* alpha = <double complex*>(data[7])
|
||||
cdef double complex* a = <double complex*>(data[8])
|
||||
cdef double complex* b = <double complex*>(data[9])
|
||||
|
||||
cdef double complex* x = <double complex*>(out)
|
||||
if x != b:
|
||||
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
|
||||
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
|
||||
|
||||
cdef char cside = 'L' if left_side else 'R'
|
||||
cdef char cuplo = 'L' if lower else 'U'
|
||||
@ -165,20 +192,36 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
|
||||
cdef char cdiag = 'U' if diag else 'N'
|
||||
cdef int lda = m if left_side else n
|
||||
cdef int ldb = m
|
||||
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
|
||||
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
|
||||
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
|
||||
|
||||
for _ in range(batch):
|
||||
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
|
||||
x += x_plus
|
||||
a += a_plus
|
||||
|
||||
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
|
||||
|
||||
|
||||
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
conj_a=False, diag=False):
|
||||
a_shape = c.GetShape(a)
|
||||
b_shape = c.GetShape(b)
|
||||
dtype = b_shape.element_type()
|
||||
m, n = b_shape.dimensions()
|
||||
|
||||
dims = b_shape.dimensions()
|
||||
|
||||
m, n = dims[-2:]
|
||||
k = m if left_side else n
|
||||
|
||||
a_shape = c.GetShape(a)
|
||||
if (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
num_b = 1
|
||||
for d in batch_dims:
|
||||
num_b *= d
|
||||
|
||||
if batch_dims + (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
|
||||
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
|
||||
a_shape, b_shape))
|
||||
|
||||
@ -196,6 +239,7 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
if conj_a and not trans_a:
|
||||
raise NotImplementedError("Conjugation without transposition not supported")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
return c.CustomCall(
|
||||
fn,
|
||||
operands=(
|
||||
@ -205,8 +249,9 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
c.ConstantS32Scalar(int(diag)),
|
||||
c.ConstantS32Scalar(m),
|
||||
c.ConstantS32Scalar(n),
|
||||
c.ConstantS32Scalar(num_b),
|
||||
alpha, a, b),
|
||||
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
|
||||
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
||||
operand_shapes_with_layout=(
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
@ -214,9 +259,10 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(np.dtype(np.int32), (), ()),
|
||||
Shape.array_shape(dtype, (), ()),
|
||||
Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)),
|
||||
Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
|
||||
Shape.array_shape(dtype, a_shape.dimensions(), layout),
|
||||
Shape.array_shape(dtype, b_shape.dimensions(), layout),
|
||||
))
|
||||
jax_trsm = trsm
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user