convert : ability to lazy-load safetensors remotely without downloading to disk (#12820)

* gguf util : add SafetensorRemote

* fix style

* convert: add --remote option

* convert : allow using lazy remote tensors

It's a bit slow for now since everything is blocking and single-threaded.

* correct metadata.name

* small style fix

* support HF_TOKEN

* convert : use writeable buffer for remote lazy tensors

* convert : fix flake8 lint regarding lamdba assigment

* multithreaded download

* multithread: print debug

* fix style

* Revert "multithreaded download"

This reverts commit 42fc895ace385edc972ad819c76c704aeea61791.

* bring back _get_request_headers

---------

Co-authored-by: Francis Couture-Harpin <git@compilade.net>
This commit is contained in:
Xuan-Son Nguyen 2025-04-10 17:24:44 +02:00 committed by GitHub
parent fe5b78c896
commit 64eda5deb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 244 additions and 7 deletions

View File

@ -65,6 +65,7 @@ class Model:
model_name: str | None
metadata_override: Path | None
dir_model_card: Path
remote_hf_model_id: str | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -73,7 +74,7 @@ class Model:
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -83,11 +84,24 @@ class Model:
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.lazy = not eager or (remote_hf_model_id is not None)
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
self.get_tensors = get_remote_tensors
else:
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@ -393,6 +407,10 @@ class Model:
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
# If we are using HF model id, set the metadata name to the model id
if self.remote_hf_model_id:
self.metadata.name = self.remote_hf_model_id
# Fallback to model directory name if metadata name is still missing
if self.metadata.name is None:
self.metadata.name = self.dir_model.name
@ -5403,6 +5421,14 @@ class LazyTorchTensor(gguf.LazyBase):
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]
shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
return cast(torch.Tensor, lazy)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
@ -5480,6 +5506,10 @@ def parse_args() -> argparse.Namespace:
"--print-supported-models", action="store_true",
help="Print the supported models"
)
parser.add_argument(
"--remote", action="store_true",
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
@ -5520,6 +5550,14 @@ def main() -> None:
dir_model = args.model
if args.remote:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
repo_id=str(dir_model),
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
@ -5541,6 +5579,9 @@ def main() -> None:
if args.outfile is not None:
fname_out = args.outfile
elif args.remote:
# if remote, use the model ID as the output file name
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
else:
fname_out = dir_model
@ -5564,7 +5605,8 @@ def main() -> None:
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None)
if args.vocab_only:
logger.info("Exporting model vocab...")

View File

@ -1,7 +1,11 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
import os
import json
def fill_templated_filename(filename: str, output_type: str | None) -> str:
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
@dataclass
class RemoteTensor:
dtype: str
shape: tuple[int, ...]
offset_start: int
size: int
url: str
def data(self) -> bytearray:
# TODO: handle request errors (maybe with limited retries?)
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
return data
class SafetensorRemote:
"""
Uility class to handle remote safetensor files.
This class is designed to work with Hugging Face model repositories.
Example (one model has single safetensor file, the other has multiple):
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
print(tensors)
Example reading tensor data:
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
for name, meta in tensors.items():
dtype, shape, offset_start, size, remote_safetensor_url = meta
# read the tensor data
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
print(data)
"""
BASE_DOMAIN = "https://huggingface.co"
ALIGNMENT = 8 # bytes
@classmethod
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a Hugging Face model repository.
Returns a dictionary of tensor names and their metadata.
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
"""
# case 1: model has only one single model.safetensor file
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
if is_single_file:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
return cls.get_list_tensors(url)
# case 2: model has multiple files
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
is_multiple_files = cls.check_file_exist(index_url)
if is_multiple_files:
# read the index file
index_data = cls.get_data_by_range(index_url, 0)
index_str = index_data.decode('utf-8')
index_json = json.loads(index_str)
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
weight_map = index_json["weight_map"]
# get the list of files
all_files = list(set(weight_map.values()))
all_files.sort() # make sure we load shard files in order
# get the list of tensors
tensors: dict[str, RemoteTensor] = {}
for file in all_files:
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
for key, val in cls.get_list_tensors(url).items():
tensors[key] = val
return tensors
raise ValueError(f"Model {model_id} does not have any safetensor files")
@classmethod
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
"""
Get list of tensors from a remote safetensor file.
Returns a dictionary of tensor names and their metadata.
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
"""
metadata, data_start_offset = cls.get_metadata(url)
res: dict[str, RemoteTensor] = {}
for name, meta in metadata.items():
if name == "__metadata__":
continue
if not isinstance(meta, dict):
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
try:
dtype = meta["dtype"]
shape = meta["shape"]
offset_start_relative, offset_end_relative = meta["data_offsets"]
size = offset_end_relative - offset_start_relative
offset_start = data_start_offset + offset_start_relative
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
return res
@classmethod
def get_metadata(cls, url: str) -> tuple[dict, int]:
"""
Get JSON metadata from a remote safetensor file.
Returns tuple of (metadata, data_start_offset)
"""
# Request first 5MB of the file (hopefully enough for metadata)
read_size = 5 * 1024 * 1024
raw_data = cls.get_data_by_range(url, 0, read_size)
# Parse header
# First 8 bytes contain the metadata length as u64 little-endian
if len(raw_data) < 8:
raise ValueError("Not enough data to read metadata size")
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
# Calculate the data start offset
data_start_offset = 8 + metadata_length
alignment = SafetensorRemote.ALIGNMENT
if data_start_offset % alignment != 0:
data_start_offset += alignment - (data_start_offset % alignment)
# Check if we have enough data to read the metadata
if len(raw_data) < 8 + metadata_length:
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
# Extract metadata bytes and parse as JSON
metadata_bytes = raw_data[8:8 + metadata_length]
metadata_str = metadata_bytes.decode('utf-8')
try:
metadata = json.loads(metadata_str)
return metadata, data_start_offset
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
@classmethod
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
"""
Get raw byte data from a remote file by range.
If size is not specified, it will read the entire file.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")
headers = cls._get_request_headers()
if size > -1:
headers["Range"] = f"bytes={start}-{start + size}"
response = requests.get(url, allow_redirects=True, headers=headers)
response.raise_for_status()
# Get raw byte data
return response.content[:size]
@classmethod
def check_file_exist(cls, url: str) -> bool:
"""
Check if a file exists at the given URL.
Returns True if the file exists, False otherwise.
"""
import requests
from urllib.parse import urlparse
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL: {url}")
try:
headers = cls._get_request_headers()
headers["Range"] = "bytes=0-0"
response = requests.head(url, allow_redirects=True, headers=headers)
# Success (2xx) or redirect (3xx)
return 200 <= response.status_code < 400
except requests.RequestException:
return False
@classmethod
def _get_request_headers(cls) -> dict[str, str]:
"""Prepare common headers for requests."""
headers = {"User-Agent": "convert_hf_to_gguf"}
if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers