mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-04-14 10:36:07 +00:00
gguf-py : use ThreadPoolExecutor when writing tensors
- gguf-py : handle (limited) retries for remote tensors
This commit is contained in:
parent
d7db1593ee
commit
3fe362fe49
@ -10,10 +10,10 @@ from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from math import prod
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
from string import ascii_letters, digits
|
||||
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -62,20 +62,49 @@ class WriterState(Enum):
|
||||
WEIGHTS = auto()
|
||||
|
||||
|
||||
# To close files which were opened in thread-local context
|
||||
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
|
||||
# ref: https://github.com/python/cpython/issues/89502
|
||||
class _ThreadedOpenFiles:
|
||||
files: dict[Path, BufferedWriter]
|
||||
|
||||
def __init__(self):
|
||||
self.files = {}
|
||||
|
||||
def __del__(self):
|
||||
for file in self.files.values():
|
||||
file.close()
|
||||
|
||||
def __getitem__(self, key: Path, /) -> BufferedWriter:
|
||||
if key not in self.files:
|
||||
self.files[key] = open(key, "r+b")
|
||||
return self.files[key]
|
||||
|
||||
@classmethod
|
||||
def init_thread_local(cls, local_data):
|
||||
local_data.open_files = _ThreadedOpenFiles()
|
||||
|
||||
|
||||
# Exit quickly instead of waiting
|
||||
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
|
||||
del exc_type, exc_val, exc_tb
|
||||
self.shutdown(wait=False, cancel_futures=True)
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadedTensorWriteInfo:
|
||||
class _ThreadedTensorWriteInfo:
|
||||
filename: Path
|
||||
offset: int
|
||||
post_pad: int
|
||||
tensor: np.ndarray
|
||||
bar: Any | None # optional tqdm progress bar
|
||||
|
||||
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
|
||||
def write_chunk(self, open_files: _ThreadedOpenFiles):
|
||||
# This is called from a thread pool,
|
||||
# and each thread should have its own file handle per output file
|
||||
# so that they can have different seek locations.
|
||||
if self.filename not in open_files:
|
||||
open_files[self.filename] = open(self.filename, "r+b")
|
||||
f = open_files[self.filename]
|
||||
|
||||
f.seek(self.offset)
|
||||
@ -462,9 +491,6 @@ class GGUFWriter:
|
||||
|
||||
if self.temp_file is None:
|
||||
bar = None
|
||||
# Distribute writing the tensors between multiple threads
|
||||
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
|
||||
|
||||
# Initial file offsets before writing the tensor data
|
||||
offsets: list[int] = [fout.tell() for fout in self.fout]
|
||||
|
||||
@ -476,60 +502,58 @@ class GGUFWriter:
|
||||
|
||||
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
|
||||
|
||||
# Fill the tensor queue with all the pending tensor writes
|
||||
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
|
||||
offset = offsets[i]
|
||||
# Allow opening the files only once per worker
|
||||
local_data = threading.local()
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
start_offset = offset
|
||||
nbytes = ti.tensor.nbytes
|
||||
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
|
||||
padding = offset - (start_offset + nbytes)
|
||||
tensor_queue.put(
|
||||
ThreadedTensorWriteInfo(
|
||||
filename=filename,
|
||||
offset=start_offset,
|
||||
post_pad=padding,
|
||||
tensor=ti.tensor,
|
||||
bar=bar,
|
||||
# Unit of work
|
||||
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
|
||||
tensor.write_chunk(local_data.open_files)
|
||||
|
||||
with _InterruptibleThreadPoolExecutor(
|
||||
max_workers=self.thread_count,
|
||||
initializer=_ThreadedOpenFiles.init_thread_local,
|
||||
initargs=(local_data,),
|
||||
) as executor:
|
||||
|
||||
futures: list[Future] = []
|
||||
|
||||
# Fill the tensor queue with all the pending tensor writes
|
||||
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
|
||||
offset = offsets[i]
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
assert ti.tensor is not None # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
start_offset = offset
|
||||
nbytes = ti.tensor.nbytes
|
||||
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
|
||||
padding = offset - (start_offset + nbytes)
|
||||
futures.append(
|
||||
executor.submit(
|
||||
thread_write_tensor,
|
||||
_ThreadedTensorWriteInfo(
|
||||
filename=filename,
|
||||
offset=start_offset,
|
||||
post_pad=padding,
|
||||
tensor=ti.tensor,
|
||||
bar=bar,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
ti.tensor = None # avoid keeping a reference to written tensors
|
||||
ti.tensor = None # avoid keeping a reference to written tensors
|
||||
|
||||
# Write tensors in parallel
|
||||
# TODO: total tensor size limit for the running threads
|
||||
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
|
||||
# Opening the files only once per thread
|
||||
open_files: dict[Path, BufferedWriter] = {}
|
||||
try:
|
||||
while tensor := queue.get_nowait():
|
||||
tensor.write_chunk(open_files)
|
||||
del tensor
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
for f in open_files.values():
|
||||
f.close()
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
|
||||
for _ in range(self.thread_count)
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# NOTE: thread joining has weird interactions with KeyboardInterrupt,
|
||||
# so waiting for the queue to be "done" first.
|
||||
tensor_queue.join()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
# FIXME: there's still some weird behavior with KeyboardInterrupt
|
||||
# not being able to interrupt a future mid-execution
|
||||
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
|
||||
exc = None
|
||||
if any(f for f in done
|
||||
if not f.cancelled() and (exc := f.exception()) is not None):
|
||||
raise RuntimeError("Error writing tensors") from exc
|
||||
elif len(not_done) != 0:
|
||||
raise RuntimeError("Not all tensors were written")
|
||||
|
||||
del local_data
|
||||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
|
@ -5,6 +5,14 @@ from typing import Literal
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
||||
@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
||||
|
||||
@dataclass
|
||||
class RemoteTensor:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: tuple[int, ...]
|
||||
offset_start: int
|
||||
@ -82,9 +91,30 @@ class RemoteTensor:
|
||||
url: str
|
||||
|
||||
def data(self) -> bytearray:
|
||||
# TODO: handle request errors (maybe with limited retries?)
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
|
||||
data = None
|
||||
MAX_RETRIES = 8
|
||||
for i in range(MAX_RETRIES):
|
||||
try:
|
||||
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
|
||||
data = bytearray(
|
||||
SafetensorRemote.get_data_by_range(
|
||||
url=self.url, start=self.offset_start, size=self.size
|
||||
)
|
||||
)
|
||||
except (
|
||||
requests.exceptions.ChunkedEncodingError,
|
||||
requests.exceptions.ContentDecodingError,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as e:
|
||||
if i == MAX_RETRIES - 1:
|
||||
raise RuntimeError(f"Failed to download tensor {self.name}") from e
|
||||
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
|
||||
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError(f"Failed to download tensor {self.name}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@ -169,7 +199,14 @@ class SafetensorRemote:
|
||||
offset_start_relative, offset_end_relative = meta["data_offsets"]
|
||||
size = offset_end_relative - offset_start_relative
|
||||
offset_start = data_start_offset + offset_start_relative
|
||||
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
|
||||
res[name] = RemoteTensor(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=tuple(shape),
|
||||
offset_start=offset_start,
|
||||
size=size,
|
||||
url=url,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
|
||||
|
||||
@ -217,8 +254,6 @@ class SafetensorRemote:
|
||||
Get raw byte data from a remote file by range.
|
||||
If size is not specified, it will read the entire file.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
@ -239,9 +274,6 @@ class SafetensorRemote:
|
||||
Check if a file exists at the given URL.
|
||||
Returns True if the file exists, False otherwise.
|
||||
"""
|
||||
import requests
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
raise ValueError(f"Invalid URL: {url}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user