mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #6849 from tomhennigan:changelist/376000598
PiperOrigin-RevId: 381010658
This commit is contained in:
commit
28977761d5
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
@ -35,6 +37,7 @@ from jax.lib import lapack
|
||||
|
||||
from jax.lib import cuda_linalg
|
||||
from jax.lib import cusolver
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import rocsolver
|
||||
|
||||
from jax.lib import xla_client
|
||||
@ -1350,3 +1353,61 @@ if cusolver is not None:
|
||||
if rocsolver is not None:
|
||||
xla.backend_specific_translations['gpu'][svd_p] = partial(
|
||||
_svd_cpu_gpu_translation_rule, rocsolver.gesvd)
|
||||
|
||||
|
||||
tridiagonal_solve_p = Primitive('tridiagonal_solve')
|
||||
tridiagonal_solve_p.multiple_results = False
|
||||
tridiagonal_solve_p.def_impl(
|
||||
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
if cusparse is not None:
|
||||
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2
|
||||
|
||||
|
||||
def tridiagonal_solve(dl, d, du, b):
|
||||
r"""Computes the solution of a tridiagonal linear system.
|
||||
|
||||
This function computes the solution of a tridiagonal linear system::
|
||||
|
||||
.. math::
|
||||
A . X = B
|
||||
|
||||
Args:
|
||||
dl: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[0] = 0``.
|
||||
d: The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
|
||||
du: The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
|
||||
Note that ``dl[m - 1] = 0``.
|
||||
b: Right hand side matrix.
|
||||
|
||||
Returns:
|
||||
Solution ``X`` of tridiagonal system.
|
||||
"""
|
||||
if dl.ndim != 1 or d.ndim != 1 or du.ndim != 1:
|
||||
raise ValueError('dl, d and du must be vectors')
|
||||
|
||||
if dl.shape != d.shape or d.shape != du.shape:
|
||||
raise ValueError(
|
||||
f'dl={dl.shape}, d={d.shape} and du={du.shape} must all be `[m]`')
|
||||
|
||||
if b.ndim != 2:
|
||||
raise ValueError(f'b={b.shape} must be a matrix')
|
||||
|
||||
m, = dl.shape
|
||||
if m < 3:
|
||||
raise ValueError(f'm ({m}) must be >= 3')
|
||||
|
||||
ldb, n = b.shape
|
||||
if ldb < max(1, m):
|
||||
raise ValueError(f'Leading dimension of b={ldb} must be ≥ max(1, {m})')
|
||||
|
||||
if dl.dtype != d.dtype or d.dtype != du.dtype or du.dtype != b.dtype:
|
||||
raise ValueError(f'dl={dl.dtype}, d={d.dtype}, du={du.dtype} and '
|
||||
f'b={b.dtype} must be the same dtype,')
|
||||
|
||||
t = dl.dtype
|
||||
if t not in (np.float32, np.float64):
|
||||
raise ValueError(f'Only f32/f64 are supported, got {t}')
|
||||
|
||||
return tridiagonal_solve_p.bind(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)
|
||||
|
@ -979,6 +979,7 @@ tf_not_yet_impl = [
|
||||
"lu_pivots_to_permutation",
|
||||
"rng_bit_generator",
|
||||
"xla_pmap",
|
||||
"tridiagonal_solve",
|
||||
]
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
|
@ -339,7 +339,7 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
|
||||
ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
|
||||
*xla_args, **params)
|
||||
else:
|
||||
raise NotImplementedError(f"XLA translation rule for {prim} not found")
|
||||
raise NotImplementedError(f"XLA translation rule for {prim!r} on platform {platform!r} not found")
|
||||
assert isinstance(ans, xe.XlaOp)
|
||||
c.clear_op_metadata()
|
||||
try:
|
||||
|
@ -29,4 +29,6 @@ from jax._src.lax.linalg import (
|
||||
svd_p,
|
||||
triangular_solve,
|
||||
triangular_solve_p,
|
||||
tridiagonal_solve,
|
||||
tridiagonal_solve_p,
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_ENABLE_CUSPARSE (CUSPARSE_VERSION >= 11300)
|
||||
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
@ -44,32 +44,18 @@ namespace {
|
||||
namespace py = pybind11;
|
||||
|
||||
void ThrowIfErrorStatus(cusparseStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSPARSE_STATUS_SUCCESS:
|
||||
return;
|
||||
case CUSPARSE_STATUS_NOT_INITIALIZED:
|
||||
throw std::runtime_error("cuSparse has not been initialized");
|
||||
case CUSPARSE_STATUS_ALLOC_FAILED:
|
||||
throw std::runtime_error("cuSparse allocation failure");
|
||||
case CUSPARSE_STATUS_INVALID_VALUE:
|
||||
throw std::runtime_error("cuSparse invalid value error");
|
||||
case CUSPARSE_STATUS_ARCH_MISMATCH:
|
||||
throw std::runtime_error("cuSparse architecture mismatch");
|
||||
case CUSPARSE_STATUS_MAPPING_ERROR:
|
||||
throw std::runtime_error("cuSparse mapping error");
|
||||
case CUSPARSE_STATUS_EXECUTION_FAILED:
|
||||
throw std::runtime_error("cuSparse execution failed");
|
||||
case CUSPARSE_STATUS_INTERNAL_ERROR:
|
||||
throw std::runtime_error("cuSparse internal error");
|
||||
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
|
||||
throw std::runtime_error("cuSparse matrix type not supported error");
|
||||
case CUSPARSE_STATUS_ZERO_PIVOT:
|
||||
throw std::runtime_error("cuSparse zero pivot error");
|
||||
default:
|
||||
throw std::runtime_error("Unknown cuSparse error");
|
||||
if (status != CUSPARSE_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(cusparseGetErrorString(status));
|
||||
}
|
||||
}
|
||||
|
||||
void ThrowIfErrorStatus(cudaError_t error) {
|
||||
if (error != cudaSuccess) {
|
||||
throw std::runtime_error(cudaGetErrorString(error));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
union CudaConst {
|
||||
int8_t i8[2];
|
||||
int16_t i16[2];
|
||||
@ -93,7 +79,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
CudaConst c;
|
||||
std::memset(&c, 0, sizeof(c));
|
||||
switch (type) {
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
|
||||
case CUDA_R_4I:
|
||||
case CUDA_C_4I:
|
||||
@ -102,7 +88,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_8I:
|
||||
c.i8[0] = 1;
|
||||
break;
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
case CUDA_R_4U:
|
||||
case CUDA_C_4U:
|
||||
#endif
|
||||
@ -110,7 +96,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_8U:
|
||||
c.u8[0] = 1;
|
||||
break;
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
case CUDA_R_16I:
|
||||
case CUDA_C_16I:
|
||||
c.i16[0] = 1;
|
||||
@ -128,7 +114,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_32U:
|
||||
c.u32[0] = 1;
|
||||
break;
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
case CUDA_R_64I:
|
||||
case CUDA_C_64I:
|
||||
c.i64[0] = 1;
|
||||
@ -143,7 +129,7 @@ CudaConst CudaOne(cudaDataType type) {
|
||||
case CUDA_C_16F:
|
||||
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
|
||||
break;
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
case CUDA_R_16BF:
|
||||
case CUDA_C_16BF:
|
||||
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
|
||||
@ -204,7 +190,7 @@ cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
|
||||
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
|
||||
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
|
||||
{{'u', 4}, CUDA_R_32U},
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
{{'V', 2}, CUDA_R_16BF},
|
||||
#endif
|
||||
});
|
||||
@ -255,7 +241,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
return DenseVecDescriptor{value_type, size};
|
||||
}
|
||||
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
// CsrToDense: Convert CSR matrix to dense matrix
|
||||
|
||||
// Returns the descriptor for a Sparse matrix.
|
||||
@ -858,12 +844,83 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
|
||||
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
|
||||
}
|
||||
#endif // if JAX_CUSPARSE_11030
|
||||
|
||||
#endif
|
||||
struct Gtsv2Descriptor {
|
||||
int m, n, ldb;
|
||||
};
|
||||
|
||||
py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
|
||||
}
|
||||
|
||||
template <typename T, typename F1, typename F2>
|
||||
void gtsv2(F1 computeGtsv2BufSize, F2 computeGtsv2, cudaStream_t stream,
|
||||
void** buffers, const char* opaque, std::size_t opaque_len) {
|
||||
auto handle = SparseHandlePool::Borrow();
|
||||
|
||||
const Gtsv2Descriptor& descriptor =
|
||||
*UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
|
||||
int m = descriptor.m;
|
||||
int n = descriptor.n;
|
||||
int ldb = descriptor.ldb;
|
||||
|
||||
const T* dl = (const T*)(buffers[0]);
|
||||
const T* d = (const T*)(buffers[1]);
|
||||
const T* du = (const T*)(buffers[2]);
|
||||
const T* B = (T*)(buffers[3]);
|
||||
T* X = (T*)(buffers[4]);
|
||||
|
||||
// The solution X is written in place to B. We need to therefore copy the
|
||||
// contents of B into the output buffer X and pass that into the kernel as B.
|
||||
// Once copy insertion is supported for custom call aliasing, we could alias B
|
||||
// with X and avoid the copy, the code below is written defensively assuming B
|
||||
// and X might alias, but today we know they will not.
|
||||
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
|
||||
if (X != B) {
|
||||
size_t B_bytes = ldb * n * sizeof(T);
|
||||
ThrowIfErrorStatus(
|
||||
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
size_t bufferSize;
|
||||
ThrowIfErrorStatus(
|
||||
computeGtsv2BufSize(handle.get(), m, n, dl, d, du, X, ldb, &bufferSize));
|
||||
|
||||
void* buffer;
|
||||
#if CUDA_VERSION >= 11020
|
||||
ThrowIfErrorStatus(cudaMallocAsync(&buffer, bufferSize, stream));
|
||||
#else
|
||||
ThrowIfErrorStatus(cudaMalloc(&buffer, bufferSize));
|
||||
#endif // CUDA_VERSION >= 11020
|
||||
|
||||
auto computeStatus =
|
||||
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer);
|
||||
|
||||
#if CUDA_VERSION >= 11020
|
||||
ThrowIfErrorStatus(cudaFreeAsync(buffer, stream));
|
||||
#else
|
||||
ThrowIfErrorStatus(cudaFree(buffer));
|
||||
#endif // CUDA_VERSION >= 11020
|
||||
|
||||
ThrowIfErrorStatus(computeStatus);
|
||||
}
|
||||
|
||||
void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
gtsv2<float>(cusparseSgtsv2_bufferSizeExt, cusparseSgtsv2, stream, buffers,
|
||||
opaque, opaque_len);
|
||||
}
|
||||
|
||||
void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
std::size_t opaque_len) {
|
||||
gtsv2<double>(cusparseDgtsv2_bufferSizeExt, cusparseDgtsv2, stream, buffers,
|
||||
opaque, opaque_len);
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
|
||||
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
|
||||
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
|
||||
@ -873,13 +930,16 @@ py::dict Registrations() {
|
||||
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
|
||||
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
|
||||
#endif
|
||||
dict["cusparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
|
||||
dict["cusparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
|
||||
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(cusparse_kernels, m) {
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_ENABLE_CUSPARSE);
|
||||
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11030);
|
||||
m.def("registrations", &Registrations);
|
||||
#if JAX_ENABLE_CUSPARSE
|
||||
#if JAX_CUSPARSE_11030
|
||||
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
|
||||
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
|
||||
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
|
||||
@ -889,6 +949,7 @@ PYBIND11_MODULE(cusparse_kernels, m) {
|
||||
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
|
||||
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
|
||||
#endif
|
||||
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -287,3 +287,16 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
|
||||
opaque=opaque,
|
||||
)
|
||||
return _ops.GetTupleElement(out, 0)
|
||||
|
||||
|
||||
def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
|
||||
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
|
||||
dl_shape, d_shape, du_shape, B_shape = map(c.get_shape, (dl, d, du, B))
|
||||
return xla_client.ops.CustomCallWithLayout(
|
||||
c,
|
||||
b"cusparse_gtsv2_" + (b"f32" if (t == np.float32) else b"f64"),
|
||||
operands=(dl, d, du, B),
|
||||
operand_shapes_with_layout=(dl_shape, d_shape, du_shape, B_shape),
|
||||
shape_with_layout=B_shape,
|
||||
opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb),
|
||||
has_side_effect=False)
|
||||
|
@ -1414,7 +1414,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
|
||||
rtol=tol)
|
||||
|
||||
class EighTridiagonalTest(jtu.JaxTestCase):
|
||||
class LaxLinalgTest(jtu.JaxTestCase):
|
||||
|
||||
def run_test(self, alpha, beta):
|
||||
n = alpha.shape[-1]
|
||||
@ -1477,6 +1477,21 @@ class EighTridiagonalTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
|
||||
|
||||
@parameterized.parameters(np.float32, np.float64)
|
||||
def test_tridiagonal_solve(self, dtype):
|
||||
if jtu.device_under_test() != "gpu":
|
||||
self.skipTest("Only supported on GPU")
|
||||
|
||||
dl = np.array([0.0, 1.0, 2.0], dtype=dtype)
|
||||
d = np.ones(3, dtype=dtype)
|
||||
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
|
||||
m = 3
|
||||
B = np.ones([m, 1], dtype=dtype)
|
||||
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
|
||||
A = np.eye(3, dtype=dtype)
|
||||
A[[1, 2], [0, 1]] = dl[1:]
|
||||
A[[0, 1], [1, 2]] = du[:-1]
|
||||
np.testing.assert_allclose(A @ X, B)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user