1#
  2# GGUF file reading/modification support. For API usage information,
  3# please see the files scripts/ for some fairly simple examples.
  4#
  5from __future__ import annotations
  6
  7import logging
  8import os
  9import sys
 10from collections import OrderedDict
 11from typing import Any, Literal, NamedTuple, TypeVar, Union
 12
 13import numpy as np
 14import numpy.typing as npt
 15
 16from .quants import quant_shape_to_byte_shape
 17
 18if __name__ == "__main__":
 19    from pathlib import Path
 20
 21    # Allow running file in package as a script.
 22    sys.path.insert(0, str(Path(__file__).parent.parent))
 23
 24from gguf.constants import (
 25    GGML_QUANT_SIZES,
 26    GGUF_DEFAULT_ALIGNMENT,
 27    GGUF_MAGIC,
 28    GGUF_VERSION,
 29    GGMLQuantizationType,
 30    GGUFValueType,
 31    GGUFEndian,
 32)
 33
 34logger = logging.getLogger(__name__)
 35
 36READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
 37
 38
 39class ReaderField(NamedTuple):
 40    # Offset to start of this field.
 41    offset: int
 42
 43    # Name of the field (not necessarily from file data).
 44    name: str
 45
 46    # Data parts. Some types have multiple components, such as strings
 47    # that consist of a length followed by the string data.
 48    parts: list[npt.NDArray[Any]] = []
 49
 50    # Indexes into parts that we can call the actual data. For example
 51    # an array of strings will be populated with indexes to the actual
 52    # string data.
 53    data: list[int] = [-1]
 54
 55    types: list[GGUFValueType] = []
 56
 57    def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
 58        if self.types:
 59            to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
 60            main_type = self.types[0]
 61
 62            if main_type == GGUFValueType.ARRAY:
 63                sub_type = self.types[-1]
 64
 65                if sub_type == GGUFValueType.STRING:
 66                    indices = self.data[index_or_slice]
 67
 68                    if isinstance(index_or_slice, int):
 69                        return to_string(self.parts[indices]) # type: ignore
 70                    else:
 71                        return [to_string(self.parts[idx]) for idx in indices] # type: ignore
 72                else:
 73                    # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
 74
 75                    # Check if it's unsafe to perform slice optimization on data
 76                    # if any(True for idx in self.data if len(self.parts[idx]) != 1):
 77                    #     optim_slice = slice(None)
 78                    # else:
 79                    #     optim_slice = index_or_slice
 80                    #     index_or_slice = slice(None)
 81
 82                    # if isinstance(optim_slice, int):
 83                    #     return self.parts[self.data[optim_slice]].tolist()[0]
 84                    # else:
 85                    #     return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
 86
 87                    if isinstance(index_or_slice, int):
 88                        return self.parts[self.data[index_or_slice]].tolist()[0]
 89                    else:
 90                        return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
 91
 92            if main_type == GGUFValueType.STRING:
 93                return to_string(self.parts[-1])
 94            else:
 95                return self.parts[-1].tolist()[0]
 96
 97        return None
 98
 99
100class ReaderTensor(NamedTuple):
101    name: str
102    tensor_type: GGMLQuantizationType
103    shape: npt.NDArray[np.uint32]
104    n_elements: int
105    n_bytes: int
106    data_offset: int
107    data: npt.NDArray[Any]
108    field: ReaderField
109
110
111class GGUFReader:
112    # I - same as host, S - swapped
113    byte_order: Literal['I', 'S'] = 'I'
114    alignment: int = GGUF_DEFAULT_ALIGNMENT
115    data_offset: int
116
117    # Note: Internal helper, API may change.
118    gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
119        GGUFValueType.UINT8:   np.uint8,
120        GGUFValueType.INT8:    np.int8,
121        GGUFValueType.UINT16:  np.uint16,
122        GGUFValueType.INT16:   np.int16,
123        GGUFValueType.UINT32:  np.uint32,
124        GGUFValueType.INT32:   np.int32,
125        GGUFValueType.FLOAT32: np.float32,
126        GGUFValueType.UINT64:  np.uint64,
127        GGUFValueType.INT64:   np.int64,
128        GGUFValueType.FLOAT64: np.float64,
129        GGUFValueType.BOOL:    np.bool_,
130    }
131
132    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
133        self.data = np.memmap(path, mode = mode)
134        offs = 0
135
136        # Check for GGUF magic
137        if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
138            raise ValueError('GGUF magic invalid')
139        offs += 4
140
141        # Check GGUF version
142        temp_version = self._get(offs, np.uint32)
143        if temp_version[0] & 65535 == 0:
144            # If we get 0 here that means it's (probably) a GGUF file created for
145            # the opposite byte order of the machine this script is running on.
146            self.byte_order = 'S'
147            temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
148        version = temp_version[0]
149        if version not in READER_SUPPORTED_VERSIONS:
150            raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
151        if sys.byteorder == "little":
152            # Host is little endian
153            host_endian = GGUFEndian.LITTLE
154            swapped_endian = GGUFEndian.BIG
155        else:
156            # Sorry PDP or other weird systems that don't use BE or LE.
157            host_endian = GGUFEndian.BIG
158            swapped_endian = GGUFEndian.LITTLE
159        self.endianess = swapped_endian if self.byte_order == "S" else host_endian
160        self.fields: OrderedDict[str, ReaderField] = OrderedDict()
161        self.tensors: list[ReaderTensor] = []
162        offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
163
164        # Check tensor count and kv count
165        temp_counts = self._get(offs, np.uint64, 2)
166        offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
167        offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
168        tensor_count, kv_count = temp_counts
169        offs = self._build_fields(offs, kv_count)
170
171        # Build Tensor Info Fields
172        offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
173        new_align = self.fields.get('general.alignment')
174        if new_align is not None:
175            if new_align.types != [GGUFValueType.UINT32]:
176                raise ValueError('Bad type for general.alignment field')
177            self.alignment = new_align.parts[-1][0]
178        padding = offs % self.alignment
179        if padding != 0:
180            offs += self.alignment - padding
181        self.data_offset = offs
182        self._build_tensors(offs, tensors_fields)
183
184    _DT = TypeVar('_DT', bound = npt.DTypeLike)
185
186    # Fetch a key/value metadata field by key.
187    def get_field(self, key: str) -> Union[ReaderField, None]:
188        return self.fields.get(key, None)
189
190    # Fetch a tensor from the list by index.
191    def get_tensor(self, idx: int) -> ReaderTensor:
192        return self.tensors[idx]
193
194    def _get(
195        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
196    ) -> npt.NDArray[Any]:
197        count = int(count)
198        itemsize = int(np.empty([], dtype = dtype).itemsize)
199        end_offs = offset + itemsize * count
200        arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
201        return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
202
203    def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
204        if field.name in self.fields:
205            # TODO: add option to generate error on duplicate keys
206            # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
207
208            logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
209            self.fields[field.name + '_{}'.format(field.offset)] = field
210        else:
211            self.fields[field.name] = field
212        return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
213
214    def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
215        slen = self._get(offset, np.uint64)
216        return slen, self._get(offset + 8, np.uint8, slen[0])
217
218    def _get_field_parts(
219        self, orig_offs: int, raw_type: int,
220    ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
221        offs = orig_offs
222        types: list[GGUFValueType] = []
223        gtype = GGUFValueType(raw_type)
224        types.append(gtype)
225        # Handle strings.
226        if gtype == GGUFValueType.STRING:
227            sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
228            size = sum(int(part.nbytes) for part in sparts)
229            return size, sparts, [1], types
230        # Check if it's a simple scalar type.
231        nptype = self.gguf_scalar_to_np.get(gtype)
232        if nptype is not None:
233            val = self._get(offs, nptype)
234            return int(val.nbytes), [val], [0], types
235        # Handle arrays.
236        if gtype == GGUFValueType.ARRAY:
237            raw_itype = self._get(offs, np.uint32)
238            offs += int(raw_itype.nbytes)
239            alen = self._get(offs, np.uint64)
240            offs += int(alen.nbytes)
241            aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
242            data_idxs: list[int] = []
243            # FIXME: Handle multi-dimensional arrays properly instead of flattening
244            for idx in range(alen[0]):
245                curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
246                if idx == 0:
247                    types += curr_types
248                idxs_offs = len(aparts)
249                aparts += curr_parts
250                data_idxs += (idx + idxs_offs for idx in curr_idxs)
251                offs += curr_size
252            return offs - orig_offs, aparts, data_idxs, types
253        # We can't deal with this one.
254        raise ValueError(f'Unknown/unhandled field type {gtype}')
255
256    def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
257        offs = orig_offs
258
259        # Get Tensor Name
260        name_len, name_data = self._get_str(offs)
261        offs += int(name_len.nbytes + name_data.nbytes)
262
263        # Get Tensor Dimensions Count
264        n_dims = self._get(offs, np.uint32)
265        offs += int(n_dims.nbytes)
266
267        # Get Tensor Dimension Array
268        dims = self._get(offs, np.uint64, n_dims[0])
269        offs += int(dims.nbytes)
270
271        # Get Tensor Encoding Scheme Type
272        raw_dtype = self._get(offs, np.uint32)
273        offs += int(raw_dtype.nbytes)
274
275        # Get Tensor Offset
276        offset_tensor = self._get(offs, np.uint64)
277        offs += int(offset_tensor.nbytes)
278
279        return ReaderField(
280            orig_offs,
281            str(bytes(name_data), encoding = 'utf-8'),
282            [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
283            [1, 3, 4, 5],
284        )
285
286    def _build_fields(self, offs: int, count: int) -> int:
287        for _ in range(count):
288            orig_offs = offs
289            kv_klen, kv_kdata = self._get_str(offs)
290            offs += int(kv_klen.nbytes + kv_kdata.nbytes)
291            raw_kv_type = self._get(offs, np.uint32)
292            offs += int(raw_kv_type.nbytes)
293            parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
294            idxs_offs = len(parts)
295            field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
296            parts += field_parts
297            self._push_field(ReaderField(
298                orig_offs,
299                str(bytes(kv_kdata), encoding = 'utf-8'),
300                parts,
301                [idx + idxs_offs for idx in field_idxs],
302                field_types,
303            ), skip_sum = True)
304            offs += field_size
305        return offs
306
307    def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
308        tensor_fields = []
309        for _ in range(count):
310            field = self._get_tensor_info_field(offs)
311            offs += sum(int(part.nbytes) for part in field.parts)
312            tensor_fields.append(field)
313        return offs, tensor_fields
314
315    def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
316        tensors = []
317        tensor_names = set() # keep track of name to prevent duplicated tensors
318        for field in fields:
319            _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
320            # check if there's any tensor having same name already in the list
321            tensor_name = str(bytes(name_data), encoding = 'utf-8')
322            if tensor_name in tensor_names:
323                raise ValueError(f'Found duplicated tensor with name {tensor_name}')
324            tensor_names.add(tensor_name)
325            ggml_type = GGMLQuantizationType(raw_dtype[0])
326            n_elems = int(np.prod(dims))
327            np_dims = tuple(reversed(dims.tolist()))
328            block_size, type_size = GGML_QUANT_SIZES[ggml_type]
329            n_bytes = n_elems * type_size // block_size
330            data_offs = int(start_offs + offset_tensor[0])
331            item_type: npt.DTypeLike
332            if ggml_type == GGMLQuantizationType.F16:
333                item_count = n_elems
334                item_type = np.float16
335            elif ggml_type == GGMLQuantizationType.F32:
336                item_count = n_elems
337                item_type = np.float32
338            elif ggml_type == GGMLQuantizationType.F64:
339                item_count = n_elems
340                item_type = np.float64
341            elif ggml_type == GGMLQuantizationType.I8:
342                item_count = n_elems
343                item_type = np.int8
344            elif ggml_type == GGMLQuantizationType.I16:
345                item_count = n_elems
346                item_type = np.int16
347            elif ggml_type == GGMLQuantizationType.I32:
348                item_count = n_elems
349                item_type = np.int32
350            elif ggml_type == GGMLQuantizationType.I64:
351                item_count = n_elems
352                item_type = np.int64
353            else:
354                item_count = n_bytes
355                item_type = np.uint8
356                np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
357            tensors.append(ReaderTensor(
358                name = tensor_name,
359                tensor_type = ggml_type,
360                shape = dims,
361                n_elements = n_elems,
362                n_bytes = n_bytes,
363                data_offset = data_offs,
364                data = self._get(data_offs, item_type, item_count).reshape(np_dims),
365                field = field,
366            ))
367        self.tensors = tensors