From f168a1560ca2f6e94cfbddc2300c6c616c20f4ae Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 8 May 2023 15:36:17 -0700 Subject: [PATCH] [GPU] Add missing stream synchronization to tridiagonal_solve gtsv2 call. May fix flaky failures in CI. Make stream argument to Pool::Borrow() mandatory to minimize chance of forgetting it. PiperOrigin-RevId: 530425766 --- jaxlib/gpu/rnn_kernels.cc | 2 +- jaxlib/gpu/solver.cc | 16 ++++++++-------- jaxlib/gpu/sparse.cc | 18 +++++++++--------- jaxlib/gpu/sparse_kernels.cc | 2 +- jaxlib/handle_pool.h | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index bf76752c6..d2e52e80d 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -75,7 +75,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size, int num_layers, int batch_size, int max_seq_length, float dropout, bool bidirectional) { - auto h = DnnHandlePool::Borrow(); + auto h = DnnHandlePool::Borrow(/*stream=*/nullptr); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index ffa38c803..d986ebd51 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -57,7 +57,7 @@ SolverType DtypeToSolverType(const py::dtype& np_type) { std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, int m, int n) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -96,7 +96,7 @@ std::pair BuildGetrfDescriptor(const py::dtype& dtype, int b, std::pair BuildGeqrfDescriptor(const py::dtype& dtype, int b, int m, int n) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -148,7 +148,7 @@ py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA, std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, int m, int n, int k) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -191,7 +191,7 @@ std::pair BuildOrgqrDescriptor(const py::dtype& dtype, int b, std::pair BuildSyevdDescriptor(const py::dtype& dtype, bool lower, int b, int n) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -230,7 +230,7 @@ std::pair BuildSyevdDescriptor(const py::dtype& dtype, std::pair BuildSyevjDescriptor(const py::dtype& dtype, bool lower, int batch, int n) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -298,7 +298,7 @@ std::pair BuildGesvdDescriptor(const py::dtype& dtype, int b, int m, int n, bool compute_uv, bool full_matrices) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -343,7 +343,7 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, int batch, int m, int n, bool compute_uv, int econ) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; @@ -426,7 +426,7 @@ std::pair BuildGesvdjDescriptor(const py::dtype& dtype, std::pair BuildSytrdDescriptor(const py::dtype& dtype, bool lower, int b, int n) { SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(); + auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; int lwork; diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index ed0e079df..67166b58b 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -108,7 +108,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype, std::pair BuildCsrToDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor d = @@ -185,7 +185,7 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, std::pair BuildCsrFromDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor d = @@ -262,7 +262,7 @@ std::pair BuildCsrMatvecDescriptor( const py::dtype& data_dtype, const py::dtype& x_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor A = @@ -310,7 +310,7 @@ std::pair BuildCsrMatmatDescriptor( const py::dtype& data_dtype, const py::dtype& b_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int BCcols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor A = @@ -361,7 +361,7 @@ std::pair BuildCsrMatmatDescriptor( std::pair BuildCooToDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor d = @@ -398,7 +398,7 @@ std::pair BuildCooToDenseDescriptor( std::pair BuildCooFromDenseDescriptor( const py::dtype& data_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor d = @@ -435,7 +435,7 @@ std::pair BuildCooMatvecDescriptor( const py::dtype& data_dtype, const py::dtype& x_dtype, const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; SparseMatDescriptor A = @@ -489,7 +489,7 @@ std::pair BuildCooMatmatDescriptor( // All three matrices A, B, and C must have the same batch count. // Use batch stride to trigger individual mode, e.g., // `rhs_batch_stride = 0` for C_i = A_i B. - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -554,7 +554,7 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) { template size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; size_t size; diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index b0024989a..1d7ece842 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -550,7 +550,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, template static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len) { - auto h = SparseHandlePool::Borrow(); + auto h = SparseHandlePool::Borrow(stream); JAX_RETURN_IF_ERROR(h.status()); auto& handle = *h; diff --git a/jaxlib/handle_pool.h b/jaxlib/handle_pool.h index a80b9fe42..9201d8d57 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/handle_pool.h @@ -77,7 +77,7 @@ class HandlePool { // Borrows a handle from the pool. If 'stream' is non-null, sets the stream // associated with the handle. - static absl::StatusOr Borrow(StreamType stream = nullptr); + static absl::StatusOr Borrow(StreamType stream); private: static HandlePool* Instance();