1from __future__ import annotations
2
3from dataclasses import dataclass
4from pathlib import Path
5from typing import Literal
6
7import os
8import json
9import numpy as np
10
11
12def fill_templated_filename(filename: str, output_type: str | None) -> str:
13 # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
14 ftype_lowercase: str = output_type.lower() if output_type is not None else ""
15 ftype_uppercase: str = output_type.upper() if output_type is not None else ""
16 return filename.format(ftype_lowercase,
17 outtype=ftype_lowercase, ftype=ftype_lowercase,
18 OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
19
20
21def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
22 if model_params_count > 1e12 :
23 # Trillions Of Parameters
24 scaled_model_params = model_params_count * 1e-12
25 scale_suffix = "T"
26 elif model_params_count > 1e9 :
27 # Billions Of Parameters
28 scaled_model_params = model_params_count * 1e-9
29 scale_suffix = "B"
30 elif model_params_count > 1e6 :
31 # Millions Of Parameters
32 scaled_model_params = model_params_count * 1e-6
33 scale_suffix = "M"
34 else:
35 # Thousands Of Parameters
36 scaled_model_params = model_params_count * 1e-3
37 scale_suffix = "K"
38
39 fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
40
41 return f"{scaled_model_params:.{fix}f}{scale_suffix}"
42
43
44def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
45
46 if expert_count > 0:
47 pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
48 size_class = f"{expert_count}x{pretty_size}"
49 else:
50 size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
51
52 return size_class
53
54
55def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
56 # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
57
58 if base_name is not None:
59 name = base_name.strip().replace(' ', '-').replace('/', '-')
60 elif model_name is not None:
61 name = model_name.strip().replace(' ', '-').replace('/', '-')
62 else:
63 name = "ggml-model"
64
65 parameters = f"-{size_label}" if size_label is not None else ""
66
67 finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
68
69 version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
70
71 encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
72
73 kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
74
75 return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
76
77
78@dataclass
79class RemoteTensor:
80 dtype: str
81 shape: tuple[int, ...]
82 offset_start: int
83 size: int
84 url: str
85
86 def data(self) -> bytearray:
87 # TODO: handle request errors (maybe with limited retries?)
88 # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
89 data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
90 return data
91
92
93class SafetensorRemote:
94 """
95 Uility class to handle remote safetensor files.
96 This class is designed to work with Hugging Face model repositories.
97
98 Example (one model has single safetensor file, the other has multiple):
99 for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
100 tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
101 print(tensors)
102
103 Example reading tensor data:
104 tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
105 for name, meta in tensors.items():
106 dtype, shape, offset_start, size, remote_safetensor_url = meta
107 # read the tensor data
108 data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
109 print(data)
110 """
111
112 BASE_DOMAIN = "https://huggingface.co"
113
114 @classmethod
115 def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
116 """
117 Get list of tensors from a Hugging Face model repository.
118
119 Returns a dictionary of tensor names and their metadata.
120 Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
121 """
122 # case 1: model has only one single model.safetensor file
123 is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
124 if is_single_file:
125 url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
126 return cls.get_list_tensors(url)
127
128 # case 2: model has multiple files
129 index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
130 is_multiple_files = cls.check_file_exist(index_url)
131 if is_multiple_files:
132 # read the index file
133 index_data = cls.get_data_by_range(index_url, 0)
134 index_str = index_data.decode('utf-8')
135 index_json = json.loads(index_str)
136 assert index_json.get("weight_map") is not None, "weight_map not found in index file"
137 weight_map = index_json["weight_map"]
138 # get the list of files
139 all_files = list(set(weight_map.values()))
140 all_files.sort() # make sure we load shard files in order
141 # get the list of tensors
142 tensors: dict[str, RemoteTensor] = {}
143 for file in all_files:
144 url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
145 for key, val in cls.get_list_tensors(url).items():
146 tensors[key] = val
147 return tensors
148
149 raise ValueError(
150 f"No safetensor file has been found for model {model_id}."
151 "If the repo has safetensor files, make sure the model is public or you have a "
152 "valid Hugging Face token set in the environment variable HF_TOKEN."
153 )
154
155 @classmethod
156 def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
157 """
158 Get list of tensors from a remote safetensor file.
159
160 Returns a dictionary of tensor names and their metadata.
161 Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
162 """
163 metadata, data_start_offset = cls.get_metadata(url)
164 res: dict[str, RemoteTensor] = {}
165
166 for name, meta in metadata.items():
167 if name == "__metadata__":
168 continue
169 if not isinstance(meta, dict):
170 raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
171 try:
172 dtype = meta["dtype"]
173 shape = meta["shape"]
174 offset_start_relative, offset_end_relative = meta["data_offsets"]
175 size = offset_end_relative - offset_start_relative
176 offset_start = data_start_offset + offset_start_relative
177 res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
178 except KeyError as e:
179 raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
180
181 # order by name (same as default safetensors behavior)
182 # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
183 res = dict(sorted(res.items(), key=lambda t: t[0]))
184
185 return res
186
187 @classmethod
188 def get_metadata(cls, url: str) -> tuple[dict, int]:
189 """
190 Get JSON metadata from a remote safetensor file.
191
192 Returns tuple of (metadata, data_start_offset)
193 """
194 # Request first 5MB of the file (hopefully enough for metadata)
195 read_size = 5 * 1024 * 1024
196 raw_data = cls.get_data_by_range(url, 0, read_size)
197
198 # Parse header
199 # First 8 bytes contain the metadata length as u64 little-endian
200 if len(raw_data) < 8:
201 raise ValueError("Not enough data to read metadata size")
202 metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
203
204 # Calculate the data start offset
205 data_start_offset = 8 + metadata_length
206
207 # Check if we have enough data to read the metadata
208 if len(raw_data) < 8 + metadata_length:
209 raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
210
211 # Extract metadata bytes and parse as JSON
212 metadata_bytes = raw_data[8:8 + metadata_length]
213 metadata_str = metadata_bytes.decode('utf-8')
214 try:
215 metadata = json.loads(metadata_str)
216 return metadata, data_start_offset
217 except json.JSONDecodeError as e:
218 raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
219
220 @classmethod
221 def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
222 """
223 Get raw byte data from a remote file by range.
224 If size is not specified, it will read the entire file.
225 """
226 import requests
227 from urllib.parse import urlparse
228
229 parsed_url = urlparse(url)
230 if not parsed_url.scheme or not parsed_url.netloc:
231 raise ValueError(f"Invalid URL: {url}")
232
233 headers = cls._get_request_headers()
234 if size > -1:
235 headers["Range"] = f"bytes={start}-{start + size}"
236 response = requests.get(url, allow_redirects=True, headers=headers)
237 response.raise_for_status()
238
239 # Get raw byte data
240 return response.content[slice(size if size > -1 else None)]
241
242 @classmethod
243 def check_file_exist(cls, url: str) -> bool:
244 """
245 Check if a file exists at the given URL.
246 Returns True if the file exists, False otherwise.
247 """
248 import requests
249 from urllib.parse import urlparse
250
251 parsed_url = urlparse(url)
252 if not parsed_url.scheme or not parsed_url.netloc:
253 raise ValueError(f"Invalid URL: {url}")
254
255 try:
256 headers = cls._get_request_headers()
257 headers["Range"] = "bytes=0-0"
258 response = requests.head(url, allow_redirects=True, headers=headers)
259 # Success (2xx) or redirect (3xx)
260 return 200 <= response.status_code < 400
261 except requests.RequestException:
262 return False
263
264 @classmethod
265 def _get_request_headers(cls) -> dict[str, str]:
266 """Prepare common headers for requests."""
267 headers = {"User-Agent": "convert_hf_to_gguf"}
268 if os.environ.get("HF_TOKEN"):
269 headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
270 return headers
271
272
273@dataclass
274class LocalTensorRange:
275 filename: Path
276 offset: int
277 size: int
278
279
280@dataclass
281class LocalTensor:
282 dtype: str
283 shape: tuple[int, ...]
284 data_range: LocalTensorRange
285
286 def mmap_bytes(self) -> np.ndarray:
287 return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
288
289
290class SafetensorsLocal:
291 """
292 Read a safetensors file from the local filesystem.
293
294 Custom parsing gives a bit more control over the memory usage.
295 The official safetensors library doesn't expose file ranges.
296 """
297
298 tensors: dict[str, LocalTensor]
299
300 def __init__(self, filename: Path):
301 with open(filename, "rb") as f:
302 metadata_length = int.from_bytes(f.read(8), byteorder='little')
303 file_size = os.stat(filename).st_size
304 if file_size < 8 + metadata_length:
305 raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
306
307 metadata_str = f.read(metadata_length).decode('utf-8')
308 try:
309 metadata = json.loads(metadata_str)
310 except json.JSONDecodeError as e:
311 raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
312
313 data_start_offset = f.tell()
314
315 tensors: dict[str, LocalTensor] = {}
316 for name, meta in metadata.items():
317 if name == "__metadata__":
318 # ignore metadata, it's not a tensor
319 continue
320
321 tensors[name] = LocalTensor(
322 dtype=meta["dtype"],
323 shape=tuple(meta["shape"]),
324 data_range=LocalTensorRange(
325 filename,
326 data_start_offset + meta["data_offsets"][0],
327 meta["data_offsets"][1] - meta["data_offsets"][0],
328 ),
329 )
330
331 # order by name (same as default safetensors behavior)
332 # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
333 self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
334
335 def __enter__(self, *args, **kwargs):
336 del args, kwargs # unused
337 return self.tensors
338
339 def __exit__(self, *args, **kwargs):
340 del args, kwargs # unused