From 3fe362fe497ff6040d206c5228b181ec2e977024 Mon Sep 17 00:00:00 2001
From: Francis Couture-Harpin <git@compilade.net>
Date: Sat, 12 Apr 2025 00:00:51 -0400
Subject: [PATCH] gguf-py : use ThreadPoolExecutor when writing tensors

- gguf-py : handle (limited) retries for remote tensors
---
 gguf-py/gguf/gguf_writer.py | 140 +++++++++++++++++++++---------------
 gguf-py/gguf/utility.py     |  50 ++++++++++---
 2 files changed, 123 insertions(+), 67 deletions(-)

diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index db8ad4f05..ea283c57f 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -10,10 +10,10 @@ from dataclasses import dataclass
 from enum import Enum, auto
 from math import prod
 from pathlib import Path
-from queue import Empty, Queue
 from io import BufferedWriter
 from typing import IO, Any, Sequence, Mapping
 from string import ascii_letters, digits
+from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
 
 import numpy as np
 
@@ -62,20 +62,49 @@ class WriterState(Enum):
     WEIGHTS = auto()
 
 
+# To close files which were opened in thread-local context
+# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
+# ref: https://github.com/python/cpython/issues/89502
+class _ThreadedOpenFiles:
+    files: dict[Path, BufferedWriter]
+
+    def __init__(self):
+        self.files = {}
+
+    def __del__(self):
+        for file in self.files.values():
+            file.close()
+
+    def __getitem__(self, key: Path, /) -> BufferedWriter:
+        if key not in self.files:
+            self.files[key] = open(key, "r+b")
+        return self.files[key]
+
+    @classmethod
+    def init_thread_local(cls, local_data):
+        local_data.open_files = _ThreadedOpenFiles()
+
+
+# Exit quickly instead of waiting
+class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
+        del exc_type, exc_val, exc_tb
+        self.shutdown(wait=False, cancel_futures=True)
+        return False
+
+
 @dataclass
-class ThreadedTensorWriteInfo:
+class _ThreadedTensorWriteInfo:
     filename: Path
     offset: int
     post_pad: int
     tensor: np.ndarray
     bar: Any | None  # optional tqdm progress bar
 
-    def write_chunk(self, open_files: dict[Path, BufferedWriter]):
+    def write_chunk(self, open_files: _ThreadedOpenFiles):
         # This is called from a thread pool,
         # and each thread should have its own file handle per output file
         # so that they can have different seek locations.
-        if self.filename not in open_files:
-            open_files[self.filename] = open(self.filename, "r+b")
         f = open_files[self.filename]
 
         f.seek(self.offset)
@@ -462,9 +491,6 @@ class GGUFWriter:
 
         if self.temp_file is None:
             bar = None
-            # Distribute writing the tensors between multiple threads
-            tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
-
             # Initial file offsets before writing the tensor data
             offsets: list[int] = [fout.tell() for fout in self.fout]
 
@@ -476,60 +502,58 @@ class GGUFWriter:
 
                 bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 
-            # Fill the tensor queue with all the pending tensor writes
-            for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
-                offset = offsets[i]
+            # Allow opening the files only once per worker
+            local_data = threading.local()
 
-                # relying on the fact that Python dicts preserve insertion order (since 3.7)
-                for ti in tensors.values():
-                    assert ti.tensor is not None  # can only iterate once over the tensors
-                    assert ti.tensor.nbytes == ti.nbytes
-                    start_offset = offset
-                    nbytes = ti.tensor.nbytes
-                    offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
-                    padding = offset - (start_offset + nbytes)
-                    tensor_queue.put(
-                        ThreadedTensorWriteInfo(
-                            filename=filename,
-                            offset=start_offset,
-                            post_pad=padding,
-                            tensor=ti.tensor,
-                            bar=bar,
+            # Unit of work
+            def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
+                tensor.write_chunk(local_data.open_files)
+
+            with _InterruptibleThreadPoolExecutor(
+                max_workers=self.thread_count,
+                initializer=_ThreadedOpenFiles.init_thread_local,
+                initargs=(local_data,),
+            ) as executor:
+
+                futures: list[Future] = []
+
+                # Fill the tensor queue with all the pending tensor writes
+                for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
+                    offset = offsets[i]
+
+                    # relying on the fact that Python dicts preserve insertion order (since 3.7)
+                    for ti in tensors.values():
+                        assert ti.tensor is not None  # can only iterate once over the tensors
+                        assert ti.tensor.nbytes == ti.nbytes
+                        start_offset = offset
+                        nbytes = ti.tensor.nbytes
+                        offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
+                        padding = offset - (start_offset + nbytes)
+                        futures.append(
+                            executor.submit(
+                                thread_write_tensor,
+                                _ThreadedTensorWriteInfo(
+                                    filename=filename,
+                                    offset=start_offset,
+                                    post_pad=padding,
+                                    tensor=ti.tensor,
+                                    bar=bar,
+                                ),
+                            )
                         )
-                    )
-                    ti.tensor = None  # avoid keeping a reference to written tensors
+                        ti.tensor = None  # avoid keeping a reference to written tensors
 
-            # Write tensors in parallel
-            # TODO: total tensor size limit for the running threads
-            def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
-                # Opening the files only once per thread
-                open_files: dict[Path, BufferedWriter] = {}
-                try:
-                    while tensor := queue.get_nowait():
-                        tensor.write_chunk(open_files)
-                        del tensor
-                        queue.task_done()
-                except Empty:
-                    pass
-
-                for f in open_files.values():
-                    f.close()
-
-            threads = [
-                threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
-                for _ in range(self.thread_count)
-            ]
-
-            for t in threads:
-                t.start()
-
-            # NOTE: thread joining has weird interactions with KeyboardInterrupt,
-            #       so waiting for the queue to be "done" first.
-            tensor_queue.join()
-
-            for t in threads:
-                t.join()
+                # FIXME: there's still some weird behavior with KeyboardInterrupt
+                #        not being able to interrupt a future mid-execution
+                done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
+                exc = None
+                if any(f for f in done
+                       if not f.cancelled() and (exc := f.exception()) is not None):
+                    raise RuntimeError("Error writing tensors") from exc
+                elif len(not_done) != 0:
+                    raise RuntimeError("Not all tensors were written")
 
+            del local_data
         else:
             self.temp_file.seek(0)
 
diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py
index e5251aef8..0734b9f25 100644
--- a/gguf-py/gguf/utility.py
+++ b/gguf-py/gguf/utility.py
@@ -5,6 +5,14 @@ from typing import Literal
 
 import os
 import json
+import time
+import logging
+
+import requests
+from urllib.parse import urlparse
+
+
+logger = logging.getLogger(__name__)
 
 
 def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
 
 @dataclass
 class RemoteTensor:
+    name: str
     dtype: str
     shape: tuple[int, ...]
     offset_start: int
@@ -82,9 +91,30 @@ class RemoteTensor:
     url: str
 
     def data(self) -> bytearray:
-        # TODO: handle request errors (maybe with limited retries?)
-        # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
-        data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
+        data = None
+        MAX_RETRIES = 8
+        for i in range(MAX_RETRIES):
+            try:
+                # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
+                data = bytearray(
+                    SafetensorRemote.get_data_by_range(
+                        url=self.url, start=self.offset_start, size=self.size
+                    )
+                )
+            except (
+                requests.exceptions.ChunkedEncodingError,
+                requests.exceptions.ContentDecodingError,
+                requests.exceptions.ConnectionError,
+            ) as e:
+                if i == MAX_RETRIES - 1:
+                    raise RuntimeError(f"Failed to download tensor {self.name}") from e
+                logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
+                time.sleep(2 * i + 1)  # 1 3 5 7 9 11 13
+                continue
+
+        if data is None:
+            raise RuntimeError(f"Failed to download tensor {self.name}")
+
         return data
 
 
@@ -169,7 +199,14 @@ class SafetensorRemote:
                 offset_start_relative, offset_end_relative = meta["data_offsets"]
                 size = offset_end_relative - offset_start_relative
                 offset_start = data_start_offset + offset_start_relative
-                res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
+                res[name] = RemoteTensor(
+                    name=name,
+                    dtype=dtype,
+                    shape=tuple(shape),
+                    offset_start=offset_start,
+                    size=size,
+                    url=url,
+                )
             except KeyError as e:
                 raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
 
@@ -217,8 +254,6 @@ class SafetensorRemote:
         Get raw byte data from a remote file by range.
         If size is not specified, it will read the entire file.
         """
-        import requests
-        from urllib.parse import urlparse
 
         parsed_url = urlparse(url)
         if not parsed_url.scheme or not parsed_url.netloc:
@@ -239,9 +274,6 @@ class SafetensorRemote:
         Check if a file exists at the given URL.
         Returns True if the file exists, False otherwise.
         """
-        import requests
-        from urllib.parse import urlparse
-
         parsed_url = urlparse(url)
         if not parsed_url.scheme or not parsed_url.netloc:
             raise ValueError(f"Invalid URL: {url}")