From 3fe362fe497ff6040d206c5228b181ec2e977024 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin <git@compilade.net> Date: Sat, 12 Apr 2025 00:00:51 -0400 Subject: [PATCH] gguf-py : use ThreadPoolExecutor when writing tensors - gguf-py : handle (limited) retries for remote tensors --- gguf-py/gguf/gguf_writer.py | 140 +++++++++++++++++++++--------------- gguf-py/gguf/utility.py | 50 ++++++++++--- 2 files changed, 123 insertions(+), 67 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index db8ad4f05..ea283c57f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -10,10 +10,10 @@ from dataclasses import dataclass from enum import Enum, auto from math import prod from pathlib import Path -from queue import Empty, Queue from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait import numpy as np @@ -62,20 +62,49 @@ class WriterState(Enum): WEIGHTS = auto() +# To close files which were opened in thread-local context +# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer +# ref: https://github.com/python/cpython/issues/89502 +class _ThreadedOpenFiles: + files: dict[Path, BufferedWriter] + + def __init__(self): + self.files = {} + + def __del__(self): + for file in self.files.values(): + file.close() + + def __getitem__(self, key: Path, /) -> BufferedWriter: + if key not in self.files: + self.files[key] = open(key, "r+b") + return self.files[key] + + @classmethod + def init_thread_local(cls, local_data): + local_data.open_files = _ThreadedOpenFiles() + + +# Exit quickly instead of waiting +class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor): + def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None: + del exc_type, exc_val, exc_tb + self.shutdown(wait=False, cancel_futures=True) + return False + + @dataclass -class ThreadedTensorWriteInfo: +class _ThreadedTensorWriteInfo: filename: Path offset: int post_pad: int tensor: np.ndarray bar: Any | None # optional tqdm progress bar - def write_chunk(self, open_files: dict[Path, BufferedWriter]): + def write_chunk(self, open_files: _ThreadedOpenFiles): # This is called from a thread pool, # and each thread should have its own file handle per output file # so that they can have different seek locations. - if self.filename not in open_files: - open_files[self.filename] = open(self.filename, "r+b") f = open_files[self.filename] f.seek(self.offset) @@ -462,9 +491,6 @@ class GGUFWriter: if self.temp_file is None: bar = None - # Distribute writing the tensors between multiple threads - tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue() - # Initial file offsets before writing the tensor data offsets: list[int] = [fout.tell() for fout in self.fout] @@ -476,60 +502,58 @@ class GGUFWriter: bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - # Fill the tensor queue with all the pending tensor writes - for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): - offset = offsets[i] + # Allow opening the files only once per worker + local_data = threading.local() - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - start_offset = offset - nbytes = ti.tensor.nbytes - offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) - padding = offset - (start_offset + nbytes) - tensor_queue.put( - ThreadedTensorWriteInfo( - filename=filename, - offset=start_offset, - post_pad=padding, - tensor=ti.tensor, - bar=bar, + # Unit of work + def thread_write_tensor(tensor: _ThreadedTensorWriteInfo): + tensor.write_chunk(local_data.open_files) + + with _InterruptibleThreadPoolExecutor( + max_workers=self.thread_count, + initializer=_ThreadedOpenFiles.init_thread_local, + initargs=(local_data,), + ) as executor: + + futures: list[Future] = [] + + # Fill the tensor queue with all the pending tensor writes + for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): + offset = offsets[i] + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + start_offset = offset + nbytes = ti.tensor.nbytes + offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) + padding = offset - (start_offset + nbytes) + futures.append( + executor.submit( + thread_write_tensor, + _ThreadedTensorWriteInfo( + filename=filename, + offset=start_offset, + post_pad=padding, + tensor=ti.tensor, + bar=bar, + ), + ) ) - ) - ti.tensor = None # avoid keeping a reference to written tensors + ti.tensor = None # avoid keeping a reference to written tensors - # Write tensors in parallel - # TODO: total tensor size limit for the running threads - def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]): - # Opening the files only once per thread - open_files: dict[Path, BufferedWriter] = {} - try: - while tensor := queue.get_nowait(): - tensor.write_chunk(open_files) - del tensor - queue.task_done() - except Empty: - pass - - for f in open_files.values(): - f.close() - - threads = [ - threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,)) - for _ in range(self.thread_count) - ] - - for t in threads: - t.start() - - # NOTE: thread joining has weird interactions with KeyboardInterrupt, - # so waiting for the queue to be "done" first. - tensor_queue.join() - - for t in threads: - t.join() + # FIXME: there's still some weird behavior with KeyboardInterrupt + # not being able to interrupt a future mid-execution + done, not_done = wait(futures, return_when=FIRST_EXCEPTION) + exc = None + if any(f for f in done + if not f.cancelled() and (exc := f.exception()) is not None): + raise RuntimeError("Error writing tensors") from exc + elif len(not_done) != 0: + raise RuntimeError("Not all tensors were written") + del local_data else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index e5251aef8..0734b9f25 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -5,6 +5,14 @@ from typing import Literal import os import json +import time +import logging + +import requests +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st @dataclass class RemoteTensor: + name: str dtype: str shape: tuple[int, ...] offset_start: int @@ -82,9 +91,30 @@ class RemoteTensor: url: str def data(self) -> bytearray: - # TODO: handle request errors (maybe with limited retries?) - # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable - data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + data = None + MAX_RETRIES = 8 + for i in range(MAX_RETRIES): + try: + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray( + SafetensorRemote.get_data_by_range( + url=self.url, start=self.offset_start, size=self.size + ) + ) + except ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ContentDecodingError, + requests.exceptions.ConnectionError, + ) as e: + if i == MAX_RETRIES - 1: + raise RuntimeError(f"Failed to download tensor {self.name}") from e + logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}") + time.sleep(2 * i + 1) # 1 3 5 7 9 11 13 + continue + + if data is None: + raise RuntimeError(f"Failed to download tensor {self.name}") + return data @@ -169,7 +199,14 @@ class SafetensorRemote: offset_start_relative, offset_end_relative = meta["data_offsets"] size = offset_end_relative - offset_start_relative offset_start = data_start_offset + offset_start_relative - res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + res[name] = RemoteTensor( + name=name, + dtype=dtype, + shape=tuple(shape), + offset_start=offset_start, + size=size, + url=url, + ) except KeyError as e: raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") @@ -217,8 +254,6 @@ class SafetensorRemote: Get raw byte data from a remote file by range. If size is not specified, it will read the entire file. """ - import requests - from urllib.parse import urlparse parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: @@ -239,9 +274,6 @@ class SafetensorRemote: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ - import requests - from urllib.parse import urlparse - parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}")