1#
2# GGUF file reading/modification support. For API usage information,
3# please see the files scripts/ for some fairly simple examples.
4#
5from __future__ import annotations
6
7import logging
8import os
9import sys
10from collections import OrderedDict
11from typing import Any, Literal, NamedTuple, TypeVar, Union
12
13import numpy as np
14import numpy.typing as npt
15
16from .quants import quant_shape_to_byte_shape
17
18if __name__ == "__main__":
19 from pathlib import Path
20
21 # Allow running file in package as a script.
22 sys.path.insert(0, str(Path(__file__).parent.parent))
23
24from gguf.constants import (
25 GGML_QUANT_SIZES,
26 GGUF_DEFAULT_ALIGNMENT,
27 GGUF_MAGIC,
28 GGUF_VERSION,
29 GGMLQuantizationType,
30 GGUFValueType,
31 GGUFEndian,
32)
33
34logger = logging.getLogger(__name__)
35
36READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
37
38
39class ReaderField(NamedTuple):
40 # Offset to start of this field.
41 offset: int
42
43 # Name of the field (not necessarily from file data).
44 name: str
45
46 # Data parts. Some types have multiple components, such as strings
47 # that consist of a length followed by the string data.
48 parts: list[npt.NDArray[Any]] = []
49
50 # Indexes into parts that we can call the actual data. For example
51 # an array of strings will be populated with indexes to the actual
52 # string data.
53 data: list[int] = [-1]
54
55 types: list[GGUFValueType] = []
56
57 def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
58 if self.types:
59 to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
60 main_type = self.types[0]
61
62 if main_type == GGUFValueType.ARRAY:
63 sub_type = self.types[-1]
64
65 if sub_type == GGUFValueType.STRING:
66 indices = self.data[index_or_slice]
67
68 if isinstance(index_or_slice, int):
69 return to_string(self.parts[indices]) # type: ignore
70 else:
71 return [to_string(self.parts[idx]) for idx in indices] # type: ignore
72 else:
73 # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
74
75 # Check if it's unsafe to perform slice optimization on data
76 # if any(True for idx in self.data if len(self.parts[idx]) != 1):
77 # optim_slice = slice(None)
78 # else:
79 # optim_slice = index_or_slice
80 # index_or_slice = slice(None)
81
82 # if isinstance(optim_slice, int):
83 # return self.parts[self.data[optim_slice]].tolist()[0]
84 # else:
85 # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
86
87 if isinstance(index_or_slice, int):
88 return self.parts[self.data[index_or_slice]].tolist()[0]
89 else:
90 return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
91
92 if main_type == GGUFValueType.STRING:
93 return to_string(self.parts[-1])
94 else:
95 return self.parts[-1].tolist()[0]
96
97 return None
98
99
100class ReaderTensor(NamedTuple):
101 name: str
102 tensor_type: GGMLQuantizationType
103 shape: npt.NDArray[np.uint32]
104 n_elements: int
105 n_bytes: int
106 data_offset: int
107 data: npt.NDArray[Any]
108 field: ReaderField
109
110
111class GGUFReader:
112 # I - same as host, S - swapped
113 byte_order: Literal['I', 'S'] = 'I'
114 alignment: int = GGUF_DEFAULT_ALIGNMENT
115 data_offset: int
116
117 # Note: Internal helper, API may change.
118 gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
119 GGUFValueType.UINT8: np.uint8,
120 GGUFValueType.INT8: np.int8,
121 GGUFValueType.UINT16: np.uint16,
122 GGUFValueType.INT16: np.int16,
123 GGUFValueType.UINT32: np.uint32,
124 GGUFValueType.INT32: np.int32,
125 GGUFValueType.FLOAT32: np.float32,
126 GGUFValueType.UINT64: np.uint64,
127 GGUFValueType.INT64: np.int64,
128 GGUFValueType.FLOAT64: np.float64,
129 GGUFValueType.BOOL: np.bool_,
130 }
131
132 def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
133 self.data = np.memmap(path, mode = mode)
134 offs = 0
135
136 # Check for GGUF magic
137 if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
138 raise ValueError('GGUF magic invalid')
139 offs += 4
140
141 # Check GGUF version
142 temp_version = self._get(offs, np.uint32)
143 if temp_version[0] & 65535 == 0:
144 # If we get 0 here that means it's (probably) a GGUF file created for
145 # the opposite byte order of the machine this script is running on.
146 self.byte_order = 'S'
147 temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
148 version = temp_version[0]
149 if version not in READER_SUPPORTED_VERSIONS:
150 raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
151 if sys.byteorder == "little":
152 # Host is little endian
153 host_endian = GGUFEndian.LITTLE
154 swapped_endian = GGUFEndian.BIG
155 else:
156 # Sorry PDP or other weird systems that don't use BE or LE.
157 host_endian = GGUFEndian.BIG
158 swapped_endian = GGUFEndian.LITTLE
159 self.endianess = swapped_endian if self.byte_order == "S" else host_endian
160 self.fields: OrderedDict[str, ReaderField] = OrderedDict()
161 self.tensors: list[ReaderTensor] = []
162 offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
163
164 # Check tensor count and kv count
165 temp_counts = self._get(offs, np.uint64, 2)
166 offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
167 offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
168 tensor_count, kv_count = temp_counts
169 offs = self._build_fields(offs, kv_count)
170
171 # Build Tensor Info Fields
172 offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
173 new_align = self.fields.get('general.alignment')
174 if new_align is not None:
175 if new_align.types != [GGUFValueType.UINT32]:
176 raise ValueError('Bad type for general.alignment field')
177 self.alignment = new_align.parts[-1][0]
178 padding = offs % self.alignment
179 if padding != 0:
180 offs += self.alignment - padding
181 self.data_offset = offs
182 self._build_tensors(offs, tensors_fields)
183
184 _DT = TypeVar('_DT', bound = npt.DTypeLike)
185
186 # Fetch a key/value metadata field by key.
187 def get_field(self, key: str) -> Union[ReaderField, None]:
188 return self.fields.get(key, None)
189
190 # Fetch a tensor from the list by index.
191 def get_tensor(self, idx: int) -> ReaderTensor:
192 return self.tensors[idx]
193
194 def _get(
195 self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
196 ) -> npt.NDArray[Any]:
197 count = int(count)
198 itemsize = int(np.empty([], dtype = dtype).itemsize)
199 end_offs = offset + itemsize * count
200 arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
201 return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
202
203 def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
204 if field.name in self.fields:
205 # TODO: add option to generate error on duplicate keys
206 # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
207
208 logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
209 self.fields[field.name + '_{}'.format(field.offset)] = field
210 else:
211 self.fields[field.name] = field
212 return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
213
214 def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
215 slen = self._get(offset, np.uint64)
216 return slen, self._get(offset + 8, np.uint8, slen[0])
217
218 def _get_field_parts(
219 self, orig_offs: int, raw_type: int,
220 ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
221 offs = orig_offs
222 types: list[GGUFValueType] = []
223 gtype = GGUFValueType(raw_type)
224 types.append(gtype)
225 # Handle strings.
226 if gtype == GGUFValueType.STRING:
227 sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
228 size = sum(int(part.nbytes) for part in sparts)
229 return size, sparts, [1], types
230 # Check if it's a simple scalar type.
231 nptype = self.gguf_scalar_to_np.get(gtype)
232 if nptype is not None:
233 val = self._get(offs, nptype)
234 return int(val.nbytes), [val], [0], types
235 # Handle arrays.
236 if gtype == GGUFValueType.ARRAY:
237 raw_itype = self._get(offs, np.uint32)
238 offs += int(raw_itype.nbytes)
239 alen = self._get(offs, np.uint64)
240 offs += int(alen.nbytes)
241 aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
242 data_idxs: list[int] = []
243 # FIXME: Handle multi-dimensional arrays properly instead of flattening
244 for idx in range(alen[0]):
245 curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
246 if idx == 0:
247 types += curr_types
248 idxs_offs = len(aparts)
249 aparts += curr_parts
250 data_idxs += (idx + idxs_offs for idx in curr_idxs)
251 offs += curr_size
252 return offs - orig_offs, aparts, data_idxs, types
253 # We can't deal with this one.
254 raise ValueError(f'Unknown/unhandled field type {gtype}')
255
256 def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
257 offs = orig_offs
258
259 # Get Tensor Name
260 name_len, name_data = self._get_str(offs)
261 offs += int(name_len.nbytes + name_data.nbytes)
262
263 # Get Tensor Dimensions Count
264 n_dims = self._get(offs, np.uint32)
265 offs += int(n_dims.nbytes)
266
267 # Get Tensor Dimension Array
268 dims = self._get(offs, np.uint64, n_dims[0])
269 offs += int(dims.nbytes)
270
271 # Get Tensor Encoding Scheme Type
272 raw_dtype = self._get(offs, np.uint32)
273 offs += int(raw_dtype.nbytes)
274
275 # Get Tensor Offset
276 offset_tensor = self._get(offs, np.uint64)
277 offs += int(offset_tensor.nbytes)
278
279 return ReaderField(
280 orig_offs,
281 str(bytes(name_data), encoding = 'utf-8'),
282 [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
283 [1, 3, 4, 5],
284 )
285
286 def _build_fields(self, offs: int, count: int) -> int:
287 for _ in range(count):
288 orig_offs = offs
289 kv_klen, kv_kdata = self._get_str(offs)
290 offs += int(kv_klen.nbytes + kv_kdata.nbytes)
291 raw_kv_type = self._get(offs, np.uint32)
292 offs += int(raw_kv_type.nbytes)
293 parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
294 idxs_offs = len(parts)
295 field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
296 parts += field_parts
297 self._push_field(ReaderField(
298 orig_offs,
299 str(bytes(kv_kdata), encoding = 'utf-8'),
300 parts,
301 [idx + idxs_offs for idx in field_idxs],
302 field_types,
303 ), skip_sum = True)
304 offs += field_size
305 return offs
306
307 def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
308 tensor_fields = []
309 for _ in range(count):
310 field = self._get_tensor_info_field(offs)
311 offs += sum(int(part.nbytes) for part in field.parts)
312 tensor_fields.append(field)
313 return offs, tensor_fields
314
315 def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
316 tensors = []
317 tensor_names = set() # keep track of name to prevent duplicated tensors
318 for field in fields:
319 _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
320 # check if there's any tensor having same name already in the list
321 tensor_name = str(bytes(name_data), encoding = 'utf-8')
322 if tensor_name in tensor_names:
323 raise ValueError(f'Found duplicated tensor with name {tensor_name}')
324 tensor_names.add(tensor_name)
325 ggml_type = GGMLQuantizationType(raw_dtype[0])
326 n_elems = int(np.prod(dims))
327 np_dims = tuple(reversed(dims.tolist()))
328 block_size, type_size = GGML_QUANT_SIZES[ggml_type]
329 n_bytes = n_elems * type_size // block_size
330 data_offs = int(start_offs + offset_tensor[0])
331 item_type: npt.DTypeLike
332 if ggml_type == GGMLQuantizationType.F16:
333 item_count = n_elems
334 item_type = np.float16
335 elif ggml_type == GGMLQuantizationType.F32:
336 item_count = n_elems
337 item_type = np.float32
338 elif ggml_type == GGMLQuantizationType.F64:
339 item_count = n_elems
340 item_type = np.float64
341 elif ggml_type == GGMLQuantizationType.I8:
342 item_count = n_elems
343 item_type = np.int8
344 elif ggml_type == GGMLQuantizationType.I16:
345 item_count = n_elems
346 item_type = np.int16
347 elif ggml_type == GGMLQuantizationType.I32:
348 item_count = n_elems
349 item_type = np.int32
350 elif ggml_type == GGMLQuantizationType.I64:
351 item_count = n_elems
352 item_type = np.int64
353 else:
354 item_count = n_bytes
355 item_type = np.uint8
356 np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
357 tensors.append(ReaderTensor(
358 name = tensor_name,
359 tensor_type = ggml_type,
360 shape = dims,
361 n_elements = n_elems,
362 n_bytes = n_bytes,
363 data_offset = data_offs,
364 data = self._get(data_offs, item_type, item_count).reshape(np_dims),
365 field = field,
366 ))
367 self.tensors = tensors