1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import struct
7import sys
8import tempfile
9from dataclasses import dataclass
10from enum import Enum, auto
11from math import prod
12from pathlib import Path
13from io import BufferedWriter
14from typing import IO, Any, Sequence, Mapping
15from string import ascii_letters, digits
16
17import numpy as np
18
19from .constants import (
20 GGUF_DEFAULT_ALIGNMENT,
21 GGUF_MAGIC,
22 GGUF_VERSION,
23 GGMLQuantizationType,
24 GGUFEndian,
25 GGUFValueType,
26 Keys,
27 RopeScalingType,
28 PoolingType,
29 TokenType,
30 ExpertGatingFuncType,
31)
32
33from .quants import quant_shape_from_byte_shape
34
35logger = logging.getLogger(__name__)
36
37
38SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
39
40
41@dataclass
42class TensorInfo:
43 shape: Sequence[int]
44 dtype: GGMLQuantizationType
45 nbytes: int
46 tensor: np.ndarray[Any, Any] | None = None
47
48
49@dataclass
50class GGUFValue:
51 value: Any
52 type: GGUFValueType
53 sub_type: GGUFValueType | None = None
54
55
56class WriterState(Enum):
57 NO_FILE = auto()
58 EMPTY = auto()
59 HEADER = auto()
60 KV_DATA = auto()
61 TI_DATA = auto()
62 WEIGHTS = auto()
63
64
65class GGUFWriter:
66 fout: list[BufferedWriter] | None
67 path: Path | None
68 temp_file: tempfile.SpooledTemporaryFile[bytes] | None
69 tensors: list[dict[str, TensorInfo]]
70 kv_data: list[dict[str, GGUFValue]]
71 state: WriterState
72 _simple_value_packing = {
73 GGUFValueType.UINT8: "B",
74 GGUFValueType.INT8: "b",
75 GGUFValueType.UINT16: "H",
76 GGUFValueType.INT16: "h",
77 GGUFValueType.UINT32: "I",
78 GGUFValueType.INT32: "i",
79 GGUFValueType.FLOAT32: "f",
80 GGUFValueType.UINT64: "Q",
81 GGUFValueType.INT64: "q",
82 GGUFValueType.FLOAT64: "d",
83 GGUFValueType.BOOL: "?",
84 }
85
86 def __init__(
87 self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
88 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
89 ):
90 self.fout = None
91 self.path = Path(path) if path else None
92 self.arch = arch
93 self.endianess = endianess
94 self.data_alignment = GGUF_DEFAULT_ALIGNMENT
95 self.use_temp_file = use_temp_file
96 self.temp_file = None
97 self.tensors = [{}]
98 self.kv_data = [{}]
99 self.split_max_tensors = split_max_tensors
100 self.split_max_size = split_max_size
101 self.dry_run = dry_run
102 self.small_first_shard = small_first_shard
103 logger.info("gguf: This GGUF file is for {0} Endian only".format(
104 "Big" if self.endianess == GGUFEndian.BIG else "Little",
105 ))
106 self.state = WriterState.NO_FILE
107
108 if self.small_first_shard:
109 self.tensors.append({})
110
111 self.add_architecture()
112
113 def get_total_parameter_count(self) -> tuple[int, int, int, int]:
114 total_params = 0
115 shared_params = 0
116 expert_params = 0
117
118 expert_sum = 0
119 n_expert_tensors = 0
120
121 last_lora_a: tuple[str, TensorInfo] | None = None
122
123 for tensors in self.tensors:
124 for name, info in tensors.items():
125
126 shape = info.shape
127
128 if name.endswith(".lora_a"):
129 last_lora_a = (name, info)
130 continue
131 elif name.endswith(".lora_b"):
132 if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
133 # Bail when the LoRA pair can't be found trivially
134 logger.warning("can't measure LoRA size correctly, tensor order is unusual")
135 return 0, 0, 0, 0
136 else:
137 shape = (*shape[:-1], last_lora_a[1].shape[-1])
138
139 size = prod(shape)
140
141 if "_exps." in name:
142 expert_count = shape[-2 if ".bias" in name else -3]
143 expert_params += (size // expert_count)
144 expert_sum += expert_count
145 n_expert_tensors += 1
146 else:
147 shared_params += size
148
149 total_params += size
150
151 # Hopefully this should work even for variable-expert-count models
152 expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
153
154 # Negate the total to signal it's likely not exact
155 if last_lora_a is not None:
156 total_params = -total_params
157
158 # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
159 return total_params, shared_params, expert_params, expert_count
160
161 def format_shard_names(self, path: Path) -> list[Path]:
162 if len(self.tensors) == 1:
163 return [path]
164 return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
165
166 def open_output_file(self, path: Path | None = None) -> None:
167 if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
168 # allow calling this multiple times as long as the path is the same
169 return
170
171 if self.state is not WriterState.NO_FILE:
172 raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
173
174 if path is not None:
175 self.path = path
176
177 if self.path is not None:
178 filenames = self.print_plan()
179 self.fout = [open(filename, "wb") for filename in filenames]
180 self.state = WriterState.EMPTY
181
182 def print_plan(self) -> list[Path]:
183 logger.info("Writing the following files:")
184 assert self.path is not None
185 filenames = self.format_shard_names(self.path)
186 assert len(filenames) == len(self.tensors)
187 for name, tensors in zip(filenames, self.tensors):
188 logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
189
190 if self.dry_run:
191 logger.info("Dry run, not writing files")
192 for name in filenames:
193 print(name) # noqa: NP100
194 exit()
195
196 return filenames
197
198 def add_shard_kv_data(self) -> None:
199 if len(self.tensors) == 1:
200 return
201
202 total_tensors = sum(len(t) for t in self.tensors)
203 assert self.fout is not None
204 total_splits = len(self.fout)
205 self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
206 for i, kv_data in enumerate(self.kv_data):
207 kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
208 kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
209 kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
210
211 def write_header_to_file(self, path: Path | None = None) -> None:
212 if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
213 logger.warning("Model fails split requirements, not splitting")
214
215 self.open_output_file(path)
216
217 if self.state is not WriterState.EMPTY:
218 raise ValueError(f'Expected output file to be empty, got {self.state}')
219
220 assert self.fout is not None
221 assert len(self.fout) == len(self.tensors)
222 assert len(self.kv_data) == 1
223
224 self.add_shard_kv_data()
225
226 for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
227 fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
228 fout.write(self._pack("I", GGUF_VERSION))
229 fout.write(self._pack("Q", len(tensors)))
230 fout.write(self._pack("Q", len(kv_data)))
231 fout.flush()
232 self.state = WriterState.HEADER
233
234 def write_kv_data_to_file(self) -> None:
235 if self.state is not WriterState.HEADER:
236 raise ValueError(f'Expected output file to contain the header, got {self.state}')
237 assert self.fout is not None
238
239 for fout, kv_data in zip(self.fout, self.kv_data):
240 kv_bytes = bytearray()
241
242 for key, val in kv_data.items():
243 kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
244 kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
245
246 fout.write(kv_bytes)
247
248 self.flush()
249 self.state = WriterState.KV_DATA
250
251 def write_ti_data_to_file(self) -> None:
252 if self.state is not WriterState.KV_DATA:
253 raise ValueError(f'Expected output file to contain KV data, got {self.state}')
254 assert self.fout is not None
255
256 for fout, tensors in zip(self.fout, self.tensors):
257 ti_data = bytearray()
258 offset_tensor = 0
259
260 for name, ti in tensors.items():
261 ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
262 n_dims = len(ti.shape)
263 ti_data += self._pack("I", n_dims)
264 for j in range(n_dims):
265 ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
266 ti_data += self._pack("I", ti.dtype)
267 ti_data += self._pack("Q", offset_tensor)
268 offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
269
270 fout.write(ti_data)
271 fout.flush()
272 self.state = WriterState.TI_DATA
273
274 def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
275 if any(key in kv_data for kv_data in self.kv_data):
276 logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
277
278 self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
279
280 def add_uint8(self, key: str, val: int) -> None:
281 self.add_key_value(key,val, GGUFValueType.UINT8)
282
283 def add_int8(self, key: str, val: int) -> None:
284 self.add_key_value(key, val, GGUFValueType.INT8)
285
286 def add_uint16(self, key: str, val: int) -> None:
287 self.add_key_value(key, val, GGUFValueType.UINT16)
288
289 def add_int16(self, key: str, val: int) -> None:
290 self.add_key_value(key, val, GGUFValueType.INT16)
291
292 def add_uint32(self, key: str, val: int) -> None:
293 self.add_key_value(key, val, GGUFValueType.UINT32)
294
295 def add_int32(self, key: str, val: int) -> None:
296 self.add_key_value(key, val, GGUFValueType.INT32)
297
298 def add_float32(self, key: str, val: float) -> None:
299 self.add_key_value(key, val, GGUFValueType.FLOAT32)
300
301 def add_uint64(self, key: str, val: int) -> None:
302 self.add_key_value(key, val, GGUFValueType.UINT64)
303
304 def add_int64(self, key: str, val: int) -> None:
305 self.add_key_value(key, val, GGUFValueType.INT64)
306
307 def add_float64(self, key: str, val: float) -> None:
308 self.add_key_value(key, val, GGUFValueType.FLOAT64)
309
310 def add_bool(self, key: str, val: bool) -> None:
311 self.add_key_value(key, val, GGUFValueType.BOOL)
312
313 def add_string(self, key: str, val: str) -> None:
314 if not val:
315 return
316 self.add_key_value(key, val, GGUFValueType.STRING)
317
318 def add_array(self, key: str, val: Sequence[Any]) -> None:
319 if len(val) == 0:
320 return
321 self.add_key_value(key, val, GGUFValueType.ARRAY)
322
323 @staticmethod
324 def ggml_pad(x: int, n: int) -> int:
325 return ((x + n - 1) // n) * n
326
327 def add_tensor_info(
328 self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
329 tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
330 ) -> None:
331 if self.state is not WriterState.NO_FILE:
332 raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
333
334 if any(name in tensors for tensors in self.tensors):
335 raise ValueError(f'Duplicated tensor name {name!r}')
336
337 if raw_dtype is None:
338 if tensor_dtype == np.float16:
339 dtype = GGMLQuantizationType.F16
340 elif tensor_dtype == np.float32:
341 dtype = GGMLQuantizationType.F32
342 elif tensor_dtype == np.float64:
343 dtype = GGMLQuantizationType.F64
344 elif tensor_dtype == np.int8:
345 dtype = GGMLQuantizationType.I8
346 elif tensor_dtype == np.int16:
347 dtype = GGMLQuantizationType.I16
348 elif tensor_dtype == np.int32:
349 dtype = GGMLQuantizationType.I32
350 elif tensor_dtype == np.int64:
351 dtype = GGMLQuantizationType.I64
352 else:
353 raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
354 else:
355 dtype = raw_dtype
356 if tensor_dtype == np.uint8:
357 tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
358
359 # make sure there is at least one tensor before splitting
360 if len(self.tensors[-1]) > 0:
361 if ( # split when over tensor limit
362 self.split_max_tensors != 0
363 and len(self.tensors[-1]) >= self.split_max_tensors
364 ) or ( # split when over size limit
365 self.split_max_size != 0
366 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
367 ):
368 self.tensors.append({})
369
370 self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
371
372 def add_tensor(
373 self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
374 raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
375 ) -> None:
376 # if tensor endianness is not passed, assume it's native to system
377 if tensor_endianess is None:
378 tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
379
380 if tensor_endianess != self.endianess:
381 # Don't byteswap inplace since lazy copies cannot handle it
382 tensor = tensor.byteswap(inplace=False)
383 if self.use_temp_file and self.temp_file is None:
384 fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
385 fp.seek(0)
386 self.temp_file = fp
387
388 shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
389 self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
390
391 if self.temp_file is None:
392 self.tensors[-1][name].tensor = tensor
393 return
394
395 tensor.tofile(self.temp_file)
396 self.write_padding(self.temp_file, tensor.nbytes)
397
398 def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
399 pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
400 if pad != 0:
401 fp.write(bytes([0] * pad))
402
403 def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
404 if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
405 raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
406 assert self.fout is not None
407
408 # if tensor endianness is not passed, assume it's native to system
409 if tensor_endianess is None:
410 tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
411
412 if tensor_endianess != self.endianess:
413 # Don't byteswap inplace since lazy copies cannot handle it
414 tensor = tensor.byteswap(inplace=False)
415
416 file_id = -1
417 for i, tensors in enumerate(self.tensors):
418 if len(tensors) > 0:
419 file_id = i
420 break
421
422 fout = self.fout[file_id]
423
424 # pop the first tensor info
425 # TODO: cleaner way to get the first key
426 first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
427 ti = self.tensors[file_id].pop(first_tensor_name)
428 assert ti.nbytes == tensor.nbytes
429
430 self.write_padding(fout, fout.tell())
431 tensor.tofile(fout)
432 self.write_padding(fout, tensor.nbytes)
433
434 self.state = WriterState.WEIGHTS
435
436 def write_tensors_to_file(self, *, progress: bool = False) -> None:
437 self.write_ti_data_to_file()
438
439 assert self.fout is not None
440
441 for fout in self.fout:
442 self.write_padding(fout, fout.tell())
443
444 if self.temp_file is None:
445 shard_bar = None
446 bar = None
447
448 if progress:
449 from tqdm import tqdm
450
451 total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
452
453 if len(self.fout) > 1:
454 shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
455 bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
456
457 for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
458 if shard_bar is not None:
459 shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
460 total = sum(ti.nbytes for ti in tensors.values())
461 shard_bar.reset(total=(total if total > 0 else None))
462
463 # relying on the fact that Python dicts preserve insertion order (since 3.7)
464 for ti in tensors.values():
465 assert ti.tensor is not None # can only iterate once over the tensors
466 assert ti.tensor.nbytes == ti.nbytes
467 ti.tensor.tofile(fout)
468 if shard_bar is not None:
469 shard_bar.update(ti.nbytes)
470 if bar is not None:
471 bar.update(ti.nbytes)
472 self.write_padding(fout, ti.nbytes)
473 ti.tensor = None
474 else:
475 self.temp_file.seek(0)
476
477 shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
478 self.flush()
479 self.temp_file.close()
480
481 self.state = WriterState.WEIGHTS
482
483 def flush(self) -> None:
484 assert self.fout is not None
485 for fout in self.fout:
486 fout.flush()
487
488 def close(self) -> None:
489 if self.fout is not None:
490 for fout in self.fout:
491 fout.close()
492 self.fout = None
493
494 def add_type(self, type_name: str) -> None:
495 self.add_string(Keys.General.TYPE, type_name)
496
497 def add_architecture(self) -> None:
498 self.add_string(Keys.General.ARCHITECTURE, self.arch)
499
500 def add_quantization_version(self, quantization_version: int) -> None:
501 self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
502
503 def add_custom_alignment(self, alignment: int) -> None:
504 self.data_alignment = alignment
505 self.add_uint32(Keys.General.ALIGNMENT, alignment)
506
507 def add_file_type(self, ftype: int) -> None:
508 self.add_uint32(Keys.General.FILE_TYPE, ftype)
509
510 def add_sampling_sequence(self, sequence: str) -> None:
511 self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
512
513 def add_sampling_top_k(self, top_k: int) -> None:
514 self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
515
516 def add_sampling_top_p(self, top_p: float) -> None:
517 self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
518
519 def add_sampling_min_p(self, min_p: float) -> None:
520 self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
521
522 def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
523 self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
524
525 def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
526 self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
527
528 def add_sampling_temp(self, temp: float) -> None:
529 self.add_float32(Keys.General.SAMPLING_TEMP, temp)
530
531 def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
532 self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
533
534 def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
535 self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
536
537 def add_sampling_mirostat(self, mirostat: int) -> None:
538 self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
539
540 def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
541 self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
542
543 def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
544 self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
545
546 def add_name(self, name: str) -> None:
547 self.add_string(Keys.General.NAME, name)
548
549 def add_author(self, author: str) -> None:
550 self.add_string(Keys.General.AUTHOR, author)
551
552 def add_version(self, version: str) -> None:
553 self.add_string(Keys.General.VERSION, version)
554
555 def add_organization(self, organization: str) -> None:
556 self.add_string(Keys.General.ORGANIZATION, organization)
557
558 def add_finetune(self, finetune: str) -> None:
559 self.add_string(Keys.General.FINETUNE, finetune)
560
561 def add_basename(self, basename: str) -> None:
562 self.add_string(Keys.General.BASENAME, basename)
563
564 def add_description(self, description: str) -> None:
565 self.add_string(Keys.General.DESCRIPTION, description)
566
567 def add_quantized_by(self, quantized: str) -> None:
568 self.add_string(Keys.General.QUANTIZED_BY, quantized)
569
570 def add_size_label(self, size_label: str) -> None:
571 self.add_string(Keys.General.SIZE_LABEL, size_label)
572
573 def add_license(self, license: str) -> None:
574 self.add_string(Keys.General.LICENSE, license)
575
576 def add_license_name(self, license: str) -> None:
577 self.add_string(Keys.General.LICENSE_NAME, license)
578
579 def add_license_link(self, license: str) -> None:
580 self.add_string(Keys.General.LICENSE_LINK, license)
581
582 def add_url(self, url: str) -> None:
583 self.add_string(Keys.General.URL, url)
584
585 def add_doi(self, doi: str) -> None:
586 self.add_string(Keys.General.DOI, doi)
587
588 def add_uuid(self, uuid: str) -> None:
589 self.add_string(Keys.General.UUID, uuid)
590
591 def add_repo_url(self, repo_url: str) -> None:
592 self.add_string(Keys.General.REPO_URL, repo_url)
593
594 def add_source_url(self, url: str) -> None:
595 self.add_string(Keys.General.SOURCE_URL, url)
596
597 def add_source_doi(self, doi: str) -> None:
598 self.add_string(Keys.General.SOURCE_DOI, doi)
599
600 def add_source_uuid(self, uuid: str) -> None:
601 self.add_string(Keys.General.SOURCE_UUID, uuid)
602
603 def add_source_repo_url(self, repo_url: str) -> None:
604 self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
605
606 def add_base_model_count(self, source_count: int) -> None:
607 self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
608
609 def add_base_model_name(self, source_id: int, name: str) -> None:
610 self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
611
612 def add_base_model_author(self, source_id: int, author: str) -> None:
613 self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
614
615 def add_base_model_version(self, source_id: int, version: str) -> None:
616 self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
617
618 def add_base_model_organization(self, source_id: int, organization: str) -> None:
619 self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
620
621 def add_base_model_description(self, source_id: int, description: str) -> None:
622 self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
623
624 def add_base_model_url(self, source_id: int, url: str) -> None:
625 self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
626
627 def add_base_model_doi(self, source_id: int, doi: str) -> None:
628 self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
629
630 def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
631 self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
632
633 def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
634 self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
635
636 def add_dataset_count(self, source_count: int) -> None:
637 self.add_uint32(Keys.General.DATASET_COUNT, source_count)
638
639 def add_dataset_name(self, source_id: int, name: str) -> None:
640 self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
641
642 def add_dataset_author(self, source_id: int, author: str) -> None:
643 self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
644
645 def add_dataset_version(self, source_id: int, version: str) -> None:
646 self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
647
648 def add_dataset_organization(self, source_id: int, organization: str) -> None:
649 self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
650
651 def add_dataset_description(self, source_id: int, description: str) -> None:
652 self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
653
654 def add_dataset_url(self, source_id: int, url: str) -> None:
655 self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
656
657 def add_dataset_doi(self, source_id: int, doi: str) -> None:
658 self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
659
660 def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
661 self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
662
663 def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
664 self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
665
666 def add_tags(self, tags: Sequence[str]) -> None:
667 self.add_array(Keys.General.TAGS, tags)
668
669 def add_languages(self, languages: Sequence[str]) -> None:
670 self.add_array(Keys.General.LANGUAGES, languages)
671
672 def add_tensor_data_layout(self, layout: str) -> None:
673 self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
674
675 def add_vocab_size(self, size: int) -> None:
676 self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
677
678 def add_context_length(self, length: int) -> None:
679 self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
680
681 def add_embedding_length(self, length: int) -> None:
682 self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
683
684 def add_embedding_length_out(self, length: int) -> None:
685 self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
686
687 def add_features_length(self, length: int) -> None:
688 self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
689
690 def add_posnet_embedding_length(self, length: int) -> None:
691 self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
692
693 def add_posnet_block_count(self, length: int) -> None:
694 self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
695
696 def add_convnext_embedding_length(self, length: int) -> None:
697 self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
698
699 def add_convnext_block_count(self, length: int) -> None:
700 self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
701
702 def add_shortconv_l_cache(self, length: int) -> None:
703 self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
704
705 def add_block_count(self, length: int) -> None:
706 self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
707
708 def add_leading_dense_block_count(self, length: int) -> None:
709 self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
710
711 def add_full_attention_interval(self, interval: int) -> None:
712 self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
713
714 def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
715 if isinstance(length, int):
716 self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
717 else:
718 self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
719
720 def add_expert_feed_forward_length(self, length: int) -> None:
721 self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
722
723 def add_expert_shared_feed_forward_length(self, length: int) -> None:
724 self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
725
726 def add_expert_chunk_feed_forward_length(self, length: int) -> None:
727 self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
728
729 def add_parallel_residual(self, use: bool) -> None:
730 self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
731
732 def add_decoder_start_token_id(self, id: int) -> None:
733 self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
734
735 def add_decoder_block_count(self, value: int) -> None:
736 self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
737
738 def add_embedding_length_per_layer_input(self, value: int) -> None:
739 self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
740
741 def add_altup_active_idx(self, val: int) -> None:
742 self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
743
744 def add_altup_num_inputs(self, val: int) -> None:
745 self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
746
747 def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
748 self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
749
750 def add_head_count(self, count: int | Sequence[int]) -> None:
751 if isinstance(count, int):
752 self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
753 else:
754 self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
755
756 def add_head_count_kv(self, count: int | Sequence[int]) -> None:
757 if isinstance(count, int):
758 self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
759 else:
760 self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
761
762 def add_key_length(self, length: int) -> None:
763 self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
764
765 def add_value_length(self, length: int) -> None:
766 self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
767
768 def add_key_length_mla(self, length: int) -> None:
769 self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
770
771 def add_value_length_mla(self, length: int) -> None:
772 self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
773
774 def add_max_alibi_bias(self, bias: float) -> None:
775 self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
776
777 def add_clamp_kqv(self, value: float) -> None:
778 self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
779
780 def add_shared_kv_layers(self, value: int) -> None:
781 self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
782
783 def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
784 key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
785 if isinstance(value, int):
786 self.add_uint32(key, value)
787 else:
788 self.add_array(key, value)
789
790 def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
791 self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
792 self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
793
794 def add_logit_scale(self, value: float) -> None:
795 self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
796
797 def add_attn_logit_softcapping(self, value: float) -> None:
798 self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
799
800 def add_router_logit_softcapping(self, value: float) -> None:
801 self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
802
803 def add_final_logit_softcapping(self, value: float) -> None:
804 self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
805
806 def add_expert_count(self, count: int) -> None:
807 self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
808
809 def add_expert_used_count(self, count: int) -> None:
810 self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
811
812 def add_expert_shared_count(self, count: int) -> None:
813 self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
814
815 def add_expert_group_count(self, count: int) -> None:
816 self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
817
818 def add_expert_group_used_count(self, count: int) -> None:
819 self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
820
821 def add_expert_weights_scale(self, value: float) -> None:
822 self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
823
824 def add_expert_weights_norm(self, value: bool) -> None:
825 self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
826
827 def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
828 self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
829
830 def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None:
831 self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values)
832
833 def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None:
834 self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values)
835
836 def add_expert_group_scale(self, value: float) -> None:
837 self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
838
839 def add_experts_per_group(self, count: int) -> None:
840 self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
841
842 def add_moe_every_n_layers(self, value: int) -> None:
843 self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
844
845 def add_nextn_predict_layers(self, count: int) -> None:
846 self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
847
848 def add_swin_norm(self, value: bool) -> None:
849 self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
850
851 def add_rescale_every_n_layers(self, count: int) -> None:
852 self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
853
854 def add_time_mix_extra_dim(self, dim: int) -> None:
855 self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
856
857 def add_time_decay_extra_dim(self, dim: int) -> None:
858 self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
859
860 def add_residual_scale(self, value: float) -> None:
861 self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
862
863 def add_embedding_scale(self, value: float) -> None:
864 self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
865
866 def add_wkv_head_size(self, size: int) -> None:
867 self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
868
869 def add_token_shift_count(self, count: int) -> None:
870 self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
871
872 def add_interleave_moe_layer_step(self, value: int) -> None:
873 self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
874
875 def add_layer_norm_eps(self, value: float) -> None:
876 self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
877
878 def add_layer_norm_rms_eps(self, value: float) -> None:
879 self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
880
881 def add_group_norm_eps(self, value: float) -> None:
882 self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
883
884 def add_group_norm_groups(self, value: int) -> None:
885 self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
886
887 def add_causal_attention(self, value: bool) -> None:
888 self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
889
890 def add_q_lora_rank(self, length: int) -> None:
891 self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
892
893 def add_kv_lora_rank(self, length: int) -> None:
894 self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
895
896 def add_decay_lora_rank(self, length: int) -> None:
897 self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
898
899 def add_iclr_lora_rank(self, length: int) -> None:
900 self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
901
902 def add_value_residual_mix_lora_rank(self, length: int) -> None:
903 self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
904
905 def add_rope_freq_base_swa(self, value: float) -> None:
906 self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
907
908 def add_gate_lora_rank(self, length: int) -> None:
909 self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
910
911 def add_relative_attn_buckets_count(self, value: int) -> None:
912 self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
913
914 def add_sliding_window(self, value: int) -> None:
915 self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
916
917 def add_attention_scale(self, value: float) -> None:
918 self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
919
920 def add_attn_output_scale(self, value: float) -> None:
921 self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
922
923 def add_attn_temperature_length(self, value: int) -> None:
924 self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
925
926 def add_attn_temperature_scale(self, value: float) -> None:
927 self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
928
929 def add_pooling_type(self, value: PoolingType) -> None:
930 self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
931
932 def add_num_deepstack_layers(self, count: int) -> None:
933 self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
934
935 def add_rope_dimension_count(self, count: int) -> None:
936 self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
937
938 def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
939 self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
940
941 def add_rope_freq_base(self, value: float) -> None:
942 self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
943
944 def add_rope_scaling_type(self, value: RopeScalingType) -> None:
945 self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
946
947 def add_rope_scaling_factor(self, value: float) -> None:
948 self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
949
950 def add_rope_scaling_attn_factors(self, value: float) -> None:
951 self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
952
953 def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
954 self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
955
956 def add_rope_scaling_finetuned(self, value: bool) -> None:
957 self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
958
959 def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
960 self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
961
962 def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
963 self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
964
965 def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
966 self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
967
968 def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
969 self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
970
971 def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
972 self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
973
974 def add_ssm_conv_kernel(self, value: int) -> None:
975 self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
976
977 def add_ssm_inner_size(self, value: int) -> None:
978 self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
979
980 def add_ssm_state_size(self, value: int) -> None:
981 self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
982
983 def add_ssm_time_step_rank(self, value: int) -> None:
984 self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
985
986 def add_ssm_group_count(self, value: int) -> None:
987 self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
988
989 def add_ssm_dt_b_c_rms(self, value: bool) -> None:
990 self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
991
992 def add_kda_head_dim(self, value: int) -> None:
993 self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)
994
995 def add_tokenizer_model(self, model: str) -> None:
996 self.add_string(Keys.Tokenizer.MODEL, model)
997
998 def add_tokenizer_pre(self, pre: str) -> None:
999 self.add_string(Keys.Tokenizer.PRE, pre)
1000
1001 def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
1002 self.add_array(Keys.Tokenizer.LIST, tokens)
1003
1004 def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
1005 self.add_array(Keys.Tokenizer.MERGES, merges)
1006
1007 def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
1008 self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
1009
1010 def add_token_type_count(self, value: int) -> None:
1011 self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
1012
1013 def add_token_scores(self, scores: Sequence[float]) -> None:
1014 self.add_array(Keys.Tokenizer.SCORES, scores)
1015
1016 def add_bos_token_id(self, id: int) -> None:
1017 self.add_uint32(Keys.Tokenizer.BOS_ID, id)
1018
1019 def add_eos_token_id(self, id: int) -> None:
1020 self.add_uint32(Keys.Tokenizer.EOS_ID, id)
1021
1022 def add_unk_token_id(self, id: int) -> None:
1023 self.add_uint32(Keys.Tokenizer.UNK_ID, id)
1024
1025 def add_sep_token_id(self, id: int) -> None:
1026 self.add_uint32(Keys.Tokenizer.SEP_ID, id)
1027
1028 def add_pad_token_id(self, id: int) -> None:
1029 self.add_uint32(Keys.Tokenizer.PAD_ID, id)
1030
1031 def add_mask_token_id(self, id: int) -> None:
1032 self.add_uint32(Keys.Tokenizer.MASK_ID, id)
1033
1034 def add_add_bos_token(self, value: bool) -> None:
1035 self.add_bool(Keys.Tokenizer.ADD_BOS, value)
1036
1037 def add_add_eos_token(self, value: bool) -> None:
1038 self.add_bool(Keys.Tokenizer.ADD_EOS, value)
1039
1040 def add_add_sep_token(self, value: bool) -> None:
1041 self.add_bool(Keys.Tokenizer.ADD_SEP, value)
1042
1043 def add_add_space_prefix(self, value: bool) -> None:
1044 self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
1045
1046 def add_remove_extra_whitespaces(self, value: bool) -> None:
1047 self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
1048
1049 def add_precompiled_charsmap(self, charsmap: bytes) -> None:
1050 self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
1051
1052 def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
1053 if not isinstance(value, str):
1054 template_default = None
1055 template_names = set()
1056
1057 for choice in value:
1058 name = choice.get('name', '')
1059 template = choice.get('template')
1060
1061 # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
1062 name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
1063
1064 if name and template is not None:
1065 if name == 'default':
1066 template_default = template
1067 else:
1068 template_names.add(name)
1069 self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
1070
1071 if template_names:
1072 self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
1073
1074 if template_default is None:
1075 return
1076
1077 value = template_default
1078
1079 self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
1080
1081 def add_eot_token_id(self, id: int) -> None:
1082 self.add_uint32(Keys.Tokenizer.EOT_ID, id)
1083
1084 def add_eom_token_id(self, id: int) -> None:
1085 self.add_uint32(Keys.Tokenizer.EOM_ID, id)
1086
1087 def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
1088 self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
1089
1090 # for vision models
1091
1092 def add_clip_has_vision_encoder(self, value: bool) -> None:
1093 self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
1094
1095 def add_clip_has_audio_encoder(self, value: bool) -> None:
1096 self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
1097
1098 def add_clip_projector_type(self, value: str) -> None:
1099 self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
1100
1101 def add_clip_vision_projector_type(self, value: str) -> None:
1102 self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
1103
1104 def add_vision_projection_dim(self, value: int) -> None:
1105 self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
1106
1107 def add_vision_patch_size(self, value: int) -> None:
1108 self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
1109
1110 def add_vision_embedding_length(self, value: int) -> None:
1111 self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
1112
1113 def add_vision_feed_forward_length(self, value: int) -> None:
1114 self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
1115
1116 def add_vision_block_count(self, value: int) -> None:
1117 self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
1118
1119 def add_vision_head_count(self, value: int) -> None:
1120 self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
1121
1122 def add_vision_attention_layernorm_eps(self, value: float) -> None:
1123 self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
1124
1125 def add_vision_image_size(self, value: int) -> None:
1126 self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
1127
1128 def add_vision_max_pixels(self, value: int) -> None:
1129 self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value)
1130
1131 def add_vision_min_pixels(self, value: int) -> None:
1132 self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
1133
1134 def add_vision_preproc_image_size(self, value: int) -> None:
1135 self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
1136
1137 def add_vision_image_mean(self, values: Sequence[float]) -> None:
1138 self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
1139
1140 def add_vision_image_std(self, values: Sequence[float]) -> None:
1141 self.add_array(Keys.ClipVision.IMAGE_STD, values)
1142
1143 def add_vision_spatial_merge_size(self, value: int) -> None:
1144 self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
1145
1146 def add_vision_use_gelu(self, value: bool) -> None:
1147 self.add_bool(Keys.ClipVision.USE_GELU, value)
1148
1149 def add_vision_use_silu(self, value: bool) -> None:
1150 self.add_bool(Keys.ClipVision.USE_SILU, value)
1151
1152 def add_vision_projector_scale_factor(self, value: int) -> None:
1153 self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
1154
1155 def add_vision_n_wa_pattern(self, value: int) -> None:
1156 """Add window attention pattern interval for vision models.
1157
1158 This defines the pattern interval for window attention vs full attention layers.
1159 For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
1160 while other layers use window attention.
1161
1162 Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
1163 """
1164 self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
1165
1166 def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
1167 """Add explicit layer indexes that use full attention in vision models.
1168
1169 This specifies the exact layer indices (0-based) that should use full attention
1170 instead of window attention. All other layers will use window attention.
1171
1172 Args:
1173 layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
1174
1175 Used by models like YoutuVL where full attention layers are explicitly specified
1176 rather than following a regular pattern.
1177
1178 Difference from add_vision_n_wa_pattern:
1179 - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
1180 - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
1181 """
1182 self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
1183
1184 def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
1185 self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
1186
1187 def add_vision_window_size(self, value: int) -> None:
1188 self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
1189
1190 # audio models
1191
1192 def add_clip_audio_projector_type(self, value: str) -> None:
1193 self.add_string(Keys.ClipAudio.PROJECTOR_TYPE, value)
1194
1195 def add_audio_projection_dim(self, value: int) -> None:
1196 self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
1197
1198 def add_audio_embedding_length(self, value: int) -> None:
1199 self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
1200
1201 def add_audio_feed_forward_length(self, value: int) -> None:
1202 self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
1203
1204 def add_audio_block_count(self, value: int) -> None:
1205 self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
1206
1207 def add_audio_head_count(self, value: int) -> None:
1208 self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
1209
1210 def add_audio_attention_layernorm_eps(self, value: float) -> None:
1211 self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
1212
1213 def add_audio_num_mel_bins(self, value: int) -> None:
1214 self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
1215
1216 def add_audio_stack_factor(self, value: int) -> None:
1217 self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
1218
1219 def add_xielu_alpha_p(self, values: Sequence[float]):
1220 self.add_array(Keys.xIELU.ALPHA_P, values)
1221
1222 def add_xielu_alpha_n(self, values: Sequence[float]):
1223 self.add_array(Keys.xIELU.ALPHA_N, values)
1224
1225 def add_xielu_beta(self, values: Sequence[float]):
1226 self.add_array(Keys.xIELU.BETA, values)
1227
1228 def add_xielu_eps(self, values: Sequence[float]):
1229 self.add_array(Keys.xIELU.EPS, values)
1230
1231 # diffusion models
1232
1233 def add_diffusion_shift_logits(self, value: bool) -> None:
1234 self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value)
1235
1236 def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
1237 pack_prefix = ''
1238 if not skip_pack_prefix:
1239 pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
1240 return struct.pack(f'{pack_prefix}{fmt}', value)
1241
1242 def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
1243 kv_data = bytearray()
1244
1245 if add_vtype:
1246 kv_data += self._pack("I", vtype)
1247
1248 pack_fmt = self._simple_value_packing.get(vtype)
1249 if pack_fmt is not None:
1250 kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
1251 elif vtype == GGUFValueType.STRING:
1252 encoded_val = val.encode("utf-8") if isinstance(val, str) else val
1253 kv_data += self._pack("Q", len(encoded_val))
1254 kv_data += encoded_val
1255 elif vtype == GGUFValueType.ARRAY:
1256
1257 if not isinstance(val, Sequence):
1258 raise ValueError("Invalid GGUF metadata array, expecting sequence")
1259
1260 if len(val) == 0:
1261 raise ValueError("Invalid GGUF metadata array. Empty array")
1262
1263 if sub_type is not None:
1264 ltype = sub_type
1265 elif isinstance(val, bytes):
1266 ltype = GGUFValueType.UINT8
1267 else:
1268 ltype = GGUFValueType.get_type(val[0])
1269 if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
1270 raise ValueError("All items in a GGUF array should be of the same type")
1271 kv_data += self._pack("I", ltype)
1272 kv_data += self._pack("Q", len(val))
1273 for item in val:
1274 kv_data += self._pack_val(item, ltype, add_vtype=False)
1275 else:
1276 raise ValueError("Invalid GGUF metadata value type or value")
1277
1278 return kv_data
1279
1280 @staticmethod
1281 def format_n_bytes_to_str(num: int) -> str:
1282 if num == 0:
1283 return "negligible - metadata only"
1284 fnum = float(num)
1285 for unit in ("", "K", "M", "G"):
1286 if abs(fnum) < 1000.0:
1287 return f"{fnum:3.1f}{unit}"
1288 fnum /= 1000.0
1289 return f"{fnum:.1f}T - over 1TB, split recommended"