[sparse] specify operand layouts in cusparse.py

Why? This can fix issues when inputs have non-standard layouts

PiperOrigin-RevId: 411110145
This commit is contained in:
Jake VanderPlas 2021-11-19 11:46:59 -08:00 committed by jax authors
parent f08a5a07a8
commit a93c99d7be

View File

@ -34,13 +34,28 @@ is_supported : bool = _cusparse and _cusparse.cusparse_supported
_ops = xla_client.ops
_Shape = xla_client.Shape
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
def _validate_csr(c, data, indices, indptr, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(indices).dimensions() == (nnz,)
assert c.get_shape(indptr).element_type() == index_dtype
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
return data_dtype, index_dtype, nnz
def _validate_coo(c, data, row, col, shape):
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
nnz, = c.get_shape(data).dimensions()
assert c.get_shape(row).dimensions() == (nnz,)
assert c.get_shape(col).element_type() == index_dtype
assert c.get_shape(col).dimensions() == (nnz,)
return data_dtype, index_dtype, nnz
def csr_todense(c, data, indices, indptr, *, shape):
"""CSR to dense matrix."""
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]
buffer_size, opaque = _cusparse.build_csr_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
@ -50,10 +65,9 @@ def csr_todense(c, data, indices, indptr, *, shape):
b"cusparse_csr_todense",
operands=(data, indices, indptr),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
@ -98,18 +112,16 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
"""CSR matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
assert c.get_shape(indptr).element_type() == index_dtype
x_dtype = np.dtype(c.get_shape(x).element_type())
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()
if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype
buffer_size, opaque = _cusparse.build_csr_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
@ -118,11 +130,10 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
b"cusparse_csr_matvec",
operands=(data, indices, indptr, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
c.get_shape(x),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,))
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
@ -136,20 +147,17 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
"""CSR from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(indices).element_type())
assert c.get_shape(indptr).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype
buffer_size, opaque = _cusparse.build_csr_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
@ -158,9 +166,9 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
b"cusparse_csr_matmat",
operands=(data, indices, indptr, B),
operand_shapes_with_layout=(
c.get_shape(data),
c.get_shape(indices),
c.get_shape(indptr),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((
@ -175,11 +183,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
def coo_todense(c, data, row, col, *, shape):
"""COO to dense matrix."""
data_dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
nnz = c.get_shape(data).dimensions()[0]
buffer_size, opaque = _cusparse.build_coo_todense_descriptor(
data_dtype, index_dtype, rows, cols, nnz)
@ -189,10 +194,9 @@ def coo_todense(c, data, row, col, *, shape):
b"cusparse_coo_todense",
operands=(data, row, col),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(data_dtype, shape, (1, 0)),
@ -236,18 +240,16 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
"""COO matrix/vector multiply."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
x_dtype = np.dtype(c.get_shape(x).element_type())
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
nnz, = c.get_shape(data).dimensions()
x_dtype = np.dtype(c.get_shape(x).element_type())
x_shape = c.get_shape(x).dimensions()
if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype
buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
dtype, x_dtype, compute_dtype, index_dtype,
data_dtype, x_dtype, compute_dtype, index_dtype,
rows, cols, nnz, transpose)
out_size = cols if transpose else rows
@ -256,11 +258,10 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
b"cusparse_coo_matvec",
operands=(data, row, col, x),
operand_shapes_with_layout=(
# All are 1D, so no layout necessary
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
c.get_shape(x),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(x_dtype, x_shape, (0,)),
),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
@ -274,20 +275,17 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
"""COO from dense matrix."""
dtype = np.dtype(c.get_shape(data).element_type())
index_dtype = np.dtype(c.get_shape(row).element_type())
assert c.get_shape(row).element_type() == index_dtype
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
rows, cols = shape
B_dtype = np.dtype(c.get_shape(B).element_type())
B_shape = c.get_shape(B).dimensions()
rows, cols = shape
_, Ccols = B_shape
nnz, = c.get_shape(data).dimensions()
if compute_dtype is None:
compute_dtype = dtype
compute_dtype = data_dtype
buffer_size, opaque = _cusparse.build_coo_matmat_descriptor(
dtype, B_dtype, compute_dtype, index_dtype,
data_dtype, B_dtype, compute_dtype, index_dtype,
rows, cols, Ccols, nnz, transpose)
out_size = cols if transpose else rows
@ -296,9 +294,9 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
b"cusparse_coo_matmat",
operands=(data, row, col, B),
operand_shapes_with_layout=(
c.get_shape(data),
c.get_shape(row),
c.get_shape(col),
_Shape.array_shape(data_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(index_dtype, (nnz,), (0,)),
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
),
shape_with_layout=_Shape.tuple_shape((