1from __future__ import annotations
   2
   3import logging
   4import os
   5import shutil
   6import struct
   7import sys
   8import tempfile
   9from dataclasses import dataclass
  10from enum import Enum, auto
  11from math import prod
  12from pathlib import Path
  13from io import BufferedWriter
  14from typing import IO, Any, Sequence, Mapping
  15from string import ascii_letters, digits
  16
  17import numpy as np
  18
  19from .constants import (
  20    GGUF_DEFAULT_ALIGNMENT,
  21    GGUF_MAGIC,
  22    GGUF_VERSION,
  23    GGMLQuantizationType,
  24    GGUFEndian,
  25    GGUFValueType,
  26    Keys,
  27    RopeScalingType,
  28    PoolingType,
  29    TokenType,
  30    ExpertGatingFuncType,
  31)
  32
  33from .quants import quant_shape_from_byte_shape
  34
  35logger = logging.getLogger(__name__)
  36
  37
  38SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
  39
  40
  41@dataclass
  42class TensorInfo:
  43    shape: Sequence[int]
  44    dtype: GGMLQuantizationType
  45    nbytes: int
  46    tensor: np.ndarray[Any, Any] | None = None
  47
  48
  49@dataclass
  50class GGUFValue:
  51    value: Any
  52    type: GGUFValueType
  53    sub_type: GGUFValueType | None = None
  54
  55
  56class WriterState(Enum):
  57    NO_FILE = auto()
  58    EMPTY   = auto()
  59    HEADER  = auto()
  60    KV_DATA = auto()
  61    TI_DATA = auto()
  62    WEIGHTS = auto()
  63
  64
  65class GGUFWriter:
  66    fout: list[BufferedWriter] | None
  67    path: Path | None
  68    temp_file: tempfile.SpooledTemporaryFile[bytes] | None
  69    tensors: list[dict[str, TensorInfo]]
  70    kv_data: list[dict[str, GGUFValue]]
  71    state: WriterState
  72    _simple_value_packing = {
  73        GGUFValueType.UINT8:   "B",
  74        GGUFValueType.INT8:    "b",
  75        GGUFValueType.UINT16:  "H",
  76        GGUFValueType.INT16:   "h",
  77        GGUFValueType.UINT32:  "I",
  78        GGUFValueType.INT32:   "i",
  79        GGUFValueType.FLOAT32: "f",
  80        GGUFValueType.UINT64:  "Q",
  81        GGUFValueType.INT64:   "q",
  82        GGUFValueType.FLOAT64: "d",
  83        GGUFValueType.BOOL:    "?",
  84    }
  85
  86    def __init__(
  87        self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
  88        split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
  89    ):
  90        self.fout = None
  91        self.path = Path(path) if path else None
  92        self.arch = arch
  93        self.endianess = endianess
  94        self.data_alignment = GGUF_DEFAULT_ALIGNMENT
  95        self.use_temp_file = use_temp_file
  96        self.temp_file = None
  97        self.tensors = [{}]
  98        self.kv_data = [{}]
  99        self.split_max_tensors = split_max_tensors
 100        self.split_max_size = split_max_size
 101        self.dry_run = dry_run
 102        self.small_first_shard = small_first_shard
 103        logger.info("gguf: This GGUF file is for {0} Endian only".format(
 104            "Big" if self.endianess == GGUFEndian.BIG else "Little",
 105        ))
 106        self.state = WriterState.NO_FILE
 107
 108        if self.small_first_shard:
 109            self.tensors.append({})
 110
 111        self.add_architecture()
 112
 113    def get_total_parameter_count(self) -> tuple[int, int, int, int]:
 114        total_params = 0
 115        shared_params = 0
 116        expert_params = 0
 117
 118        expert_sum = 0
 119        n_expert_tensors = 0
 120
 121        last_lora_a: tuple[str, TensorInfo] | None = None
 122
 123        for tensors in self.tensors:
 124            for name, info in tensors.items():
 125
 126                shape = info.shape
 127
 128                if name.endswith(".lora_a"):
 129                    last_lora_a = (name, info)
 130                    continue
 131                elif name.endswith(".lora_b"):
 132                    if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
 133                        # Bail when the LoRA pair can't be found trivially
 134                        logger.warning("can't measure LoRA size correctly, tensor order is unusual")
 135                        return 0, 0, 0, 0
 136                    else:
 137                        shape = (*shape[:-1], last_lora_a[1].shape[-1])
 138
 139                size = prod(shape)
 140
 141                if "_exps." in name:
 142                    expert_count = shape[-2 if ".bias" in name else -3]
 143                    expert_params += (size // expert_count)
 144                    expert_sum += expert_count
 145                    n_expert_tensors += 1
 146                else:
 147                    shared_params += size
 148
 149                total_params += size
 150
 151        # Hopefully this should work even for variable-expert-count models
 152        expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
 153
 154        # Negate the total to signal it's likely not exact
 155        if last_lora_a is not None:
 156            total_params = -total_params
 157
 158        # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
 159        return total_params, shared_params, expert_params, expert_count
 160
 161    def format_shard_names(self, path: Path) -> list[Path]:
 162        if len(self.tensors) == 1:
 163            return [path]
 164        return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
 165
 166    def open_output_file(self, path: Path | None = None) -> None:
 167        if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
 168            # allow calling this multiple times as long as the path is the same
 169            return
 170
 171        if self.state is not WriterState.NO_FILE:
 172            raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
 173
 174        if path is not None:
 175            self.path = path
 176
 177        if self.path is not None:
 178            filenames = self.print_plan()
 179            self.fout = [open(filename, "wb") for filename in filenames]
 180            self.state = WriterState.EMPTY
 181
 182    def print_plan(self) -> list[Path]:
 183        logger.info("Writing the following files:")
 184        assert self.path is not None
 185        filenames = self.format_shard_names(self.path)
 186        assert len(filenames) == len(self.tensors)
 187        for name, tensors in zip(filenames, self.tensors):
 188            logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
 189
 190        if self.dry_run:
 191            logger.info("Dry run, not writing files")
 192            for name in filenames:
 193                print(name)  # noqa: NP100
 194            exit()
 195
 196        return filenames
 197
 198    def add_shard_kv_data(self) -> None:
 199        if len(self.tensors) == 1:
 200            return
 201
 202        total_tensors = sum(len(t) for t in self.tensors)
 203        assert self.fout is not None
 204        total_splits = len(self.fout)
 205        self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
 206        for i, kv_data in enumerate(self.kv_data):
 207            kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
 208            kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
 209            kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
 210
 211    def write_header_to_file(self, path: Path | None = None) -> None:
 212        if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
 213            logger.warning("Model fails split requirements, not splitting")
 214
 215        self.open_output_file(path)
 216
 217        if self.state is not WriterState.EMPTY:
 218            raise ValueError(f'Expected output file to be empty, got {self.state}')
 219
 220        assert self.fout is not None
 221        assert len(self.fout) == len(self.tensors)
 222        assert len(self.kv_data) == 1
 223
 224        self.add_shard_kv_data()
 225
 226        for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
 227            fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
 228            fout.write(self._pack("I", GGUF_VERSION))
 229            fout.write(self._pack("Q", len(tensors)))
 230            fout.write(self._pack("Q", len(kv_data)))
 231            fout.flush()
 232        self.state = WriterState.HEADER
 233
 234    def write_kv_data_to_file(self) -> None:
 235        if self.state is not WriterState.HEADER:
 236            raise ValueError(f'Expected output file to contain the header, got {self.state}')
 237        assert self.fout is not None
 238
 239        for fout, kv_data in zip(self.fout, self.kv_data):
 240            kv_bytes = bytearray()
 241
 242            for key, val in kv_data.items():
 243                kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
 244                kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
 245
 246            fout.write(kv_bytes)
 247
 248        self.flush()
 249        self.state = WriterState.KV_DATA
 250
 251    def write_ti_data_to_file(self) -> None:
 252        if self.state is not WriterState.KV_DATA:
 253            raise ValueError(f'Expected output file to contain KV data, got {self.state}')
 254        assert self.fout is not None
 255
 256        for fout, tensors in zip(self.fout, self.tensors):
 257            ti_data = bytearray()
 258            offset_tensor = 0
 259
 260            for name, ti in tensors.items():
 261                ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
 262                n_dims = len(ti.shape)
 263                ti_data += self._pack("I", n_dims)
 264                for j in range(n_dims):
 265                    ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
 266                ti_data += self._pack("I", ti.dtype)
 267                ti_data += self._pack("Q", offset_tensor)
 268                offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
 269
 270            fout.write(ti_data)
 271            fout.flush()
 272        self.state = WriterState.TI_DATA
 273
 274    def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
 275        if any(key in kv_data for kv_data in self.kv_data):
 276            logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
 277
 278        self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
 279
 280    def add_uint8(self, key: str, val: int) -> None:
 281        self.add_key_value(key,val, GGUFValueType.UINT8)
 282
 283    def add_int8(self, key: str, val: int) -> None:
 284        self.add_key_value(key, val, GGUFValueType.INT8)
 285
 286    def add_uint16(self, key: str, val: int) -> None:
 287        self.add_key_value(key, val, GGUFValueType.UINT16)
 288
 289    def add_int16(self, key: str, val: int) -> None:
 290        self.add_key_value(key, val, GGUFValueType.INT16)
 291
 292    def add_uint32(self, key: str, val: int) -> None:
 293        self.add_key_value(key, val, GGUFValueType.UINT32)
 294
 295    def add_int32(self, key: str, val: int) -> None:
 296        self.add_key_value(key, val, GGUFValueType.INT32)
 297
 298    def add_float32(self, key: str, val: float) -> None:
 299        self.add_key_value(key, val, GGUFValueType.FLOAT32)
 300
 301    def add_uint64(self, key: str, val: int) -> None:
 302        self.add_key_value(key, val, GGUFValueType.UINT64)
 303
 304    def add_int64(self, key: str, val: int) -> None:
 305        self.add_key_value(key, val, GGUFValueType.INT64)
 306
 307    def add_float64(self, key: str, val: float) -> None:
 308        self.add_key_value(key, val, GGUFValueType.FLOAT64)
 309
 310    def add_bool(self, key: str, val: bool) -> None:
 311        self.add_key_value(key, val, GGUFValueType.BOOL)
 312
 313    def add_string(self, key: str, val: str) -> None:
 314        if not val:
 315            return
 316        self.add_key_value(key, val, GGUFValueType.STRING)
 317
 318    def add_array(self, key: str, val: Sequence[Any]) -> None:
 319        if len(val) == 0:
 320            return
 321        self.add_key_value(key, val, GGUFValueType.ARRAY)
 322
 323    @staticmethod
 324    def ggml_pad(x: int, n: int) -> int:
 325        return ((x + n - 1) // n) * n
 326
 327    def add_tensor_info(
 328        self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
 329        tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
 330    ) -> None:
 331        if self.state is not WriterState.NO_FILE:
 332            raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
 333
 334        if any(name in tensors for tensors in self.tensors):
 335            raise ValueError(f'Duplicated tensor name {name!r}')
 336
 337        if raw_dtype is None:
 338            if tensor_dtype == np.float16:
 339                dtype = GGMLQuantizationType.F16
 340            elif tensor_dtype == np.float32:
 341                dtype = GGMLQuantizationType.F32
 342            elif tensor_dtype == np.float64:
 343                dtype = GGMLQuantizationType.F64
 344            elif tensor_dtype == np.int8:
 345                dtype = GGMLQuantizationType.I8
 346            elif tensor_dtype == np.int16:
 347                dtype = GGMLQuantizationType.I16
 348            elif tensor_dtype == np.int32:
 349                dtype = GGMLQuantizationType.I32
 350            elif tensor_dtype == np.int64:
 351                dtype = GGMLQuantizationType.I64
 352            else:
 353                raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
 354        else:
 355            dtype = raw_dtype
 356            if tensor_dtype == np.uint8:
 357                tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
 358
 359        # make sure there is at least one tensor before splitting
 360        if len(self.tensors[-1]) > 0:
 361            if (  # split when over tensor limit
 362                self.split_max_tensors != 0
 363                and len(self.tensors[-1]) >= self.split_max_tensors
 364            ) or (   # split when over size limit
 365                self.split_max_size != 0
 366                and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
 367            ):
 368                self.tensors.append({})
 369
 370        self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
 371
 372    def add_tensor(
 373        self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
 374        raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
 375    ) -> None:
 376        # if tensor endianness is not passed, assume it's native to system
 377        if tensor_endianess is None:
 378            tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
 379
 380        if tensor_endianess != self.endianess:
 381            # Don't byteswap inplace since lazy copies cannot handle it
 382            tensor = tensor.byteswap(inplace=False)
 383        if self.use_temp_file and self.temp_file is None:
 384            fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
 385            fp.seek(0)
 386            self.temp_file = fp
 387
 388        shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
 389        self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
 390
 391        if self.temp_file is None:
 392            self.tensors[-1][name].tensor = tensor
 393            return
 394
 395        tensor.tofile(self.temp_file)
 396        self.write_padding(self.temp_file, tensor.nbytes)
 397
 398    def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
 399        pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
 400        if pad != 0:
 401            fp.write(bytes([0] * pad))
 402
 403    def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
 404        if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
 405            raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
 406        assert self.fout is not None
 407
 408        # if tensor endianness is not passed, assume it's native to system
 409        if tensor_endianess is None:
 410            tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
 411
 412        if tensor_endianess != self.endianess:
 413            # Don't byteswap inplace since lazy copies cannot handle it
 414            tensor = tensor.byteswap(inplace=False)
 415
 416        file_id = -1
 417        for i, tensors in enumerate(self.tensors):
 418            if len(tensors) > 0:
 419                file_id = i
 420                break
 421
 422        fout = self.fout[file_id]
 423
 424        # pop the first tensor info
 425        # TODO: cleaner way to get the first key
 426        first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
 427        ti = self.tensors[file_id].pop(first_tensor_name)
 428        assert ti.nbytes == tensor.nbytes
 429
 430        self.write_padding(fout, fout.tell())
 431        tensor.tofile(fout)
 432        self.write_padding(fout, tensor.nbytes)
 433
 434        self.state = WriterState.WEIGHTS
 435
 436    def write_tensors_to_file(self, *, progress: bool = False) -> None:
 437        self.write_ti_data_to_file()
 438
 439        assert self.fout is not None
 440
 441        for fout in self.fout:
 442            self.write_padding(fout, fout.tell())
 443
 444        if self.temp_file is None:
 445            shard_bar = None
 446            bar = None
 447
 448            if progress:
 449                from tqdm import tqdm
 450
 451                total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
 452
 453                if len(self.fout) > 1:
 454                    shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
 455                bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 456
 457            for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
 458                if shard_bar is not None:
 459                    shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
 460                    total = sum(ti.nbytes for ti in tensors.values())
 461                    shard_bar.reset(total=(total if total > 0 else None))
 462
 463                # relying on the fact that Python dicts preserve insertion order (since 3.7)
 464                for ti in tensors.values():
 465                    assert ti.tensor is not None  # can only iterate once over the tensors
 466                    assert ti.tensor.nbytes == ti.nbytes
 467                    ti.tensor.tofile(fout)
 468                    if shard_bar is not None:
 469                        shard_bar.update(ti.nbytes)
 470                    if bar is not None:
 471                        bar.update(ti.nbytes)
 472                    self.write_padding(fout, ti.nbytes)
 473                    ti.tensor = None
 474        else:
 475            self.temp_file.seek(0)
 476
 477            shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
 478            self.flush()
 479            self.temp_file.close()
 480
 481        self.state = WriterState.WEIGHTS
 482
 483    def flush(self) -> None:
 484        assert self.fout is not None
 485        for fout in self.fout:
 486            fout.flush()
 487
 488    def close(self) -> None:
 489        if self.fout is not None:
 490            for fout in self.fout:
 491                fout.close()
 492            self.fout = None
 493
 494    def add_type(self, type_name: str) -> None:
 495        self.add_string(Keys.General.TYPE, type_name)
 496
 497    def add_architecture(self) -> None:
 498        self.add_string(Keys.General.ARCHITECTURE, self.arch)
 499
 500    def add_quantization_version(self, quantization_version: int) -> None:
 501        self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
 502
 503    def add_custom_alignment(self, alignment: int) -> None:
 504        self.data_alignment = alignment
 505        self.add_uint32(Keys.General.ALIGNMENT, alignment)
 506
 507    def add_file_type(self, ftype: int) -> None:
 508        self.add_uint32(Keys.General.FILE_TYPE, ftype)
 509
 510    def add_sampling_sequence(self, sequence: str) -> None:
 511        self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
 512
 513    def add_sampling_top_k(self, top_k: int) -> None:
 514        self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
 515
 516    def add_sampling_top_p(self, top_p: float) -> None:
 517        self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
 518
 519    def add_sampling_min_p(self, min_p: float) -> None:
 520        self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
 521
 522    def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
 523        self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
 524
 525    def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
 526        self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
 527
 528    def add_sampling_temp(self, temp: float) -> None:
 529        self.add_float32(Keys.General.SAMPLING_TEMP, temp)
 530
 531    def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
 532        self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
 533
 534    def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
 535        self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
 536
 537    def add_sampling_mirostat(self, mirostat: int) -> None:
 538        self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
 539
 540    def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
 541        self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
 542
 543    def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
 544        self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
 545
 546    def add_name(self, name: str) -> None:
 547        self.add_string(Keys.General.NAME, name)
 548
 549    def add_author(self, author: str) -> None:
 550        self.add_string(Keys.General.AUTHOR, author)
 551
 552    def add_version(self, version: str) -> None:
 553        self.add_string(Keys.General.VERSION, version)
 554
 555    def add_organization(self, organization: str) -> None:
 556        self.add_string(Keys.General.ORGANIZATION, organization)
 557
 558    def add_finetune(self, finetune: str) -> None:
 559        self.add_string(Keys.General.FINETUNE, finetune)
 560
 561    def add_basename(self, basename: str) -> None:
 562        self.add_string(Keys.General.BASENAME, basename)
 563
 564    def add_description(self, description: str) -> None:
 565        self.add_string(Keys.General.DESCRIPTION, description)
 566
 567    def add_quantized_by(self, quantized: str) -> None:
 568        self.add_string(Keys.General.QUANTIZED_BY, quantized)
 569
 570    def add_size_label(self, size_label: str) -> None:
 571        self.add_string(Keys.General.SIZE_LABEL, size_label)
 572
 573    def add_license(self, license: str) -> None:
 574        self.add_string(Keys.General.LICENSE, license)
 575
 576    def add_license_name(self, license: str) -> None:
 577        self.add_string(Keys.General.LICENSE_NAME, license)
 578
 579    def add_license_link(self, license: str) -> None:
 580        self.add_string(Keys.General.LICENSE_LINK, license)
 581
 582    def add_url(self, url: str) -> None:
 583        self.add_string(Keys.General.URL, url)
 584
 585    def add_doi(self, doi: str) -> None:
 586        self.add_string(Keys.General.DOI, doi)
 587
 588    def add_uuid(self, uuid: str) -> None:
 589        self.add_string(Keys.General.UUID, uuid)
 590
 591    def add_repo_url(self, repo_url: str) -> None:
 592        self.add_string(Keys.General.REPO_URL, repo_url)
 593
 594    def add_source_url(self, url: str) -> None:
 595        self.add_string(Keys.General.SOURCE_URL, url)
 596
 597    def add_source_doi(self, doi: str) -> None:
 598        self.add_string(Keys.General.SOURCE_DOI, doi)
 599
 600    def add_source_uuid(self, uuid: str) -> None:
 601        self.add_string(Keys.General.SOURCE_UUID, uuid)
 602
 603    def add_source_repo_url(self, repo_url: str) -> None:
 604        self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
 605
 606    def add_base_model_count(self, source_count: int) -> None:
 607        self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
 608
 609    def add_base_model_name(self, source_id: int, name: str) -> None:
 610        self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
 611
 612    def add_base_model_author(self, source_id: int, author: str) -> None:
 613        self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
 614
 615    def add_base_model_version(self, source_id: int, version: str) -> None:
 616        self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
 617
 618    def add_base_model_organization(self, source_id: int, organization: str) -> None:
 619        self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
 620
 621    def add_base_model_description(self, source_id: int, description: str) -> None:
 622        self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
 623
 624    def add_base_model_url(self, source_id: int, url: str) -> None:
 625        self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
 626
 627    def add_base_model_doi(self, source_id: int, doi: str) -> None:
 628        self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
 629
 630    def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
 631        self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
 632
 633    def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
 634        self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
 635
 636    def add_dataset_count(self, source_count: int) -> None:
 637        self.add_uint32(Keys.General.DATASET_COUNT, source_count)
 638
 639    def add_dataset_name(self, source_id: int, name: str) -> None:
 640        self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
 641
 642    def add_dataset_author(self, source_id: int, author: str) -> None:
 643        self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
 644
 645    def add_dataset_version(self, source_id: int, version: str) -> None:
 646        self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
 647
 648    def add_dataset_organization(self, source_id: int, organization: str) -> None:
 649        self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
 650
 651    def add_dataset_description(self, source_id: int, description: str) -> None:
 652        self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
 653
 654    def add_dataset_url(self, source_id: int, url: str) -> None:
 655        self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
 656
 657    def add_dataset_doi(self, source_id: int, doi: str) -> None:
 658        self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
 659
 660    def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
 661        self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
 662
 663    def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
 664        self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
 665
 666    def add_tags(self, tags: Sequence[str]) -> None:
 667        self.add_array(Keys.General.TAGS, tags)
 668
 669    def add_languages(self, languages: Sequence[str]) -> None:
 670        self.add_array(Keys.General.LANGUAGES, languages)
 671
 672    def add_tensor_data_layout(self, layout: str) -> None:
 673        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 674
 675    def add_vocab_size(self, size: int) -> None:
 676        self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
 677
 678    def add_context_length(self, length: int) -> None:
 679        self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
 680
 681    def add_embedding_length(self, length: int) -> None:
 682        self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
 683
 684    def add_embedding_length_out(self, length: int) -> None:
 685        self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
 686
 687    def add_features_length(self, length: int) -> None:
 688        self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
 689
 690    def add_posnet_embedding_length(self, length: int) -> None:
 691        self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
 692
 693    def add_posnet_block_count(self, length: int) -> None:
 694        self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
 695
 696    def add_convnext_embedding_length(self, length: int) -> None:
 697        self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
 698
 699    def add_convnext_block_count(self, length: int) -> None:
 700        self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
 701
 702    def add_shortconv_l_cache(self, length: int) -> None:
 703        self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
 704
 705    def add_block_count(self, length: int) -> None:
 706        self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
 707
 708    def add_leading_dense_block_count(self, length: int) -> None:
 709        self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
 710
 711    def add_full_attention_interval(self, interval: int) -> None:
 712        self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
 713
 714    def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
 715        if isinstance(length, int):
 716            self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 717        else:
 718            self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 719
 720    def add_expert_feed_forward_length(self, length: int) -> None:
 721        self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 722
 723    def add_expert_shared_feed_forward_length(self, length: int) -> None:
 724        self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 725
 726    def add_expert_chunk_feed_forward_length(self, length: int) -> None:
 727        self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 728
 729    def add_parallel_residual(self, use: bool) -> None:
 730        self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
 731
 732    def add_decoder_start_token_id(self, id: int) -> None:
 733        self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
 734
 735    def add_decoder_block_count(self, value: int) -> None:
 736        self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
 737
 738    def add_embedding_length_per_layer_input(self, value: int) -> None:
 739        self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
 740
 741    def add_altup_active_idx(self, val: int) -> None:
 742        self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
 743
 744    def add_altup_num_inputs(self, val: int) -> None:
 745        self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
 746
 747    def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
 748        self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
 749
 750    def add_head_count(self, count: int | Sequence[int]) -> None:
 751        if isinstance(count, int):
 752            self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
 753        else:
 754            self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
 755
 756    def add_head_count_kv(self, count: int | Sequence[int]) -> None:
 757        if isinstance(count, int):
 758            self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
 759        else:
 760            self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
 761
 762    def add_key_length(self, length: int) -> None:
 763        self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
 764
 765    def add_value_length(self, length: int) -> None:
 766        self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
 767
 768    def add_key_length_mla(self, length: int) -> None:
 769        self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
 770
 771    def add_value_length_mla(self, length: int) -> None:
 772        self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
 773
 774    def add_max_alibi_bias(self, bias: float) -> None:
 775        self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
 776
 777    def add_clamp_kqv(self, value: float) -> None:
 778        self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
 779
 780    def add_shared_kv_layers(self, value: int) -> None:
 781        self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
 782
 783    def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
 784        key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
 785        if isinstance(value, int):
 786            self.add_uint32(key, value)
 787        else:
 788            self.add_array(key, value)
 789
 790    def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
 791        self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
 792        self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
 793
 794    def add_logit_scale(self, value: float) -> None:
 795        self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
 796
 797    def add_attn_logit_softcapping(self, value: float) -> None:
 798        self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
 799
 800    def add_router_logit_softcapping(self, value: float) -> None:
 801        self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
 802
 803    def add_final_logit_softcapping(self, value: float) -> None:
 804        self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
 805
 806    def add_expert_count(self, count: int) -> None:
 807        self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
 808
 809    def add_expert_used_count(self, count: int) -> None:
 810        self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
 811
 812    def add_expert_shared_count(self, count: int) -> None:
 813        self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
 814
 815    def add_expert_group_count(self, count: int) -> None:
 816        self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
 817
 818    def add_expert_group_used_count(self, count: int) -> None:
 819        self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
 820
 821    def add_expert_weights_scale(self, value: float) -> None:
 822        self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
 823
 824    def add_expert_weights_norm(self, value: bool) -> None:
 825        self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
 826
 827    def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
 828        self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
 829
 830    def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None:
 831        self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values)
 832
 833    def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None:
 834        self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values)
 835
 836    def add_expert_group_scale(self, value: float) -> None:
 837        self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
 838
 839    def add_experts_per_group(self, count: int) -> None:
 840        self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
 841
 842    def add_moe_every_n_layers(self, value: int) -> None:
 843        self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
 844
 845    def add_nextn_predict_layers(self, count: int) -> None:
 846        self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
 847
 848    def add_swin_norm(self, value: bool) -> None:
 849        self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
 850
 851    def add_rescale_every_n_layers(self, count: int) -> None:
 852        self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
 853
 854    def add_time_mix_extra_dim(self, dim: int) -> None:
 855        self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
 856
 857    def add_time_decay_extra_dim(self, dim: int) -> None:
 858        self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
 859
 860    def add_residual_scale(self, value: float) -> None:
 861        self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
 862
 863    def add_embedding_scale(self, value: float) -> None:
 864        self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
 865
 866    def add_wkv_head_size(self, size: int) -> None:
 867        self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
 868
 869    def add_token_shift_count(self, count: int) -> None:
 870        self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
 871
 872    def add_interleave_moe_layer_step(self, value: int) -> None:
 873        self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
 874
 875    def add_layer_norm_eps(self, value: float) -> None:
 876        self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
 877
 878    def add_layer_norm_rms_eps(self, value: float) -> None:
 879        self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
 880
 881    def add_group_norm_eps(self, value: float) -> None:
 882        self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
 883
 884    def add_group_norm_groups(self, value: int) -> None:
 885        self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
 886
 887    def add_causal_attention(self, value: bool) -> None:
 888        self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
 889
 890    def add_q_lora_rank(self, length: int) -> None:
 891        self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
 892
 893    def add_kv_lora_rank(self, length: int) -> None:
 894        self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
 895
 896    def add_decay_lora_rank(self, length: int) -> None:
 897        self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
 898
 899    def add_iclr_lora_rank(self, length: int) -> None:
 900        self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
 901
 902    def add_value_residual_mix_lora_rank(self, length: int) -> None:
 903        self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
 904
 905    def add_rope_freq_base_swa(self, value: float) -> None:
 906        self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
 907
 908    def add_gate_lora_rank(self, length: int) -> None:
 909        self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
 910
 911    def add_relative_attn_buckets_count(self, value: int) -> None:
 912        self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
 913
 914    def add_sliding_window(self, value: int) -> None:
 915        self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
 916
 917    def add_attention_scale(self, value: float) -> None:
 918        self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
 919
 920    def add_attn_output_scale(self, value: float) -> None:
 921        self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
 922
 923    def add_attn_temperature_length(self, value: int) -> None:
 924        self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
 925
 926    def add_attn_temperature_scale(self, value: float) -> None:
 927        self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
 928
 929    def add_pooling_type(self, value: PoolingType) -> None:
 930        self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
 931
 932    def add_num_deepstack_layers(self, count: int) -> None:
 933        self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
 934
 935    def add_rope_dimension_count(self, count: int) -> None:
 936        self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
 937
 938    def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
 939        self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
 940
 941    def add_rope_freq_base(self, value: float) -> None:
 942        self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
 943
 944    def add_rope_scaling_type(self, value: RopeScalingType) -> None:
 945        self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
 946
 947    def add_rope_scaling_factor(self, value: float) -> None:
 948        self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
 949
 950    def add_rope_scaling_attn_factors(self, value: float) -> None:
 951        self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
 952
 953    def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
 954        self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
 955
 956    def add_rope_scaling_finetuned(self, value: bool) -> None:
 957        self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
 958
 959    def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
 960        self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
 961
 962    def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
 963        self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
 964
 965    def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
 966        self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
 967
 968    def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
 969        self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
 970
 971    def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
 972        self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
 973
 974    def add_ssm_conv_kernel(self, value: int) -> None:
 975        self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
 976
 977    def add_ssm_inner_size(self, value: int) -> None:
 978        self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
 979
 980    def add_ssm_state_size(self, value: int) -> None:
 981        self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
 982
 983    def add_ssm_time_step_rank(self, value: int) -> None:
 984        self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
 985
 986    def add_ssm_group_count(self, value: int) -> None:
 987        self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
 988
 989    def add_ssm_dt_b_c_rms(self, value: bool) -> None:
 990        self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
 991
 992    def add_kda_head_dim(self, value: int) -> None:
 993        self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)
 994
 995    def add_tokenizer_model(self, model: str) -> None:
 996        self.add_string(Keys.Tokenizer.MODEL, model)
 997
 998    def add_tokenizer_pre(self, pre: str) -> None:
 999        self.add_string(Keys.Tokenizer.PRE, pre)
1000
1001    def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
1002        self.add_array(Keys.Tokenizer.LIST, tokens)
1003
1004    def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
1005        self.add_array(Keys.Tokenizer.MERGES, merges)
1006
1007    def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
1008        self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
1009
1010    def add_token_type_count(self, value: int) -> None:
1011        self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
1012
1013    def add_token_scores(self, scores: Sequence[float]) -> None:
1014        self.add_array(Keys.Tokenizer.SCORES, scores)
1015
1016    def add_bos_token_id(self, id: int) -> None:
1017        self.add_uint32(Keys.Tokenizer.BOS_ID, id)
1018
1019    def add_eos_token_id(self, id: int) -> None:
1020        self.add_uint32(Keys.Tokenizer.EOS_ID, id)
1021
1022    def add_unk_token_id(self, id: int) -> None:
1023        self.add_uint32(Keys.Tokenizer.UNK_ID, id)
1024
1025    def add_sep_token_id(self, id: int) -> None:
1026        self.add_uint32(Keys.Tokenizer.SEP_ID, id)
1027
1028    def add_pad_token_id(self, id: int) -> None:
1029        self.add_uint32(Keys.Tokenizer.PAD_ID, id)
1030
1031    def add_mask_token_id(self, id: int) -> None:
1032        self.add_uint32(Keys.Tokenizer.MASK_ID, id)
1033
1034    def add_add_bos_token(self, value: bool) -> None:
1035        self.add_bool(Keys.Tokenizer.ADD_BOS, value)
1036
1037    def add_add_eos_token(self, value: bool) -> None:
1038        self.add_bool(Keys.Tokenizer.ADD_EOS, value)
1039
1040    def add_add_sep_token(self, value: bool) -> None:
1041        self.add_bool(Keys.Tokenizer.ADD_SEP, value)
1042
1043    def add_add_space_prefix(self, value: bool) -> None:
1044        self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
1045
1046    def add_remove_extra_whitespaces(self, value: bool) -> None:
1047        self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
1048
1049    def add_precompiled_charsmap(self, charsmap: bytes) -> None:
1050        self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
1051
1052    def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
1053        if not isinstance(value, str):
1054            template_default = None
1055            template_names = set()
1056
1057            for choice in value:
1058                name = choice.get('name', '')
1059                template = choice.get('template')
1060
1061                # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
1062                name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
1063
1064                if name and template is not None:
1065                    if name == 'default':
1066                        template_default = template
1067                    else:
1068                        template_names.add(name)
1069                        self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
1070
1071            if template_names:
1072                self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
1073
1074            if template_default is None:
1075                return
1076
1077            value = template_default
1078
1079        self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
1080
1081    def add_eot_token_id(self, id: int) -> None:
1082        self.add_uint32(Keys.Tokenizer.EOT_ID, id)
1083
1084    def add_eom_token_id(self, id: int) -> None:
1085        self.add_uint32(Keys.Tokenizer.EOM_ID, id)
1086
1087    def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
1088        self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
1089
1090    # for vision models
1091
1092    def add_clip_has_vision_encoder(self, value: bool) -> None:
1093        self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
1094
1095    def add_clip_has_audio_encoder(self, value: bool) -> None:
1096        self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
1097
1098    def add_clip_projector_type(self, value: str) -> None:
1099        self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
1100
1101    def add_clip_vision_projector_type(self, value: str) -> None:
1102        self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
1103
1104    def add_vision_projection_dim(self, value: int) -> None:
1105        self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
1106
1107    def add_vision_patch_size(self, value: int) -> None:
1108        self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
1109
1110    def add_vision_embedding_length(self, value: int) -> None:
1111        self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
1112
1113    def add_vision_feed_forward_length(self, value: int) -> None:
1114        self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
1115
1116    def add_vision_block_count(self, value: int) -> None:
1117        self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
1118
1119    def add_vision_head_count(self, value: int) -> None:
1120        self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
1121
1122    def add_vision_attention_layernorm_eps(self, value: float) -> None:
1123        self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
1124
1125    def add_vision_image_size(self, value: int) -> None:
1126        self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
1127
1128    def add_vision_max_pixels(self, value: int) -> None:
1129        self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value)
1130
1131    def add_vision_min_pixels(self, value: int) -> None:
1132        self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
1133
1134    def add_vision_preproc_image_size(self, value: int) -> None:
1135        self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
1136
1137    def add_vision_image_mean(self, values: Sequence[float]) -> None:
1138        self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
1139
1140    def add_vision_image_std(self, values: Sequence[float]) -> None:
1141        self.add_array(Keys.ClipVision.IMAGE_STD, values)
1142
1143    def add_vision_spatial_merge_size(self, value: int) -> None:
1144        self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
1145
1146    def add_vision_use_gelu(self, value: bool) -> None:
1147        self.add_bool(Keys.ClipVision.USE_GELU, value)
1148
1149    def add_vision_use_silu(self, value: bool) -> None:
1150        self.add_bool(Keys.ClipVision.USE_SILU, value)
1151
1152    def add_vision_projector_scale_factor(self, value: int) -> None:
1153        self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
1154
1155    def add_vision_n_wa_pattern(self, value: int) -> None:
1156        """Add window attention pattern interval for vision models.
1157
1158        This defines the pattern interval for window attention vs full attention layers.
1159        For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
1160        while other layers use window attention.
1161
1162        Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
1163        """
1164        self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
1165
1166    def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
1167        """Add explicit layer indexes that use full attention in vision models.
1168
1169        This specifies the exact layer indices (0-based) that should use full attention
1170        instead of window attention. All other layers will use window attention.
1171
1172        Args:
1173            layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
1174
1175        Used by models like YoutuVL where full attention layers are explicitly specified
1176        rather than following a regular pattern.
1177
1178        Difference from add_vision_n_wa_pattern:
1179        - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
1180        - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
1181        """
1182        self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
1183
1184    def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
1185        self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
1186
1187    def add_vision_window_size(self, value: int) -> None:
1188        self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
1189
1190    # audio models
1191
1192    def add_clip_audio_projector_type(self, value: str) -> None:
1193        self.add_string(Keys.ClipAudio.PROJECTOR_TYPE, value)
1194
1195    def add_audio_projection_dim(self, value: int) -> None:
1196        self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
1197
1198    def add_audio_embedding_length(self, value: int) -> None:
1199        self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
1200
1201    def add_audio_feed_forward_length(self, value: int) -> None:
1202        self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
1203
1204    def add_audio_block_count(self, value: int) -> None:
1205        self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
1206
1207    def add_audio_head_count(self, value: int) -> None:
1208        self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
1209
1210    def add_audio_attention_layernorm_eps(self, value: float) -> None:
1211        self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
1212
1213    def add_audio_num_mel_bins(self, value: int) -> None:
1214        self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
1215
1216    def add_audio_stack_factor(self, value: int) -> None:
1217        self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
1218
1219    def add_xielu_alpha_p(self, values: Sequence[float]):
1220        self.add_array(Keys.xIELU.ALPHA_P, values)
1221
1222    def add_xielu_alpha_n(self, values: Sequence[float]):
1223        self.add_array(Keys.xIELU.ALPHA_N, values)
1224
1225    def add_xielu_beta(self, values: Sequence[float]):
1226        self.add_array(Keys.xIELU.BETA, values)
1227
1228    def add_xielu_eps(self, values: Sequence[float]):
1229        self.add_array(Keys.xIELU.EPS, values)
1230
1231    # diffusion models
1232
1233    def add_diffusion_shift_logits(self, value: bool) -> None:
1234        self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value)
1235
1236    def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
1237        pack_prefix = ''
1238        if not skip_pack_prefix:
1239            pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
1240        return struct.pack(f'{pack_prefix}{fmt}', value)
1241
1242    def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
1243        kv_data = bytearray()
1244
1245        if add_vtype:
1246            kv_data += self._pack("I", vtype)
1247
1248        pack_fmt = self._simple_value_packing.get(vtype)
1249        if pack_fmt is not None:
1250            kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
1251        elif vtype == GGUFValueType.STRING:
1252            encoded_val = val.encode("utf-8") if isinstance(val, str) else val
1253            kv_data += self._pack("Q", len(encoded_val))
1254            kv_data += encoded_val
1255        elif vtype == GGUFValueType.ARRAY:
1256
1257            if not isinstance(val, Sequence):
1258                raise ValueError("Invalid GGUF metadata array, expecting sequence")
1259
1260            if len(val) == 0:
1261                raise ValueError("Invalid GGUF metadata array. Empty array")
1262
1263            if sub_type is not None:
1264                ltype = sub_type
1265            elif isinstance(val, bytes):
1266                ltype = GGUFValueType.UINT8
1267            else:
1268                ltype = GGUFValueType.get_type(val[0])
1269                if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
1270                    raise ValueError("All items in a GGUF array should be of the same type")
1271            kv_data += self._pack("I", ltype)
1272            kv_data += self._pack("Q", len(val))
1273            for item in val:
1274                kv_data += self._pack_val(item, ltype, add_vtype=False)
1275        else:
1276            raise ValueError("Invalid GGUF metadata value type or value")
1277
1278        return kv_data
1279
1280    @staticmethod
1281    def format_n_bytes_to_str(num: int) -> str:
1282        if num == 0:
1283            return "negligible - metadata only"
1284        fnum = float(num)
1285        for unit in ("", "K", "M", "G"):
1286            if abs(fnum) < 1000.0:
1287                return f"{fnum:3.1f}{unit}"
1288            fnum /= 1000.0
1289        return f"{fnum:.1f}T - over 1TB, split recommended"