diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc
index bf76752c6..d2e52e80d 100644
--- a/jaxlib/gpu/rnn_kernels.cc
+++ b/jaxlib/gpu/rnn_kernels.cc
@@ -75,7 +75,7 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
                                        int num_layers, int batch_size,
                                        int max_seq_length, float dropout,
                                        bool bidirectional) {
-  auto h = DnnHandlePool::Borrow();
+  auto h = DnnHandlePool::Borrow(/*stream=*/nullptr);
   JAX_RETURN_IF_ERROR(h.status());
   auto& handle = *h;
 
diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc
index ffa38c803..d986ebd51 100644
--- a/jaxlib/gpu/solver.cc
+++ b/jaxlib/gpu/solver.cc
@@ -57,7 +57,7 @@ SolverType DtypeToSolverType(const py::dtype& np_type) {
 std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
                                                int m, int n) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -96,7 +96,7 @@ std::pair<int, py::bytes> BuildGetrfDescriptor(const py::dtype& dtype, int b,
 std::pair<int, py::bytes> BuildGeqrfDescriptor(const py::dtype& dtype, int b,
                                                int m, int n) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -148,7 +148,7 @@ py::bytes BuildCsrlsvqrDescriptor(const py::dtype& dtype, int n, int nnzA,
 std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
                                                int m, int n, int k) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -191,7 +191,7 @@ std::pair<int, py::bytes> BuildOrgqrDescriptor(const py::dtype& dtype, int b,
 std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
                                                bool lower, int b, int n) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -230,7 +230,7 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
 std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
                                                bool lower, int batch, int n) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -298,7 +298,7 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
                                                int m, int n, bool compute_uv,
                                                bool full_matrices) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -343,7 +343,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
                                                 int batch, int m, int n,
                                                 bool compute_uv, int econ) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
@@ -426,7 +426,7 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
 std::pair<int, py::bytes> BuildSytrdDescriptor(const py::dtype& dtype,
                                                bool lower, int b, int n) {
   SolverType type = DtypeToSolverType(dtype);
-  auto h = SolverHandlePool::Borrow();
+  auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   int lwork;
diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc
index ed0e079df..67166b58b 100644
--- a/jaxlib/gpu/sparse.cc
+++ b/jaxlib/gpu/sparse.cc
@@ -108,7 +108,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
 std::pair<size_t, py::bytes> BuildCsrToDenseDescriptor(
     const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor d =
@@ -185,7 +185,7 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque,
 std::pair<size_t, py::bytes> BuildCsrFromDenseDescriptor(
     const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor d =
@@ -262,7 +262,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
     const py::dtype& data_dtype, const py::dtype& x_dtype,
     const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz, bool transpose) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor A =
@@ -310,7 +310,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
     const py::dtype& data_dtype, const py::dtype& b_dtype,
     const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
     int cols, int BCcols, int nnz, bool transpose) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor A =
@@ -361,7 +361,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
 std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
     const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor d =
@@ -398,7 +398,7 @@ std::pair<size_t, py::bytes> BuildCooToDenseDescriptor(
 std::pair<size_t, py::bytes> BuildCooFromDenseDescriptor(
     const py::dtype& data_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor d =
@@ -435,7 +435,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
     const py::dtype& data_dtype, const py::dtype& x_dtype,
     const py::dtype& compute_dtype, const py::dtype& index_dtype, int rows,
     int cols, int nnz, bool transpose) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   SparseMatDescriptor A =
@@ -489,7 +489,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
   // All three matrices A, B, and C must have the same batch count.
   // Use batch stride to trigger individual mode, e.g.,
   // `rhs_batch_stride = 0` for C_i = A_i B.
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
 
@@ -554,7 +554,7 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
 
 template <typename F>
 size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(/*stream=*/nullptr);
   JAX_THROW_IF_ERROR(h.status());
   auto& handle = *h;
   size_t size;
diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc
index b0024989a..1d7ece842 100644
--- a/jaxlib/gpu/sparse_kernels.cc
+++ b/jaxlib/gpu/sparse_kernels.cc
@@ -550,7 +550,7 @@ void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque,
 template <typename T, typename F>
 static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers,
                           const char* opaque, std::size_t opaque_len) {
-  auto h = SparseHandlePool::Borrow();
+  auto h = SparseHandlePool::Borrow(stream);
   JAX_RETURN_IF_ERROR(h.status());
   auto& handle = *h;
 
diff --git a/jaxlib/handle_pool.h b/jaxlib/handle_pool.h
index a80b9fe42..9201d8d57 100644
--- a/jaxlib/handle_pool.h
+++ b/jaxlib/handle_pool.h
@@ -77,7 +77,7 @@ class HandlePool {
 
   // Borrows a handle from the pool. If 'stream' is non-null, sets the stream
   // associated with the handle.
-  static absl::StatusOr<Handle> Borrow(StreamType stream = nullptr);
+  static absl::StatusOr<Handle> Borrow(StreamType stream);
 
  private:
   static HandlePool<HandleType, StreamType>* Instance();