added batching to LAPACK triangular_solve (#1985)

* Added batching to cpu triangular_solver

* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK

* add todo to benchmark LAPACK vs XLA
This commit is contained in:
AmKhan 2020-01-14 16:18:47 +00:00 committed by Peter Hawkins
parent 64bf55dc6f
commit dcda87d0e7
2 changed files with 75 additions and 29 deletions

View File

@ -404,6 +404,7 @@ def _triangular_solve_cpu_translation_rule(
c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
shape = c.GetShape(a)
dtype = shape.element_type().type
if len(shape.dimensions()) == 2 and onp.dtype(dtype) in _cpu_lapack_types:
if conjugate_a and not transpose_a:
a = c.Conj(a)
@ -412,9 +413,8 @@ def _triangular_solve_cpu_translation_rule(
c, c.Constant(onp.array(1, dtype=dtype)), a, b, left_side, lower,
transpose_a, conjugate_a, unit_diagonal)
else:
# Fall back to the HLO implementation for batched triangular_solve or
# unsupported types.
# TODO(phawkins): support BLAS primitives in batched mode.
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a,
unit_diagonal)

View File

@ -59,13 +59,14 @@ cdef void blas_strsm(void* out, void** data) nogil:
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef float* alpha = <float*>(data[6])
cdef float* a = <float*>(data[7])
cdef float* b = <float*>(data[8])
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float* alpha = <float*>(data[7])
cdef float* a = <float*>(data[8])
cdef float* b = <float*>(data[9])
cdef float* x = <float*>(out)
if x != b:
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(float))
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
@ -77,7 +78,14 @@ cdef void blas_strsm(void* out, void** data) nogil:
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_strsm", <void*>(blas_strsm))
@ -88,13 +96,14 @@ cdef void blas_dtrsm(void* out, void** data) nogil:
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef double* alpha = <double*>(data[6])
cdef double* a = <double*>(data[7])
cdef double* b = <double*>(data[8])
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double* alpha = <double*>(data[7])
cdef double* a = <double*>(data[8])
cdef double* b = <double*>(data[9])
cdef double* x = <double*>(out)
if x != b:
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(double))
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
@ -106,7 +115,15 @@ cdef void blas_dtrsm(void* out, void** data) nogil:
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_dtrsm", <void*>(blas_dtrsm))
@ -118,13 +135,14 @@ cdef void blas_ctrsm(void* out, void** data) nogil:
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef float complex* alpha = <float complex*>(data[6])
cdef float complex* a = <float complex*>(data[7])
cdef float complex* b = <float complex*>(data[8])
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float complex* alpha = <float complex*>(data[7])
cdef float complex* a = <float complex*>(data[8])
cdef float complex* b = <float complex*>(data[9])
cdef float complex* x = <float complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
@ -136,7 +154,15 @@ cdef void blas_ctrsm(void* out, void** data) nogil:
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ctrsm", <void*>(blas_ctrsm))
@ -147,13 +173,14 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef double complex* alpha = <double complex*>(data[6])
cdef double complex* a = <double complex*>(data[7])
cdef double complex* b = <double complex*>(data[8])
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double complex* alpha = <double complex*>(data[7])
cdef double complex* a = <double complex*>(data[8])
cdef double complex* b = <double complex*>(data[9])
cdef double complex* x = <double complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
@ -165,20 +192,36 @@ cdef void blas_ztrsm(void* out, void** data) nogil:
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
a_shape = c.GetShape(a)
b_shape = c.GetShape(b)
dtype = b_shape.element_type()
m, n = b_shape.dimensions()
dims = b_shape.dimensions()
m, n = dims[-2:]
k = m if left_side else n
a_shape = c.GetShape(a)
if (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
num_b = 1
for d in batch_dims:
num_b *= d
if batch_dims + (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
@ -196,6 +239,7 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
return c.CustomCall(
fn,
operands=(
@ -205,8 +249,9 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
c.ConstantS32Scalar(int(diag)),
c.ConstantS32Scalar(m),
c.ConstantS32Scalar(n),
c.ConstantS32Scalar(num_b),
alpha, a, b),
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), layout),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
@ -214,9 +259,10 @@ def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (), ()),
Shape.array_shape(dtype, a_shape.dimensions(), (0, 1)),
Shape.array_shape(dtype, b_shape.dimensions(), (0, 1)),
Shape.array_shape(dtype, a_shape.dimensions(), layout),
Shape.array_shape(dtype, b_shape.dimensions(), layout),
))
jax_trsm = trsm