mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call.
May fix flaky failures in CI. Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it. PiperOrigin-RevId: 530425766
This commit is contained in:
parent
00b75aff82
commit
f168a1560c
@ -75,7 +75,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
bool bidirectional) {
|
||||
auto h = DnnHandlePool::Borrow();
|
||||
auto h = DnnHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
|
@ -57,7 +57,7 @@ SolverType DtypeToSolverType(const py::dtype& np_type) {
|
||||
std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -96,7 +96,7 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
|
||||
std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -148,7 +148,7 @@ py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
|
||||
std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, int k) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -191,7 +191,7 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
|
||||
std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -230,7 +230,7 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
bool lower, int batch, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -298,7 +298,7 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
int m, int n, bool compute_uv,
|
||||
bool full_matrices) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -343,7 +343,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
int batch, int m, int n,
|
||||
bool compute_uv, int econ) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
@ -426,7 +426,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
std::pair<int, py::bytes> BuildSytrdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
|
@ -108,7 +108,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
|
||||
std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
@ -185,7 +185,7 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
@ -262,7 +262,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
@ -310,7 +310,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& b_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int BCcols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
@ -361,7 +361,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
|
||||
std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
@ -398,7 +398,7 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
|
||||
std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor d =
|
||||
@ -435,7 +435,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
|
||||
const py::dtype& data_dtype, const py::dtype& x_dtype,
|
||||
const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
|
||||
int cols, int nnz, bool transpose) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
SparseMatDescriptor A =
|
||||
@ -489,7 +489,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
|
||||
// All three matrices A, B, and C must have the same batch count.
|
||||
// Use batch stride to trigger individual mode, e.g.,
|
||||
// `rhs_batch_stride = 0` for C_i = A_i B.
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
@ -554,7 +554,7 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
|
||||
|
||||
template <typename F>
|
||||
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
size_t size;
|
||||
|
@ -550,7 +550,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
template <typename T, typename F>
|
||||
static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
|
||||
const char* opaque, std::size_t opaque_len) {
|
||||
auto h = SparseHandlePool::Borrow();
|
||||
auto h = SparseHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
|
@ -77,7 +77,7 @@ class HandlePool {
|
||||
|
||||
// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
|
||||
// associated with the handle.
|
||||
static absl::StatusOr<Handle> Borrow(StreamType stream = nullptr);
|
||||
static absl::StatusOr<Handle> Borrow(StreamType stream);
|
||||
|
||||
private:
|
||||
static HandlePool<HandleType, StreamType>* Instance();
|
||||
|
Loading…
x
Reference in New Issue
Block a user