1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4from pathlib import Path
  5from typing import Literal
  6
  7import os
  8import json
  9import numpy as np
 10
 11
 12def fill_templated_filename(filename: str, output_type: str | None) -> str:
 13    # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
 14    ftype_lowercase: str = output_type.lower() if output_type is not None else ""
 15    ftype_uppercase: str = output_type.upper() if output_type is not None else ""
 16    return filename.format(ftype_lowercase,
 17                           outtype=ftype_lowercase, ftype=ftype_lowercase,
 18                           OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
 19
 20
 21def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
 22    if model_params_count > 1e12 :
 23        # Trillions Of Parameters
 24        scaled_model_params = model_params_count * 1e-12
 25        scale_suffix = "T"
 26    elif model_params_count > 1e9 :
 27        # Billions Of Parameters
 28        scaled_model_params = model_params_count * 1e-9
 29        scale_suffix = "B"
 30    elif model_params_count > 1e6 :
 31        # Millions Of Parameters
 32        scaled_model_params = model_params_count * 1e-6
 33        scale_suffix = "M"
 34    else:
 35        # Thousands Of Parameters
 36        scaled_model_params = model_params_count * 1e-3
 37        scale_suffix = "K"
 38
 39    fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
 40
 41    return f"{scaled_model_params:.{fix}f}{scale_suffix}"
 42
 43
 44def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
 45
 46    if expert_count > 0:
 47        pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
 48        size_class = f"{expert_count}x{pretty_size}"
 49    else:
 50        size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
 51
 52    return size_class
 53
 54
 55def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
 56    # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
 57
 58    if base_name is not None:
 59        name = base_name.strip().replace(' ', '-').replace('/', '-')
 60    elif model_name is not None:
 61        name = model_name.strip().replace(' ', '-').replace('/', '-')
 62    else:
 63        name = "ggml-model"
 64
 65    parameters = f"-{size_label}" if size_label is not None else ""
 66
 67    finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
 68
 69    version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
 70
 71    encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
 72
 73    kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
 74
 75    return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
 76
 77
 78@dataclass
 79class RemoteTensor:
 80    dtype: str
 81    shape: tuple[int, ...]
 82    offset_start: int
 83    size: int
 84    url: str
 85
 86    def data(self) -> bytearray:
 87        # TODO: handle request errors (maybe with limited retries?)
 88        # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
 89        data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
 90        return data
 91
 92
 93class SafetensorRemote:
 94    """
 95    Uility class to handle remote safetensor files.
 96    This class is designed to work with Hugging Face model repositories.
 97
 98    Example (one model has single safetensor file, the other has multiple):
 99        for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
100            tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
101            print(tensors)
102
103    Example reading tensor data:
104        tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
105        for name, meta in tensors.items():
106            dtype, shape, offset_start, size, remote_safetensor_url = meta
107            # read the tensor data
108            data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
109            print(data)
110    """
111
112    BASE_DOMAIN = "https://huggingface.co"
113
114    @classmethod
115    def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
116        """
117        Get list of tensors from a Hugging Face model repository.
118
119        Returns a dictionary of tensor names and their metadata.
120        Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
121        """
122        # case 1: model has only one single model.safetensor file
123        is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
124        if is_single_file:
125            url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
126            return cls.get_list_tensors(url)
127
128        # case 2: model has multiple files
129        index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
130        is_multiple_files = cls.check_file_exist(index_url)
131        if is_multiple_files:
132            # read the index file
133            index_data = cls.get_data_by_range(index_url, 0)
134            index_str = index_data.decode('utf-8')
135            index_json = json.loads(index_str)
136            assert index_json.get("weight_map") is not None, "weight_map not found in index file"
137            weight_map = index_json["weight_map"]
138            # get the list of files
139            all_files = list(set(weight_map.values()))
140            all_files.sort() # make sure we load shard files in order
141            # get the list of tensors
142            tensors: dict[str, RemoteTensor] = {}
143            for file in all_files:
144                url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
145                for key, val in cls.get_list_tensors(url).items():
146                    tensors[key] = val
147            return tensors
148
149        raise ValueError(
150            f"No safetensor file has been found for model {model_id}."
151            "If the repo has safetensor files, make sure the model is public or you have a "
152            "valid Hugging Face token set in the environment variable HF_TOKEN."
153        )
154
155    @classmethod
156    def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
157        """
158        Get list of tensors from a remote safetensor file.
159
160        Returns a dictionary of tensor names and their metadata.
161        Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
162        """
163        metadata, data_start_offset = cls.get_metadata(url)
164        res: dict[str, RemoteTensor] = {}
165
166        for name, meta in metadata.items():
167            if name == "__metadata__":
168                continue
169            if not isinstance(meta, dict):
170                raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
171            try:
172                dtype = meta["dtype"]
173                shape = meta["shape"]
174                offset_start_relative, offset_end_relative = meta["data_offsets"]
175                size = offset_end_relative - offset_start_relative
176                offset_start = data_start_offset + offset_start_relative
177                res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
178            except KeyError as e:
179                raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
180
181        # order by name (same as default safetensors behavior)
182        # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
183        res = dict(sorted(res.items(), key=lambda t: t[0]))
184
185        return res
186
187    @classmethod
188    def get_metadata(cls, url: str) -> tuple[dict, int]:
189        """
190        Get JSON metadata from a remote safetensor file.
191
192        Returns tuple of (metadata, data_start_offset)
193        """
194        # Request first 5MB of the file (hopefully enough for metadata)
195        read_size = 5 * 1024 * 1024
196        raw_data = cls.get_data_by_range(url, 0, read_size)
197
198        # Parse header
199        # First 8 bytes contain the metadata length as u64 little-endian
200        if len(raw_data) < 8:
201            raise ValueError("Not enough data to read metadata size")
202        metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
203
204        # Calculate the data start offset
205        data_start_offset = 8 + metadata_length
206
207        # Check if we have enough data to read the metadata
208        if len(raw_data) < 8 + metadata_length:
209            raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
210
211        # Extract metadata bytes and parse as JSON
212        metadata_bytes = raw_data[8:8 + metadata_length]
213        metadata_str = metadata_bytes.decode('utf-8')
214        try:
215            metadata = json.loads(metadata_str)
216            return metadata, data_start_offset
217        except json.JSONDecodeError as e:
218            raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
219
220    @classmethod
221    def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
222        """
223        Get raw byte data from a remote file by range.
224        If size is not specified, it will read the entire file.
225        """
226        import requests
227        from urllib.parse import urlparse
228
229        parsed_url = urlparse(url)
230        if not parsed_url.scheme or not parsed_url.netloc:
231            raise ValueError(f"Invalid URL: {url}")
232
233        headers = cls._get_request_headers()
234        if size > -1:
235            headers["Range"] = f"bytes={start}-{start + size}"
236        response = requests.get(url, allow_redirects=True, headers=headers)
237        response.raise_for_status()
238
239        # Get raw byte data
240        return response.content[slice(size if size > -1 else None)]
241
242    @classmethod
243    def check_file_exist(cls, url: str) -> bool:
244        """
245        Check if a file exists at the given URL.
246        Returns True if the file exists, False otherwise.
247        """
248        import requests
249        from urllib.parse import urlparse
250
251        parsed_url = urlparse(url)
252        if not parsed_url.scheme or not parsed_url.netloc:
253            raise ValueError(f"Invalid URL: {url}")
254
255        try:
256            headers = cls._get_request_headers()
257            headers["Range"] = "bytes=0-0"
258            response = requests.head(url, allow_redirects=True, headers=headers)
259            # Success (2xx) or redirect (3xx)
260            return 200 <= response.status_code < 400
261        except requests.RequestException:
262            return False
263
264    @classmethod
265    def _get_request_headers(cls) -> dict[str, str]:
266        """Prepare common headers for requests."""
267        headers = {"User-Agent": "convert_hf_to_gguf"}
268        if os.environ.get("HF_TOKEN"):
269            headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
270        return headers
271
272
273@dataclass
274class LocalTensorRange:
275    filename: Path
276    offset: int
277    size: int
278
279
280@dataclass
281class LocalTensor:
282    dtype: str
283    shape: tuple[int, ...]
284    data_range: LocalTensorRange
285
286    def mmap_bytes(self) -> np.ndarray:
287        return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
288
289
290class SafetensorsLocal:
291    """
292        Read a safetensors file from the local filesystem.
293
294        Custom parsing gives a bit more control over the memory usage.
295        The official safetensors library doesn't expose file ranges.
296    """
297
298    tensors: dict[str, LocalTensor]
299
300    def __init__(self, filename: Path):
301        with open(filename, "rb") as f:
302            metadata_length = int.from_bytes(f.read(8), byteorder='little')
303            file_size = os.stat(filename).st_size
304            if file_size < 8 + metadata_length:
305                raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
306
307            metadata_str = f.read(metadata_length).decode('utf-8')
308            try:
309                metadata = json.loads(metadata_str)
310            except json.JSONDecodeError as e:
311                raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
312
313            data_start_offset = f.tell()
314
315            tensors: dict[str, LocalTensor] = {}
316            for name, meta in metadata.items():
317                if name == "__metadata__":
318                    # ignore metadata, it's not a tensor
319                    continue
320
321                tensors[name] = LocalTensor(
322                    dtype=meta["dtype"],
323                    shape=tuple(meta["shape"]),
324                    data_range=LocalTensorRange(
325                        filename,
326                        data_start_offset + meta["data_offsets"][0],
327                        meta["data_offsets"][1] - meta["data_offsets"][0],
328                    ),
329                )
330
331            # order by name (same as default safetensors behavior)
332            # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
333            self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
334
335    def __enter__(self, *args, **kwargs):
336        del args, kwargs  # unused
337        return self.tensors
338
339    def __exit__(self, *args, **kwargs):
340        del args, kwargs  # unused