mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[sparse] specify operand layouts in cusparse.py
Why? This can fix issues when inputs have non-standard layouts PiperOrigin-RevId: 411110145
This commit is contained in:
parent
f08a5a07a8
commit
a93c99d7be
@ -34,13 +34,28 @@ is_supported : bool = _cusparse and _cusparse.cusparse_supported
|
||||
_ops = xla_client.ops
|
||||
_Shape = xla_client.Shape
|
||||
|
||||
def csr_todense(c, data, indices, indptr, *, shape):
|
||||
"""CSR to dense matrix."""
|
||||
def _validate_csr(c, data, indices, indptr, shape):
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
assert c.get_shape(indices).dimensions() == (nnz,)
|
||||
assert c.get_shape(indptr).element_type() == index_dtype
|
||||
assert c.get_shape(indptr).dimensions() == (shape[0] + 1,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def _validate_coo(c, data, row, col, shape):
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
assert c.get_shape(row).dimensions() == (nnz,)
|
||||
assert c.get_shape(col).element_type() == index_dtype
|
||||
assert c.get_shape(col).dimensions() == (nnz,)
|
||||
return data_dtype, index_dtype, nnz
|
||||
|
||||
def csr_todense(c, data, indices, indptr, *, shape):
|
||||
"""CSR to dense matrix."""
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
nnz = c.get_shape(data).dimensions()[0]
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
@ -50,10 +65,9 @@ def csr_todense(c, data, indices, indptr, *, shape):
|
||||
b"cusparse_csr_todense",
|
||||
operands=(data, indices, indptr),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
@ -98,18 +112,16 @@ def csr_fromdense(c, mat, *, nnz, index_dtype):
|
||||
|
||||
def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR matrix/vector multiply."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
assert c.get_shape(indptr).element_type() == index_dtype
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
x_shape = c.get_shape(x).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
compute_dtype = data_dtype
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_matvec_descriptor(
|
||||
dtype, x_dtype, compute_dtype, index_dtype,
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
@ -118,11 +130,10 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
|
||||
b"cusparse_csr_matvec",
|
||||
operands=(data, indices, indptr, x),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
c.get_shape(x),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
|
||||
_Shape.array_shape(x_dtype, x_shape, (0,))
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
@ -136,20 +147,17 @@ def csr_matvec(c, data, indices, indptr, x, *, shape, transpose=False, compute_d
|
||||
|
||||
def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""CSR from dense matrix."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(indices).element_type())
|
||||
assert c.get_shape(indptr).element_type() == index_dtype
|
||||
data_dtype, index_dtype, nnz = _validate_csr(c, data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
B_shape = c.get_shape(B).dimensions()
|
||||
rows, cols = shape
|
||||
_, Ccols = B_shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
compute_dtype = data_dtype
|
||||
|
||||
buffer_size, opaque = _cusparse.build_csr_matmat_descriptor(
|
||||
dtype, B_dtype, compute_dtype, index_dtype,
|
||||
data_dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
@ -158,9 +166,9 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
b"cusparse_csr_matmat",
|
||||
operands=(data, indices, indptr, B),
|
||||
operand_shapes_with_layout=(
|
||||
c.get_shape(data),
|
||||
c.get_shape(indices),
|
||||
c.get_shape(indptr),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (rows + 1,), (0,)),
|
||||
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
@ -175,11 +183,8 @@ def csr_matmat(c, data, indices, indptr, B, *, shape, transpose=False, compute_d
|
||||
|
||||
def coo_todense(c, data, row, col, *, shape):
|
||||
"""COO to dense matrix."""
|
||||
data_dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
assert c.get_shape(row).element_type() == index_dtype
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
rows, cols = shape
|
||||
nnz = c.get_shape(data).dimensions()[0]
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_todense_descriptor(
|
||||
data_dtype, index_dtype, rows, cols, nnz)
|
||||
@ -189,10 +194,9 @@ def coo_todense(c, data, row, col, *, shape):
|
||||
b"cusparse_coo_todense",
|
||||
operands=(data, row, col),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(data_dtype, shape, (1, 0)),
|
||||
@ -236,18 +240,16 @@ def coo_fromdense(c, mat, *, nnz, index_dtype):
|
||||
|
||||
def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=None):
|
||||
"""COO matrix/vector multiply."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
assert c.get_shape(row).element_type() == index_dtype
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
rows, cols = shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
x_dtype = np.dtype(c.get_shape(x).element_type())
|
||||
x_shape = c.get_shape(x).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
compute_dtype = data_dtype
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_matvec_descriptor(
|
||||
dtype, x_dtype, compute_dtype, index_dtype,
|
||||
data_dtype, x_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
@ -256,11 +258,10 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
|
||||
b"cusparse_coo_matvec",
|
||||
operands=(data, row, col, x),
|
||||
operand_shapes_with_layout=(
|
||||
# All are 1D, so no layout necessary
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
c.get_shape(x),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(x_dtype, x_shape, (0,)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(compute_dtype, (out_size,), (0,)),
|
||||
@ -274,20 +275,17 @@ def coo_matvec(c, data, row, col, x, *, shape, transpose=False, compute_dtype=No
|
||||
|
||||
def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=None):
|
||||
"""COO from dense matrix."""
|
||||
dtype = np.dtype(c.get_shape(data).element_type())
|
||||
index_dtype = np.dtype(c.get_shape(row).element_type())
|
||||
assert c.get_shape(row).element_type() == index_dtype
|
||||
data_dtype, index_dtype, nnz = _validate_coo(c, data, row, col, shape)
|
||||
rows, cols = shape
|
||||
B_dtype = np.dtype(c.get_shape(B).element_type())
|
||||
B_shape = c.get_shape(B).dimensions()
|
||||
rows, cols = shape
|
||||
_, Ccols = B_shape
|
||||
nnz, = c.get_shape(data).dimensions()
|
||||
|
||||
if compute_dtype is None:
|
||||
compute_dtype = dtype
|
||||
compute_dtype = data_dtype
|
||||
|
||||
buffer_size, opaque = _cusparse.build_coo_matmat_descriptor(
|
||||
dtype, B_dtype, compute_dtype, index_dtype,
|
||||
data_dtype, B_dtype, compute_dtype, index_dtype,
|
||||
rows, cols, Ccols, nnz, transpose)
|
||||
out_size = cols if transpose else rows
|
||||
|
||||
@ -296,9 +294,9 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
b"cusparse_coo_matmat",
|
||||
operands=(data, row, col, B),
|
||||
operand_shapes_with_layout=(
|
||||
c.get_shape(data),
|
||||
c.get_shape(row),
|
||||
c.get_shape(col),
|
||||
_Shape.array_shape(data_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(index_dtype, (nnz,), (0,)),
|
||||
_Shape.array_shape(B_dtype, B_shape, (1, 0)),
|
||||
),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
|
Loading…
x
Reference in New Issue
Block a user