gguf-py : use ThreadPoolExecutor when writing tensors

- gguf-py : handle (limited) retries for remote tensors
This commit is contained in:
Francis Couture-Harpin 2025-04-12 00:00:51 -04:00
parent d7db1593ee
commit 3fe362fe49
2 changed files with 123 additions and 67 deletions

View File

@ -10,10 +10,10 @@ from dataclasses import dataclass
from enum import Enum, auto
from math import prod
from pathlib import Path
from queue import Empty, Queue
from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
import numpy as np
@ -62,20 +62,49 @@ class WriterState(Enum):
WEIGHTS = auto()
# To close files which were opened in thread-local context
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
# ref: https://github.com/python/cpython/issues/89502
class _ThreadedOpenFiles:
files: dict[Path, BufferedWriter]
def __init__(self):
self.files = {}
def __del__(self):
for file in self.files.values():
file.close()
def __getitem__(self, key: Path, /) -> BufferedWriter:
if key not in self.files:
self.files[key] = open(key, "r+b")
return self.files[key]
@classmethod
def init_thread_local(cls, local_data):
local_data.open_files = _ThreadedOpenFiles()
# Exit quickly instead of waiting
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
del exc_type, exc_val, exc_tb
self.shutdown(wait=False, cancel_futures=True)
return False
@dataclass
class ThreadedTensorWriteInfo:
class _ThreadedTensorWriteInfo:
filename: Path
offset: int
post_pad: int
tensor: np.ndarray
bar: Any | None # optional tqdm progress bar
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
def write_chunk(self, open_files: _ThreadedOpenFiles):
# This is called from a thread pool,
# and each thread should have its own file handle per output file
# so that they can have different seek locations.
if self.filename not in open_files:
open_files[self.filename] = open(self.filename, "r+b")
f = open_files[self.filename]
f.seek(self.offset)
@ -462,9 +491,6 @@ class GGUFWriter:
if self.temp_file is None:
bar = None
# Distribute writing the tensors between multiple threads
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
# Initial file offsets before writing the tensor data
offsets: list[int] = [fout.tell() for fout in self.fout]
@ -476,60 +502,58 @@ class GGUFWriter:
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
# Fill the tensor queue with all the pending tensor writes
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
offset = offsets[i]
# Allow opening the files only once per worker
local_data = threading.local()
# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
start_offset = offset
nbytes = ti.tensor.nbytes
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
padding = offset - (start_offset + nbytes)
tensor_queue.put(
ThreadedTensorWriteInfo(
filename=filename,
offset=start_offset,
post_pad=padding,
tensor=ti.tensor,
bar=bar,
# Unit of work
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
tensor.write_chunk(local_data.open_files)
with _InterruptibleThreadPoolExecutor(
max_workers=self.thread_count,
initializer=_ThreadedOpenFiles.init_thread_local,
initargs=(local_data,),
) as executor:
futures: list[Future] = []
# Fill the tensor queue with all the pending tensor writes
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
offset = offsets[i]
# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes
start_offset = offset
nbytes = ti.tensor.nbytes
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
padding = offset - (start_offset + nbytes)
futures.append(
executor.submit(
thread_write_tensor,
_ThreadedTensorWriteInfo(
filename=filename,
offset=start_offset,
post_pad=padding,
tensor=ti.tensor,
bar=bar,
),
)
)
)
ti.tensor = None # avoid keeping a reference to written tensors
ti.tensor = None # avoid keeping a reference to written tensors
# Write tensors in parallel
# TODO: total tensor size limit for the running threads
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
# Opening the files only once per thread
open_files: dict[Path, BufferedWriter] = {}
try:
while tensor := queue.get_nowait():
tensor.write_chunk(open_files)
del tensor
queue.task_done()
except Empty:
pass
for f in open_files.values():
f.close()
threads = [
threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
for _ in range(self.thread_count)
]
for t in threads:
t.start()
# NOTE: thread joining has weird interactions with KeyboardInterrupt,
# so waiting for the queue to be "done" first.
tensor_queue.join()
for t in threads:
t.join()
# FIXME: there's still some weird behavior with KeyboardInterrupt
# not being able to interrupt a future mid-execution
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
exc = None
if any(f for f in done
if not f.cancelled() and (exc := f.exception()) is not None):
raise RuntimeError("Error writing tensors") from exc
elif len(not_done) != 0:
raise RuntimeError("Not all tensors were written")
del local_data
else:
self.temp_file.seek(0)

View File

@ -5,6 +5,14 @@ from typing import Literal
import os
import json
import time
import logging
import requests
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
@dataclass
class RemoteTensor:
name: str
dtype: str
shape: tuple[int, ...]
offset_start: int
@ -82,9 +91,30 @@ class RemoteTensor:
url: str
def data(self) -> bytearray:
# TODO: handle request errors (maybe with limited retries?)
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
data = None
MAX_RETRIES = 8
for i in range(MAX_RETRIES):
try:
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(
SafetensorRemote.get_data_by_range(
url=self.url, start=self.offset_start, size=self.size
)
)
except (
requests.exceptions.ChunkedEncodingError,
requests.exceptions.ContentDecodingError,
requests.exceptions.ConnectionError,
) as e:
if i == MAX_RETRIES - 1:
raise RuntimeError(f"Failed to download tensor {self.name}") from e
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
continue
if data is None:
raise RuntimeError(f"Failed to download tensor {self.name}")
return data
@ -169,7 +199,14 @@ class SafetensorRemote:
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
res[name] = RemoteTensor(
name=name,
dtype=dtype,
shape=tuple(shape),
offset_start=offset_start,
size=size,
url=url,
)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
@ -217,8 +254,6 @@ class SafetensorRemote:
Get raw byte data from a remote file by range.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
@ -239,9 +274,6 @@ class SafetensorRemote:
Check if a file exists at the given URL.
Returns True if the file exists, False otherwise.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")