summaryrefslogtreecommitdiff
path: root/llama.cpp/gguf-py/gguf
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/gguf-py/gguf')
-rw-r--r--llama.cpp/gguf-py/gguf/__init__.py9
-rw-r--r--llama.cpp/gguf-py/gguf/constants.py3895
-rw-r--r--llama.cpp/gguf-py/gguf/gguf.py15
-rw-r--r--llama.cpp/gguf-py/gguf/gguf_reader.py367
-rw-r--r--llama.cpp/gguf-py/gguf/gguf_writer.py1289
-rw-r--r--llama.cpp/gguf-py/gguf/lazy.py228
-rw-r--r--llama.cpp/gguf-py/gguf/metadata.py731
-rw-r--r--llama.cpp/gguf-py/gguf/py.typed0
-rw-r--r--llama.cpp/gguf-py/gguf/quants.py1318
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py186
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_dump.py477
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py1621
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_hash.py102
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py216
-rwxr-xr-xllama.cpp/gguf-py/gguf/scripts/gguf_set_metadata.py95
-rw-r--r--llama.cpp/gguf-py/gguf/tensor_mapping.py1948
-rw-r--r--llama.cpp/gguf-py/gguf/utility.py340
-rw-r--r--llama.cpp/gguf-py/gguf/vocab.py891
18 files changed, 13728 insertions, 0 deletions
diff --git a/llama.cpp/gguf-py/gguf/__init__.py b/llama.cpp/gguf-py/gguf/__init__.py
new file mode 100644
index 0000000..243defc
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/__init__.py
@@ -0,0 +1,9 @@
+from .constants import *
+from .lazy import *
+from .gguf_reader import *
+from .gguf_writer import *
+from .quants import *
+from .tensor_mapping import *
+from .vocab import *
+from .utility import *
+from .metadata import *
diff --git a/llama.cpp/gguf-py/gguf/constants.py b/llama.cpp/gguf-py/gguf/constants.py
new file mode 100644
index 0000000..9dab0df
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/constants.py
@@ -0,0 +1,3895 @@
+from __future__ import annotations
+
+from enum import Enum, IntEnum, auto
+from typing import Any
+
+#
+# constants
+#
+
+GGUF_MAGIC = 0x46554747 # "GGUF"
+GGUF_VERSION = 3
+GGUF_DEFAULT_ALIGNMENT = 32
+GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
+
+#
+# metadata keys
+#
+
+
+class Keys:
+ class General:
+ TYPE = "general.type"
+ ARCHITECTURE = "general.architecture"
+ QUANTIZATION_VERSION = "general.quantization_version"
+ ALIGNMENT = "general.alignment"
+ FILE_TYPE = "general.file_type"
+
+ # Recommended Sampler Parameters
+ SAMPLING_SEQUENCE = "general.sampling.sequence"
+ SAMPLING_TOP_K = "general.sampling.top_k"
+ SAMPLING_TOP_P = "general.sampling.top_p"
+ SAMPLING_MIN_P = "general.sampling.min_p"
+ SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
+ SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
+ SAMPLING_TEMP = "general.sampling.temp"
+ SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
+ SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
+ SAMPLING_MIROSTAT = "general.sampling.mirostat"
+ SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
+ SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
+
+ # Authorship Metadata
+ NAME = "general.name"
+ AUTHOR = "general.author"
+ VERSION = "general.version"
+ ORGANIZATION = "general.organization"
+
+ FINETUNE = "general.finetune"
+ BASENAME = "general.basename"
+
+ DESCRIPTION = "general.description"
+ QUANTIZED_BY = "general.quantized_by"
+
+ SIZE_LABEL = "general.size_label"
+
+ # Licensing details
+ LICENSE = "general.license"
+ LICENSE_NAME = "general.license.name"
+ LICENSE_LINK = "general.license.link"
+
+ # Typically represents the converted GGUF repo (Unless native)
+ URL = "general.url" # Model Website/Paper
+ DOI = "general.doi"
+ UUID = "general.uuid"
+ REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
+
+ # Model Source during conversion
+ SOURCE_URL = "general.source.url" # Model Website/Paper
+ SOURCE_DOI = "general.source.doi"
+ SOURCE_UUID = "general.source.uuid"
+ SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
+
+ # Base Model Source. There can be more than one source if it's a merged
+ # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
+ # tracing linage of models as it is finetuned or merged over time.
+ BASE_MODEL_COUNT = "general.base_model.count"
+ BASE_MODEL_NAME = "general.base_model.{id}.name"
+ BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
+ BASE_MODEL_VERSION = "general.base_model.{id}.version"
+ BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
+ BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description"
+ BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
+ BASE_MODEL_DOI = "general.base_model.{id}.doi"
+ BASE_MODEL_UUID = "general.base_model.{id}.uuid"
+ BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
+
+ # Dataset Source
+ DATASET_COUNT = "general.dataset.count"
+ DATASET_NAME = "general.dataset.{id}.name"
+ DATASET_AUTHOR = "general.dataset.{id}.author"
+ DATASET_VERSION = "general.dataset.{id}.version"
+ DATASET_ORGANIZATION = "general.dataset.{id}.organization"
+ DATASET_DESCRIPTION = "general.dataset.{id}.description"
+ DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper
+ DATASET_DOI = "general.dataset.{id}.doi"
+ DATASET_UUID = "general.dataset.{id}.uuid"
+ DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)
+
+ # Array based KV stores
+ TAGS = "general.tags"
+ LANGUAGES = "general.languages"
+
+ class LLM:
+ VOCAB_SIZE = "{arch}.vocab_size"
+ CONTEXT_LENGTH = "{arch}.context_length"
+ EMBEDDING_LENGTH = "{arch}.embedding_length"
+ EMBEDDING_LENGTH_OUT = "{arch}.embedding_length_out"
+ FEATURES_LENGTH = "{arch}.features_length"
+ BLOCK_COUNT = "{arch}.block_count"
+ LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
+ FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
+ EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
+ EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
+ EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
+ USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
+ TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
+ EXPERT_COUNT = "{arch}.expert_count"
+ EXPERT_USED_COUNT = "{arch}.expert_used_count"
+ EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
+ EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
+ EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
+ EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
+ EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
+ EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
+ EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
+ EXPERTS_PER_GROUP = "{arch}.experts_per_group"
+ MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
+ NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
+ NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
+ POOLING_TYPE = "{arch}.pooling_type"
+ LOGIT_SCALE = "{arch}.logit_scale"
+ DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
+ DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
+ ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
+ ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
+ FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
+ SWIN_NORM = "{arch}.swin_norm"
+ RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
+ TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
+ TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
+ RESIDUAL_SCALE = "{arch}.residual_scale"
+ EMBEDDING_SCALE = "{arch}.embedding_scale"
+ TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
+ INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
+ FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
+ ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
+ ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
+ ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
+ EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
+ SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp"
+ SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp"
+ DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
+ DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
+
+ class Attention:
+ HEAD_COUNT = "{arch}.attention.head_count"
+ HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
+ MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
+ CLAMP_KQV = "{arch}.attention.clamp_kqv"
+ KEY_LENGTH = "{arch}.attention.key_length"
+ VALUE_LENGTH = "{arch}.attention.value_length"
+ LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
+ LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
+ GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
+ GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
+ CAUSAL = "{arch}.attention.causal"
+ Q_LORA_RANK = "{arch}.attention.q_lora_rank"
+ KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
+ DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank"
+ ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank"
+ VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank"
+ GATE_LORA_RANK = "{arch}.attention.gate_lora_rank"
+ REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
+ SLIDING_WINDOW = "{arch}.attention.sliding_window"
+ SCALE = "{arch}.attention.scale"
+ OUTPUT_SCALE = "{arch}.attention.output_scale"
+ TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
+ KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
+ VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
+ SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
+ SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
+ TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
+
+ class Rope:
+ DIMENSION_COUNT = "{arch}.rope.dimension_count"
+ DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
+ FREQ_BASE = "{arch}.rope.freq_base"
+ FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
+ SCALING_TYPE = "{arch}.rope.scaling.type"
+ SCALING_FACTOR = "{arch}.rope.scaling.factor"
+ SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
+ SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
+ SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
+ SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
+ SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
+ SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
+ SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
+ SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
+
+ class Split:
+ LLM_KV_SPLIT_NO = "split.no"
+ LLM_KV_SPLIT_COUNT = "split.count"
+ LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
+
+ class SSM:
+ CONV_KERNEL = "{arch}.ssm.conv_kernel"
+ INNER_SIZE = "{arch}.ssm.inner_size"
+ STATE_SIZE = "{arch}.ssm.state_size"
+ TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
+ GROUP_COUNT = "{arch}.ssm.group_count"
+ DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
+
+ class KDA:
+ HEAD_DIM = "{arch}.kda.head_dim"
+
+ class WKV:
+ HEAD_SIZE = "{arch}.wkv.head_size"
+
+ class PosNet:
+ EMBEDDING_LENGTH = "{arch}.posnet.embedding_length"
+ BLOCK_COUNT = "{arch}.posnet.block_count"
+
+ class ConvNext:
+ EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
+ BLOCK_COUNT = "{arch}.convnext.block_count"
+
+ class Classifier:
+ OUTPUT_LABELS = "{arch}.classifier.output_labels"
+
+ class ShortConv:
+ L_CACHE = "{arch}.shortconv.l_cache"
+
+ class Tokenizer:
+ MODEL = "tokenizer.ggml.model"
+ PRE = "tokenizer.ggml.pre"
+ LIST = "tokenizer.ggml.tokens"
+ TOKEN_TYPE = "tokenizer.ggml.token_type"
+ TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
+ SCORES = "tokenizer.ggml.scores"
+ MERGES = "tokenizer.ggml.merges"
+ BOS_ID = "tokenizer.ggml.bos_token_id"
+ EOS_ID = "tokenizer.ggml.eos_token_id"
+ EOT_ID = "tokenizer.ggml.eot_token_id"
+ EOM_ID = "tokenizer.ggml.eom_token_id"
+ UNK_ID = "tokenizer.ggml.unknown_token_id"
+ SEP_ID = "tokenizer.ggml.seperator_token_id"
+ PAD_ID = "tokenizer.ggml.padding_token_id"
+ MASK_ID = "tokenizer.ggml.mask_token_id"
+ ADD_BOS = "tokenizer.ggml.add_bos_token"
+ ADD_EOS = "tokenizer.ggml.add_eos_token"
+ ADD_SEP = "tokenizer.ggml.add_sep_token"
+ ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
+ REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
+ PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
+ HF_JSON = "tokenizer.huggingface.json"
+ RWKV = "tokenizer.rwkv.world"
+ CHAT_TEMPLATE = "tokenizer.chat_template"
+ CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
+ CHAT_TEMPLATES = "tokenizer.chat_templates"
+ # FIM/Infill special tokens constants
+ FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id"
+ FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id"
+ FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id"
+ FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id"
+ FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id"
+ FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id"
+ # deprecated:
+ PREFIX_ID = "tokenizer.ggml.prefix_token_id"
+ SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
+ MIDDLE_ID = "tokenizer.ggml.middle_token_id"
+
+ class Adapter:
+ TYPE = "adapter.type"
+ LORA_ALPHA = "adapter.lora.alpha"
+ LORA_TASK_NAME = "adapter.lora.task_name"
+ LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
+ ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens"
+
+ class IMatrix:
+ CHUNK_COUNT = "imatrix.chunk_count"
+ CHUNK_SIZE = "imatrix.chunk_size"
+ DATASETS = "imatrix.datasets"
+
+ class Clip:
+ PROJECTOR_TYPE = "clip.projector_type"
+ HAS_VISION_ENCODER = "clip.has_vision_encoder"
+ HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
+ HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
+
+ class ClipVision:
+ PROJECTOR_TYPE = "clip.vision.projector_type" # for mixed modality models
+ IMAGE_SIZE = "clip.vision.image_size"
+ IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels"
+ IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels"
+ PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
+ PATCH_SIZE = "clip.vision.patch_size"
+ EMBEDDING_LENGTH = "clip.vision.embedding_length"
+ FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
+ PROJECTION_DIM = "clip.vision.projection_dim"
+ BLOCK_COUNT = "clip.vision.block_count"
+ IMAGE_MEAN = "clip.vision.image_mean"
+ IMAGE_STD = "clip.vision.image_std"
+ SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
+ USE_GELU = "clip.use_gelu"
+ USE_SILU = "clip.use_silu"
+ N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
+ WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl
+ IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
+ WINDOW_SIZE = "clip.vision.window_size"
+
+ class Attention:
+ HEAD_COUNT = "clip.vision.attention.head_count"
+ LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
+
+ class Projector:
+ SCALE_FACTOR = "clip.vision.projector.scale_factor"
+
+ class ClipAudio:
+ PROJECTOR_TYPE = "clip.audio.projector_type" # for mixed modality models
+ NUM_MEL_BINS = "clip.audio.num_mel_bins"
+ EMBEDDING_LENGTH = "clip.audio.embedding_length"
+ FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
+ PROJECTION_DIM = "clip.audio.projection_dim"
+ BLOCK_COUNT = "clip.audio.block_count"
+
+ class Attention:
+ HEAD_COUNT = "clip.audio.attention.head_count"
+ LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
+
+ class Projector:
+ STACK_FACTOR = "clip.audio.projector.stack_factor"
+
+ class Diffusion:
+ SHIFT_LOGITS = "diffusion.shift_logits"
+
+ class xIELU:
+ ALPHA_P = "xielu.alpha_p"
+ ALPHA_N = "xielu.alpha_n"
+ BETA = "xielu.beta"
+ EPS = "xielu.eps"
+
+
+#
+# recommended mapping of model tensor names for storage in gguf
+#
+
+
+class GGUFType:
+ MODEL = "model"
+ ADAPTER = "adapter"
+ IMATRIX = "imatrix"
+ MMPROJ = "mmproj" # dummy, unused for now
+
+
+class MODEL_ARCH(IntEnum):
+ MMPROJ = auto() # dummy arch for clip.cpp
+ LLAMA = auto()
+ LLAMA4 = auto()
+ DECI = auto()
+ FALCON = auto()
+ FALCON_H1 = auto()
+ BAICHUAN = auto()
+ GROK = auto()
+ GPT2 = auto()
+ GPTJ = auto()
+ GPTNEOX = auto()
+ MPT = auto()
+ STARCODER = auto()
+ REFACT = auto()
+ BERT = auto()
+ MODERN_BERT = auto()
+ NOMIC_BERT = auto()
+ NOMIC_BERT_MOE = auto()
+ NEO_BERT = auto()
+ JINA_BERT_V2 = auto()
+ JINA_BERT_V3 = auto()
+ BLOOM = auto()
+ STABLELM = auto()
+ QWEN = auto()
+ QWEN2 = auto()
+ QWEN2MOE = auto()
+ QWEN2VL = auto()
+ QWEN3 = auto()
+ QWEN3MOE = auto()
+ QWEN3NEXT = auto()
+ QWEN3VL = auto()
+ QWEN3VLMOE = auto()
+ QWEN35 = auto()
+ QWEN35MOE = auto()
+ PHI2 = auto()
+ PHI3 = auto()
+ PHIMOE = auto()
+ PLAMO = auto()
+ PLAMO2 = auto()
+ PLAMO3 = auto()
+ CODESHELL = auto()
+ ORION = auto()
+ INTERNLM2 = auto()
+ MINICPM = auto()
+ MINICPM3 = auto()
+ GEMMA = auto()
+ GEMMA2 = auto()
+ GEMMA3 = auto()
+ GEMMA3N = auto()
+ GEMMA_EMBEDDING = auto()
+ STARCODER2 = auto()
+ RWKV6 = auto()
+ RWKV6QWEN2 = auto()
+ RWKV7 = auto()
+ ARWKV7 = auto()
+ MAMBA = auto()
+ MAMBA2 = auto()
+ JAMBA = auto()
+ XVERSE = auto()
+ COMMAND_R = auto()
+ COHERE2 = auto()
+ DBRX = auto()
+ OLMO = auto()
+ OLMO2 = auto()
+ OLMOE = auto()
+ OPENELM = auto()
+ ARCTIC = auto()
+ DEEPSEEK = auto()
+ DEEPSEEK2 = auto()
+ CHATGLM = auto()
+ GLM4 = auto()
+ GLM4_MOE = auto()
+ BITNET = auto()
+ T5 = auto()
+ T5ENCODER = auto()
+ JAIS = auto()
+ NEMOTRON = auto()
+ NEMOTRON_H = auto()
+ NEMOTRON_H_MOE = auto()
+ EXAONE = auto()
+ EXAONE4 = auto()
+ EXAONE_MOE = auto()
+ GRANITE = auto()
+ GRANITE_MOE = auto()
+ GRANITE_HYBRID = auto()
+ CHAMELEON = auto()
+ WAVTOKENIZER_DEC = auto()
+ PLM = auto()
+ BAILINGMOE = auto()
+ BAILINGMOE2 = auto()
+ DOTS1 = auto()
+ ARCEE = auto()
+ AFMOE = auto()
+ ERNIE4_5 = auto()
+ ERNIE4_5_MOE = auto()
+ HUNYUAN_MOE = auto()
+ HUNYUAN_DENSE = auto()
+ SMOLLM3 = auto()
+ GPT_OSS = auto()
+ LFM2 = auto()
+ LFM2MOE = auto()
+ DREAM = auto()
+ SMALLTHINKER = auto()
+ LLADA = auto()
+ LLADA_MOE = auto()
+ SEED_OSS = auto()
+ GROVEMOE = auto()
+ APERTUS = auto()
+ COGVLM = auto()
+ MINIMAXM2 = auto()
+ RND1 = auto()
+ PANGU_EMBED = auto()
+ MISTRAL3 = auto()
+ MIMO2 = auto()
+ STEP35 = auto()
+ LLAMA_EMBED = auto()
+ MAINCODER = auto()
+ KIMI_LINEAR = auto()
+
+
+class VISION_PROJECTOR_TYPE(IntEnum):
+ MLP = auto()
+ LDP = auto()
+ LDPV2 = auto()
+ RESAMPLER = auto()
+ GLM_EDGE = auto()
+ MERGER = auto()
+ GEMMA3N = auto()
+ GEMMA3 = auto()
+ QWEN3VL = auto()
+ COGVLM = auto()
+
+
+class MODEL_TENSOR(IntEnum):
+ TOKEN_EMBD = auto()
+ TOKEN_EMBD_NORM = auto()
+ TOKEN_TYPES = auto()
+ POS_EMBD = auto()
+ OUTPUT = auto()
+ DENSE_2_OUT = auto() # embeddinggemma 2_Dense
+ DENSE_3_OUT = auto() # embeddinggemma 3_Dense
+ OUTPUT_NORM = auto()
+ ROPE_FREQS = auto()
+ ROPE_FACTORS_LONG = auto()
+ ROPE_FACTORS_SHORT = auto()
+ ATTN_Q = auto()
+ ATTN_K = auto()
+ ATTN_V = auto()
+ ATTN_QKV = auto()
+ ATTN_OUT = auto()
+ ATTN_NORM = auto()
+ ATTN_NORM_2 = auto()
+ ATTN_OUT_NORM = auto()
+ ATTN_POST_NORM = auto()
+ ATTN_ROT_EMBD = auto()
+ ATTN_SINKS = auto()
+ ATTN_GATE = auto()
+ FFN_GATE_INP = auto()
+ FFN_GATE_INP_SHEXP = auto()
+ FFN_NORM = auto()
+ FFN_PRE_NORM = auto()
+ FFN_POST_NORM = auto()
+ FFN_GATE = auto()
+ FFN_DOWN = auto()
+ FFN_UP = auto()
+ FFN_ACT = auto()
+ FFN_NORM_EXP = auto()
+ FFN_GATE_EXP = auto()
+ FFN_DOWN_EXP = auto()
+ FFN_UP_EXP = auto()
+ FFN_GATE_SHEXP = auto()
+ FFN_DOWN_SHEXP = auto()
+ FFN_UP_SHEXP = auto()
+ FFN_GATE_CHEXP = auto()
+ FFN_DOWN_CHEXP = auto()
+ FFN_UP_CHEXP = auto()
+ FFN_EXP_PROBS_B = auto()
+ ATTN_Q_NORM = auto()
+ ATTN_K_NORM = auto()
+ LAYER_OUT_NORM = auto()
+ PER_LAYER_TOKEN_EMBD = auto() # gemma3n
+ PER_LAYER_MODEL_PROJ = auto() # gemma3n
+ PER_LAYER_INP_GATE = auto() # gemma3n
+ PER_LAYER_PROJ = auto() # gemma3n
+ PER_LAYER_PROJ_NORM = auto() # gemma3n
+ PER_LAYER_POST_NORM = auto() # gemma3n
+ ALTUP_PROJ = auto() # gemma3n
+ ALTUP_UNEMBD_PROJ = auto() # gemma3n
+ ALTUP_CORRECT_COEF = auto() # gemma3n
+ ALTUP_CORRECT_SCALE = auto() # gemma3n
+ ALTUP_PREDICT_COEF = auto() # gemma3n
+ ALTUP_ROUTER = auto() # gemma3n
+ ALTUP_ROUTER_NORM = auto() # gemma3n
+ LAUREL_L = auto() # gemma3n
+ LAUREL_R = auto() # gemma3n
+ LAUREL_POST_NORM = auto() # gemma3n
+ SSM_IN = auto()
+ SSM_CONV1D = auto()
+ SSM_X = auto()
+ SSM_DT = auto()
+ SSM_DT_NORM = auto()
+ SSM_A = auto()
+ SSM_B_NORM = auto()
+ SSM_C_NORM = auto()
+ SSM_D = auto()
+ SSM_NORM = auto()
+ SSM_OUT = auto()
+ SSM_ALPHA = auto() # qwen3.5
+ SSM_BETA_ALPHA = auto() # qwen3next
+ SSM_CONV1D_Q = auto() # Kimi Linear
+ SSM_CONV1D_K = auto() # Kimi Linear
+ SSM_CONV1D_V = auto() # Kimi Linear
+ SSM_F_A = auto() # Kimi Linear
+ SSM_F_B = auto() # Kimi Linear
+ SSM_BETA = auto() # Kimi Linear qwen3.5
+ SSM_G_A = auto() # Kimi Linear
+ SSM_G_B = auto() # Kimi Linear
+ TIME_MIX_W0 = auto()
+ TIME_MIX_W1 = auto()
+ TIME_MIX_W2 = auto()
+ TIME_MIX_A0 = auto()
+ TIME_MIX_A1 = auto()
+ TIME_MIX_A2 = auto()
+ TIME_MIX_V0 = auto()
+ TIME_MIX_V1 = auto()
+ TIME_MIX_V2 = auto()
+ TIME_MIX_G1 = auto()
+ TIME_MIX_G2 = auto()
+ TIME_MIX_K_K = auto()
+ TIME_MIX_K_A = auto()
+ TIME_MIX_R_K = auto()
+ TIME_MIX_LERP_X = auto()
+ TIME_MIX_LERP_K = auto()
+ TIME_MIX_LERP_V = auto()
+ TIME_MIX_LERP_R = auto()
+ TIME_MIX_LERP_G = auto()
+ TIME_MIX_LERP_FUSED = auto()
+ TIME_MIX_LERP_W = auto()
+ TIME_MIX_FIRST = auto()
+ TIME_MIX_DECAY = auto()
+ TIME_MIX_DECAY_W1 = auto()
+ TIME_MIX_DECAY_W2 = auto()
+ TIME_MIX_KEY = auto()
+ TIME_MIX_VALUE = auto()
+ TIME_MIX_RECEPTANCE = auto()
+ TIME_MIX_GATE = auto()
+ TIME_MIX_LN = auto()
+ TIME_MIX_OUTPUT = auto()
+ CHANNEL_MIX_LERP_K = auto()
+ CHANNEL_MIX_LERP_R = auto()
+ CHANNEL_MIX_KEY = auto()
+ CHANNEL_MIX_RECEPTANCE = auto()
+ CHANNEL_MIX_VALUE = auto()
+ ATTN_Q_A = auto()
+ ATTN_Q_B = auto()
+ ATTN_KV_A_MQA = auto()
+ ATTN_KV_B = auto()
+ ATTN_K_B = auto()
+ ATTN_V_B = auto()
+ ATTN_Q_A_NORM = auto()
+ ATTN_KV_A_NORM = auto()
+ FFN_SUB_NORM = auto()
+ ATTN_SUB_NORM = auto()
+ DEC_ATTN_NORM = auto()
+ DEC_ATTN_Q = auto()
+ DEC_ATTN_K = auto()
+ DEC_ATTN_V = auto()
+ DEC_ATTN_OUT = auto()
+ DEC_ATTN_REL_B = auto()
+ DEC_CROSS_ATTN_NORM = auto()
+ DEC_CROSS_ATTN_Q = auto()
+ DEC_CROSS_ATTN_K = auto()
+ DEC_CROSS_ATTN_V = auto()
+ DEC_CROSS_ATTN_OUT = auto()
+ DEC_CROSS_ATTN_REL_B = auto()
+ DEC_FFN_NORM = auto()
+ DEC_FFN_GATE = auto()
+ DEC_FFN_DOWN = auto()
+ DEC_FFN_UP = auto()
+ DEC_OUTPUT_NORM = auto()
+ ENC_ATTN_NORM = auto()
+ ENC_ATTN_Q = auto()
+ ENC_ATTN_K = auto()
+ ENC_ATTN_V = auto()
+ ENC_ATTN_OUT = auto()
+ ENC_ATTN_REL_B = auto()
+ ENC_FFN_NORM = auto()
+ ENC_FFN_GATE = auto()
+ ENC_FFN_DOWN = auto()
+ ENC_FFN_UP = auto()
+ ENC_OUTPUT_NORM = auto()
+ CLS = auto() # classifier
+ CLS_OUT = auto() # classifier output projection
+ CONV1D = auto()
+ CONVNEXT_DW = auto()
+ CONVNEXT_NORM = auto()
+ CONVNEXT_PW1 = auto()
+ CONVNEXT_PW2 = auto()
+ CONVNEXT_GAMMA = auto()
+ POSNET_CONV1 = auto()
+ POSNET_CONV2 = auto()
+ POSNET_NORM = auto()
+ POSNET_NORM1 = auto()
+ POSNET_NORM2 = auto()
+ POSNET_ATTN_NORM = auto()
+ POSNET_ATTN_Q = auto()
+ POSNET_ATTN_K = auto()
+ POSNET_ATTN_V = auto()
+ POSNET_ATTN_OUT = auto()
+ SHORTCONV_CONV = auto()
+ SHORTCONV_INPROJ = auto()
+ SHORTCONV_OUTPROJ = auto()
+ VISEXP_ATTN_QKV = auto()
+ VISEXP_ATTN_OUT = auto()
+ VISEXP_GATE = auto()
+ VISEXP_DOWN = auto()
+ VISEXP_UP = auto()
+ # vision
+ V_MMPROJ = auto()
+ V_MMPROJ_FC = auto()
+ V_MMPROJ_MLP = auto()
+ V_MMPROJ_PEG = auto()
+ V_ENC_EMBD_CLS = auto()
+ V_ENC_EMBD_PATCH = auto()
+ V_ENC_EMBD_NORM = auto()
+ V_ENC_EMBD_POS = auto()
+ V_ENC_INPUT_NORM = auto()
+ V_ENC_ATTN_QKV = auto()
+ V_ENC_ATTN_Q = auto()
+ V_ENC_ATTN_Q_NORM = auto()
+ V_ENC_ATTN_K = auto()
+ V_ENC_ATTN_K_NORM = auto()
+ V_ENC_ATTN_V = auto()
+ V_ENC_ATTN_O = auto()
+ V_ENC_ATTN_O_NORM = auto()
+ V_ENC_POST_ATTN_NORM = auto()
+ V_ENC_FFN_UP = auto()
+ V_ENC_FFN_GATE = auto()
+ V_ENC_FFN_DOWN = auto()
+ V_LAYER_SCALE_1 = auto()
+ V_LAYER_SCALE_2 = auto()
+ V_PRE_NORM = auto()
+ V_POST_NORM = auto()
+ V_MM_POST_NORM = auto()
+ V_MM_INP_NORM = auto()
+ V_MM_INP_PROJ = auto() # gemma3
+ V_MM_SOFT_EMB_NORM = auto() # gemma3
+ V_MM_EMBEDDING = auto() # gemma3n
+ V_MM_HARD_EMB_NORM = auto() # gemma3n
+ V_ENC_CONV_STEM = auto() # gemma3n
+ V_ENC_CONV_STEM_NORM = auto() # gemma3n
+ V_ENC_MSFA_EXP = auto() # gemma3n
+ V_ENC_MSFA_EXP_NORM = auto() # gemma3n
+ V_ENC_MSFA_PROJ = auto() # gemma3n
+ V_ENC_MSFA_PROJ_NORM = auto() # gemma3n
+ V_ENC_MSFA_NORM = auto() # gemma3n
+ V_RESMPL_POS_EMBD_K = auto() # minicpmv
+ V_RESMPL_ATTN_Q = auto() # minicpmv
+ V_RESMPL_ATTN_K = auto() # minicpmv
+ V_RESMPL_ATTN_V = auto() # minicpmv
+ V_RESMPL_ATTN_OUT = auto() # minicpmv
+ V_RESMPL_KV = auto() # minicpmv
+ V_RESMPL_KV_NORM = auto() # minicpmv
+ V_RESMPL_POST_NORM = auto() # minicpmv
+ V_RESMPL_Q_NORM = auto() # minicpmv
+ V_RESMPL_PROJ = auto() # minicpmv
+ V_RESMPL_QUERY = auto() # minicpmv
+ V_TOK_EMBD_IMG_BREAK = auto() # pixtral
+ V_MM_PATCH_MERGER = auto() # mistral small 3.1
+ V_DS_NORM = auto() # qwen3vl
+ V_DS_FC1 = auto() # qwen3vl
+ V_DS_FC2 = auto() # qwen3vl
+ V_MM_POST_FC_NORM = auto() # cogvlm
+ V_MM_UP = auto() # cogvlm
+ V_MM_DOWN = auto() # cogvlm
+ V_MM_GATE = auto() # cogvlm
+ V_TOK_BOI = auto() # cogvlm
+ V_TOK_EOI = auto() # cogvlm
+ # audio (mtmd)
+ A_ENC_EMBD_POS = auto()
+ A_ENC_EMBD_NORM = auto()
+ A_ENC_EMBD_TO_LOGITS = auto() # lfm2
+ A_ENC_CONV1D = auto()
+ A_ENC_CONV1D_NORM = auto() # gemma3n
+ A_PRE_NORM = auto()
+ A_POST_NORM = auto()
+ A_ENC_LAYER_PRE_NORM = auto() # gemma3n
+ A_ENC_ATTN_Q = auto()
+ A_ENC_ATTN_K = auto()
+ A_ENC_ATTN_V = auto()
+ A_ENC_PER_DIM_SCALE = auto() # gemma3n
+ A_ENC_INPUT_NORM = auto()
+ A_ENC_OUTPUT = auto()
+ A_ENC_OUTPUT_NORM = auto()
+ A_ENC_FFN_UP = auto()
+ A_ENC_FFN_NORM = auto()
+ A_ENC_FFN_POST_NORM = auto() # gemma3n
+ A_ENC_FFN_SCALE = auto() # gemma3n
+ A_ENC_FFN_GATE = auto()
+ A_ENC_FFN_DOWN = auto()
+ A_ENC_FFN_UP_1 = auto() # lfm2, gemma3n
+ A_ENC_FFN_NORM_1 = auto() # lfm2, gemma3n (pre-norm)
+ A_ENC_FFN_POST_NORM_1 = auto() # gemma3n
+ A_ENC_FFN_SCALE_1 = auto() # gemma3n
+ A_ENC_FFN_GATE_1 = auto() # lfm2, gemma3n
+ A_ENC_FFN_DOWN_1 = auto() # lfm2, gemma3n
+ A_MMPROJ = auto()
+ A_MMPROJ_FC = auto()
+ A_MM_NORM_PRE = auto()
+ A_MM_NORM_MID = auto()
+ A_MM_EMBEDDING = auto() # gemma3n
+ A_MM_HARD_EMB_NORM = auto() # gemma3n
+ A_MM_SOFT_EMB_NORM = auto() # gemma3n
+ A_MM_INP_PROJ = auto() # gemma3n
+ # nextn/mtp
+ NEXTN_EH_PROJ = auto()
+ NEXTN_EMBED_TOKENS = auto()
+ NEXTN_ENORM = auto()
+ NEXTN_HNORM = auto()
+ NEXTN_SHARED_HEAD_HEAD = auto()
+ NEXTN_SHARED_HEAD_NORM = auto()
+ # lfm2 audio
+ A_ENC_NORM_CONV = auto()
+ A_ENC_LINEAR_POS = auto()
+ A_ENC_POS_BIAS_U = auto()
+ A_ENC_POS_BIAS_V = auto()
+ A_ENC_OUT = auto()
+ A_ENC_CONV_DW = auto() # SSM conv
+ A_ENC_CONV_NORM = auto() # SSM conv
+ A_ENC_CONV_PW1 = auto()
+ A_ENC_CONV_PW2 = auto()
+
+
+MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
+ MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
+ MODEL_ARCH.LLAMA: "llama",
+ MODEL_ARCH.LLAMA4: "llama4",
+ MODEL_ARCH.DECI: "deci",
+ MODEL_ARCH.FALCON: "falcon",
+ MODEL_ARCH.BAICHUAN: "baichuan",
+ MODEL_ARCH.GROK: "grok",
+ MODEL_ARCH.GPT2: "gpt2",
+ MODEL_ARCH.GPTJ: "gptj",
+ MODEL_ARCH.GPTNEOX: "gptneox",
+ MODEL_ARCH.MPT: "mpt",
+ MODEL_ARCH.STARCODER: "starcoder",
+ MODEL_ARCH.REFACT: "refact",
+ MODEL_ARCH.BERT: "bert",
+ MODEL_ARCH.MODERN_BERT: "modern-bert",
+ MODEL_ARCH.NOMIC_BERT: "nomic-bert",
+ MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
+ MODEL_ARCH.NEO_BERT: "neo-bert",
+ MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
+ MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
+ MODEL_ARCH.BLOOM: "bloom",
+ MODEL_ARCH.STABLELM: "stablelm",
+ MODEL_ARCH.QWEN: "qwen",
+ MODEL_ARCH.QWEN2: "qwen2",
+ MODEL_ARCH.QWEN2MOE: "qwen2moe",
+ MODEL_ARCH.QWEN2VL: "qwen2vl",
+ MODEL_ARCH.QWEN3: "qwen3",
+ MODEL_ARCH.QWEN3MOE: "qwen3moe",
+ MODEL_ARCH.QWEN3NEXT: "qwen3next",
+ MODEL_ARCH.QWEN3VL: "qwen3vl",
+ MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
+ MODEL_ARCH.QWEN35: "qwen35",
+ MODEL_ARCH.QWEN35MOE: "qwen35moe",
+ MODEL_ARCH.PHI2: "phi2",
+ MODEL_ARCH.PHI3: "phi3",
+ MODEL_ARCH.PHIMOE: "phimoe",
+ MODEL_ARCH.PLAMO: "plamo",
+ MODEL_ARCH.PLAMO2: "plamo2",
+ MODEL_ARCH.PLAMO3: "plamo3",
+ MODEL_ARCH.CODESHELL: "codeshell",
+ MODEL_ARCH.ORION: "orion",
+ MODEL_ARCH.INTERNLM2: "internlm2",
+ MODEL_ARCH.MINICPM: "minicpm",
+ MODEL_ARCH.MINICPM3: "minicpm3",
+ MODEL_ARCH.GEMMA: "gemma",
+ MODEL_ARCH.GEMMA2: "gemma2",
+ MODEL_ARCH.GEMMA3: "gemma3",
+ MODEL_ARCH.GEMMA3N: "gemma3n",
+ MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
+ MODEL_ARCH.STARCODER2: "starcoder2",
+ MODEL_ARCH.RWKV6: "rwkv6",
+ MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
+ MODEL_ARCH.RWKV7: "rwkv7",
+ MODEL_ARCH.ARWKV7: "arwkv7",
+ MODEL_ARCH.MAMBA: "mamba",
+ MODEL_ARCH.MAMBA2: "mamba2",
+ MODEL_ARCH.JAMBA: "jamba",
+ MODEL_ARCH.XVERSE: "xverse",
+ MODEL_ARCH.COMMAND_R: "command-r",
+ MODEL_ARCH.COHERE2: "cohere2",
+ MODEL_ARCH.DBRX: "dbrx",
+ MODEL_ARCH.OLMO: "olmo",
+ MODEL_ARCH.OLMO2: "olmo2",
+ MODEL_ARCH.OLMOE: "olmoe",
+ MODEL_ARCH.OPENELM: "openelm",
+ MODEL_ARCH.ARCTIC: "arctic",
+ MODEL_ARCH.DEEPSEEK: "deepseek",
+ MODEL_ARCH.DEEPSEEK2: "deepseek2",
+ MODEL_ARCH.CHATGLM: "chatglm",
+ MODEL_ARCH.GLM4: "glm4",
+ MODEL_ARCH.GLM4_MOE: "glm4moe",
+ MODEL_ARCH.BITNET: "bitnet",
+ MODEL_ARCH.T5: "t5",
+ MODEL_ARCH.T5ENCODER: "t5encoder",
+ MODEL_ARCH.JAIS: "jais",
+ MODEL_ARCH.NEMOTRON: "nemotron",
+ MODEL_ARCH.NEMOTRON_H: "nemotron_h",
+ MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
+ MODEL_ARCH.EXAONE: "exaone",
+ MODEL_ARCH.EXAONE4: "exaone4",
+ MODEL_ARCH.EXAONE_MOE: "exaone-moe",
+ MODEL_ARCH.GRANITE: "granite",
+ MODEL_ARCH.GRANITE_MOE: "granitemoe",
+ MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
+ MODEL_ARCH.CHAMELEON: "chameleon",
+ MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
+ MODEL_ARCH.PLM: "plm",
+ MODEL_ARCH.BAILINGMOE: "bailingmoe",
+ MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
+ MODEL_ARCH.DOTS1: "dots1",
+ MODEL_ARCH.ARCEE: "arcee",
+ MODEL_ARCH.AFMOE: "afmoe",
+ MODEL_ARCH.ERNIE4_5: "ernie4_5",
+ MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
+ MODEL_ARCH.FALCON_H1: "falcon-h1",
+ MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
+ MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
+ MODEL_ARCH.SMOLLM3: "smollm3",
+ MODEL_ARCH.GPT_OSS: "gpt-oss",
+ MODEL_ARCH.LFM2: "lfm2",
+ MODEL_ARCH.LFM2MOE: "lfm2moe",
+ MODEL_ARCH.DREAM: "dream",
+ MODEL_ARCH.SMALLTHINKER: "smallthinker",
+ MODEL_ARCH.LLADA: "llada",
+ MODEL_ARCH.LLADA_MOE: "llada-moe",
+ MODEL_ARCH.SEED_OSS: "seed_oss",
+ MODEL_ARCH.GROVEMOE: "grovemoe",
+ MODEL_ARCH.APERTUS: "apertus",
+ MODEL_ARCH.MINIMAXM2: "minimax-m2",
+ MODEL_ARCH.COGVLM: "cogvlm",
+ MODEL_ARCH.RND1: "rnd1",
+ MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
+ MODEL_ARCH.MISTRAL3: "mistral3",
+ MODEL_ARCH.MIMO2: "mimo2",
+ MODEL_ARCH.STEP35: "step35",
+ MODEL_ARCH.LLAMA_EMBED: "llama-embed",
+ MODEL_ARCH.MAINCODER: "maincoder",
+ MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
+}
+
+VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
+ VISION_PROJECTOR_TYPE.MLP: "mlp",
+ VISION_PROJECTOR_TYPE.LDP: "ldp",
+ VISION_PROJECTOR_TYPE.LDPV2: "ldpv2",
+ VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
+ VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter",
+ VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger",
+ VISION_PROJECTOR_TYPE.GEMMA3: "gemma3",
+}
+
+TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
+ MODEL_TENSOR.TOKEN_EMBD: "token_embd",
+ MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
+ MODEL_TENSOR.TOKEN_TYPES: "token_types",
+ MODEL_TENSOR.POS_EMBD: "position_embd",
+ MODEL_TENSOR.OUTPUT_NORM: "output_norm",
+ MODEL_TENSOR.OUTPUT: "output",
+ MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
+ MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
+ MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
+ MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
+ MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
+ MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
+ MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
+ MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
+ MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
+ MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
+ MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
+ MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+ MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
+ MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
+ MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
+ MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
+ MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
+ MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
+ MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
+ MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
+ MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
+ MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
+ MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
+ MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
+ MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
+ MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
+ MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
+ MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
+ MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
+ MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
+ MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
+ MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
+ MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
+ MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
+ MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
+ MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
+ MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
+ MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
+ MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
+ MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
+ MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
+ MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n
+ MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n
+ MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n
+ MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n
+ MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n
+ MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n
+ MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n
+ MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n
+ MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n
+ MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n
+ MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n
+ MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n
+ MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n
+ MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
+ MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
+ MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
+ MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
+ MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
+ MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
+ MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
+ MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
+ MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
+ MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
+ MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
+ MODEL_TENSOR.SSM_ALPHA: "blk.{bid}.ssm_alpha", # qwen3.5
+ MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
+ MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
+ MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
+ MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
+ MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
+ MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
+ MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear qwen3.5
+ MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
+ MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
+ MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
+ MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
+ MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
+ MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0",
+ MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1",
+ MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2",
+ MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0",
+ MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1",
+ MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2",
+ MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1",
+ MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2",
+ MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k",
+ MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a",
+ MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k",
+ MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x",
+ MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k",
+ MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
+ MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
+ MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
+ MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
+ MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
+ MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1",
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2",
+ MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key",
+ MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value",
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance",
+ MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate",
+ MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln",
+ MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output",
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k",
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r",
+ MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key",
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance",
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value",
+ MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
+ MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
+ MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
+ MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
+ MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
+ MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
+ MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
+ MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
+ MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
+ MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
+ MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
+ MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q",
+ MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k",
+ MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v",
+ MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o",
+ MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b",
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm",
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q",
+ MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k",
+ MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v",
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o",
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b",
+ MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm",
+ MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
+ MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
+ MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
+ MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
+ MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
+ MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
+ MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v",
+ MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o",
+ MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b",
+ MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm",
+ MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
+ MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
+ MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
+ MODEL_TENSOR.CLS: "cls",
+ MODEL_TENSOR.CLS_OUT: "cls.output",
+ MODEL_TENSOR.CONV1D: "conv1d",
+ MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
+ MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm",
+ MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1",
+ MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2",
+ MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma",
+ MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1",
+ MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2",
+ MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm",
+ MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1",
+ MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2",
+ MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm",
+ MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q",
+ MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
+ MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
+ MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
+ MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv",
+ MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj",
+ MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj",
+ MODEL_TENSOR.VISEXP_ATTN_QKV: "blk.{bid}.vis_attn_qkv",
+ MODEL_TENSOR.VISEXP_ATTN_OUT: "blk.{bid}.vis_attn_output",
+ MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
+ MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
+ MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
+ # vision
+ MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
+ MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
+ MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}",
+ MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
+ MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
+ MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
+ MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
+ MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
+ MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
+ MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
+ MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
+ MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
+ MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
+ MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
+ MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
+ MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
+ MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
+ MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
+ MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
+ MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
+ MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
+ MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
+ MODEL_TENSOR.V_POST_NORM: "v.post_ln",
+ MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
+ MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
+ MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
+ MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", # gemma3n
+ MODEL_TENSOR.V_MM_EMBEDDING: "mm.embedding", # gemma3n
+ MODEL_TENSOR.V_MM_HARD_EMB_NORM: "mm.hard_emb_norm", # gemma3n
+ MODEL_TENSOR.V_ENC_CONV_STEM: "v.conv_stem.conv", # gemma3n
+ MODEL_TENSOR.V_ENC_CONV_STEM_NORM: "v.conv_stem.bn", # gemma3n
+ MODEL_TENSOR.V_ENC_MSFA_EXP: "v.msfa.ffn.pw_exp.conv", # gemma3n
+ MODEL_TENSOR.V_ENC_MSFA_EXP_NORM: "v.msfa.ffn.pw_exp.bn", # gemma3n
+ MODEL_TENSOR.V_ENC_MSFA_PROJ: "v.msfa.ffn.pw_proj.conv", # gemma3n
+ MODEL_TENSOR.V_ENC_MSFA_PROJ_NORM: "v.msfa.ffn.pw_proj.bn", # gemma3n
+ MODEL_TENSOR.V_ENC_MSFA_NORM: "v.msfa.norm", # gemma3n
+ MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
+ MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
+ MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k",
+ MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v",
+ MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out",
+ MODEL_TENSOR.V_RESMPL_KV: "resampler.kv",
+ MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv",
+ MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post",
+ MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
+ MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
+ MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
+ MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
+ MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
+ MODEL_TENSOR.V_DS_NORM: "v.deepstack.{bid}.norm",
+ MODEL_TENSOR.V_DS_FC1: "v.deepstack.{bid}.fc1",
+ MODEL_TENSOR.V_DS_FC2: "v.deepstack.{bid}.fc2",
+ MODEL_TENSOR.V_MM_POST_FC_NORM: "mm.post_fc_norm", # cogvlm
+ MODEL_TENSOR.V_MM_UP: "mm.up",
+ MODEL_TENSOR.V_MM_DOWN: "mm.down",
+ MODEL_TENSOR.V_MM_GATE: "mm.gate",
+ MODEL_TENSOR.V_TOK_BOI: "v.boi",
+ MODEL_TENSOR.V_TOK_EOI: "v.eoi",
+ # audio (mtmd)
+ # note: all audio tensor names must use prefix "a." or "mm.a."
+ MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
+ MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
+ MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
+ MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
+ MODEL_TENSOR.A_ENC_CONV1D_NORM: "a.conv1d.{bid}.norm",
+ MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
+ MODEL_TENSOR.A_POST_NORM: "a.post_ln",
+ MODEL_TENSOR.A_ENC_LAYER_PRE_NORM: "a.blk.{bid}.layer_pre_norm",
+ MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
+ MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
+ MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
+ MODEL_TENSOR.A_ENC_PER_DIM_SCALE: "a.blk.{bid}.per_dim_scale",
+ MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
+ MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
+ MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
+ MODEL_TENSOR.A_ENC_FFN_NORM: "a.blk.{bid}.ffn_norm",
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM: "a.blk.{bid}.ffn_post_norm",
+ MODEL_TENSOR.A_ENC_FFN_SCALE: "a.blk.{bid}.ffn_scale",
+ MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
+ MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
+ MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
+ MODEL_TENSOR.A_ENC_FFN_NORM_1: "a.blk.{bid}.ffn_norm_1",
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM_1: "a.blk.{bid}.ffn_post_norm_1",
+ MODEL_TENSOR.A_ENC_FFN_SCALE_1: "a.blk.{bid}.ffn_scale_1",
+ MODEL_TENSOR.A_ENC_FFN_UP_1: "a.blk.{bid}.ffn_up_1",
+ MODEL_TENSOR.A_ENC_FFN_GATE_1: "a.blk.{bid}.ffn_gate_1",
+ MODEL_TENSOR.A_ENC_FFN_DOWN_1: "a.blk.{bid}.ffn_down_1",
+ MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
+ MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
+ MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
+ MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
+ MODEL_TENSOR.A_MM_INP_PROJ: "mm.a.input_projection", # gemma3n
+ MODEL_TENSOR.A_MM_SOFT_EMB_NORM: "mm.a.soft_emb_norm", # gemma3n
+ MODEL_TENSOR.A_MM_EMBEDDING: "mm.a.embedding", # gemma3n
+ MODEL_TENSOR.A_MM_HARD_EMB_NORM: "mm.a.hard_emb_norm", # gemma3n
+ # lfm2 audio
+ MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
+ MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
+ MODEL_TENSOR.A_ENC_POS_BIAS_U: "a.blk.{bid}.pos_bias_u",
+ MODEL_TENSOR.A_ENC_POS_BIAS_V: "a.blk.{bid}.pos_bias_v",
+ MODEL_TENSOR.A_ENC_OUT: "a.pre_encode.out",
+ MODEL_TENSOR.A_ENC_CONV_DW: "a.blk.{bid}.conv_dw",
+ MODEL_TENSOR.A_ENC_CONV_NORM: "a.blk.{bid}.conv_norm",
+ MODEL_TENSOR.A_ENC_CONV_PW1: "a.blk.{bid}.conv_pw1",
+ MODEL_TENSOR.A_ENC_CONV_PW2: "a.blk.{bid}.conv_pw2",
+ # NextN/MTP
+ MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
+ MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
+ MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
+}
+
+MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
+ MODEL_ARCH.MMPROJ: [
+ MODEL_TENSOR.V_MMPROJ,
+ MODEL_TENSOR.V_MMPROJ_FC,
+ MODEL_TENSOR.V_MMPROJ_MLP,
+ MODEL_TENSOR.V_MMPROJ_PEG,
+ MODEL_TENSOR.V_ENC_EMBD_CLS,
+ MODEL_TENSOR.V_ENC_EMBD_PATCH,
+ MODEL_TENSOR.V_ENC_EMBD_NORM,
+ MODEL_TENSOR.V_ENC_EMBD_POS,
+ MODEL_TENSOR.V_ENC_INPUT_NORM,
+ MODEL_TENSOR.V_ENC_ATTN_QKV,
+ MODEL_TENSOR.V_ENC_ATTN_Q,
+ MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
+ MODEL_TENSOR.V_ENC_ATTN_K,
+ MODEL_TENSOR.V_ENC_ATTN_K_NORM,
+ MODEL_TENSOR.V_ENC_ATTN_V,
+ MODEL_TENSOR.V_ENC_ATTN_O,
+ MODEL_TENSOR.V_ENC_ATTN_O_NORM,
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
+ MODEL_TENSOR.V_ENC_FFN_UP,
+ MODEL_TENSOR.V_ENC_FFN_GATE,
+ MODEL_TENSOR.V_ENC_FFN_DOWN,
+ MODEL_TENSOR.V_LAYER_SCALE_1,
+ MODEL_TENSOR.V_LAYER_SCALE_2,
+ MODEL_TENSOR.V_PRE_NORM,
+ MODEL_TENSOR.V_POST_NORM,
+ MODEL_TENSOR.V_MM_POST_NORM,
+ MODEL_TENSOR.V_MM_INP_PROJ,
+ MODEL_TENSOR.V_MM_INP_NORM,
+ MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
+ MODEL_TENSOR.V_MM_EMBEDDING,
+ MODEL_TENSOR.V_MM_HARD_EMB_NORM,
+ MODEL_TENSOR.V_ENC_CONV_STEM,
+ MODEL_TENSOR.V_ENC_CONV_STEM_NORM,
+ MODEL_TENSOR.V_ENC_MSFA_EXP,
+ MODEL_TENSOR.V_ENC_MSFA_EXP_NORM,
+ MODEL_TENSOR.V_ENC_MSFA_PROJ,
+ MODEL_TENSOR.V_ENC_MSFA_PROJ_NORM,
+ MODEL_TENSOR.V_ENC_MSFA_NORM,
+ MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
+ MODEL_TENSOR.V_RESMPL_ATTN_Q,
+ MODEL_TENSOR.V_RESMPL_ATTN_K,
+ MODEL_TENSOR.V_RESMPL_ATTN_V,
+ MODEL_TENSOR.V_RESMPL_ATTN_OUT,
+ MODEL_TENSOR.V_RESMPL_KV,
+ MODEL_TENSOR.V_RESMPL_KV_NORM,
+ MODEL_TENSOR.V_RESMPL_POST_NORM,
+ MODEL_TENSOR.V_RESMPL_Q_NORM,
+ MODEL_TENSOR.V_RESMPL_PROJ,
+ MODEL_TENSOR.V_RESMPL_QUERY,
+ MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
+ MODEL_TENSOR.V_MM_PATCH_MERGER,
+ MODEL_TENSOR.V_DS_NORM,
+ MODEL_TENSOR.V_DS_FC1,
+ MODEL_TENSOR.V_DS_FC2,
+ MODEL_TENSOR.V_MM_POST_FC_NORM,
+ MODEL_TENSOR.V_MM_UP,
+ MODEL_TENSOR.V_MM_DOWN,
+ MODEL_TENSOR.V_MM_GATE,
+ MODEL_TENSOR.V_TOK_BOI,
+ MODEL_TENSOR.V_TOK_EOI,
+ # audio
+ MODEL_TENSOR.A_ENC_EMBD_POS,
+ MODEL_TENSOR.A_ENC_EMBD_NORM,
+ MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
+ MODEL_TENSOR.A_ENC_CONV1D,
+ MODEL_TENSOR.A_ENC_CONV1D_NORM,
+ MODEL_TENSOR.A_PRE_NORM,
+ MODEL_TENSOR.A_POST_NORM,
+ MODEL_TENSOR.A_ENC_LAYER_PRE_NORM,
+ MODEL_TENSOR.A_ENC_ATTN_Q,
+ MODEL_TENSOR.A_ENC_ATTN_K,
+ MODEL_TENSOR.A_ENC_ATTN_V,
+ MODEL_TENSOR.A_ENC_PER_DIM_SCALE,
+ MODEL_TENSOR.A_ENC_INPUT_NORM,
+ MODEL_TENSOR.A_ENC_OUTPUT,
+ MODEL_TENSOR.A_ENC_OUTPUT_NORM,
+ MODEL_TENSOR.A_ENC_FFN_NORM,
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM,
+ MODEL_TENSOR.A_ENC_FFN_SCALE,
+ MODEL_TENSOR.A_ENC_FFN_UP,
+ MODEL_TENSOR.A_ENC_FFN_GATE,
+ MODEL_TENSOR.A_ENC_FFN_DOWN,
+ MODEL_TENSOR.A_ENC_FFN_NORM_1,
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM_1,
+ MODEL_TENSOR.A_ENC_FFN_SCALE_1,
+ MODEL_TENSOR.A_ENC_FFN_UP_1,
+ MODEL_TENSOR.A_ENC_FFN_GATE_1,
+ MODEL_TENSOR.A_ENC_FFN_DOWN_1,
+ MODEL_TENSOR.A_MMPROJ,
+ MODEL_TENSOR.A_MMPROJ_FC,
+ MODEL_TENSOR.A_MM_NORM_PRE,
+ MODEL_TENSOR.A_MM_NORM_MID,
+ MODEL_TENSOR.A_ENC_NORM_CONV,
+ MODEL_TENSOR.A_ENC_LINEAR_POS,
+ MODEL_TENSOR.A_ENC_POS_BIAS_U,
+ MODEL_TENSOR.A_ENC_POS_BIAS_V,
+ MODEL_TENSOR.A_ENC_OUT,
+ MODEL_TENSOR.A_ENC_CONV_DW,
+ MODEL_TENSOR.A_ENC_CONV_NORM,
+ MODEL_TENSOR.A_ENC_CONV_PW1,
+ MODEL_TENSOR.A_ENC_CONV_PW2,
+ MODEL_TENSOR.A_MM_INP_PROJ,
+ MODEL_TENSOR.A_MM_SOFT_EMB_NORM,
+ MODEL_TENSOR.A_MM_EMBEDDING,
+ MODEL_TENSOR.A_MM_HARD_EMB_NORM,
+ ],
+ MODEL_ARCH.LLAMA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.LLAMA4: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.DECI: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.GROK: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ ],
+ MODEL_ARCH.GPTNEOX: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.FALCON: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_NORM_2,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.BAICHUAN: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.STARCODER: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.BERT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.TOKEN_TYPES,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ MODEL_TENSOR.CLS,
+ MODEL_TENSOR.CLS_OUT,
+ ],
+ MODEL_ARCH.MODERN_BERT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.CLS,
+ MODEL_TENSOR.CLS_OUT,
+ ],
+ MODEL_ARCH.NOMIC_BERT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.TOKEN_TYPES,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ ],
+ MODEL_ARCH.NOMIC_BERT_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.TOKEN_TYPES,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ ],
+ MODEL_ARCH.NEO_BERT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
+ MODEL_TENSOR.CLS,
+ MODEL_TENSOR.CLS_OUT,
+ ],
+ MODEL_ARCH.JINA_BERT_V2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.TOKEN_TYPES,
+ MODEL_TENSOR.ATTN_NORM_2,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ MODEL_TENSOR.CLS,
+ ],
+ MODEL_ARCH.JINA_BERT_V3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.TOKEN_TYPES,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ ],
+ MODEL_ARCH.MPT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_ACT,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.POS_EMBD,
+ ],
+ MODEL_ARCH.GPTJ: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.REFACT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.BLOOM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.STABLELM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ ],
+ MODEL_ARCH.QWEN: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.QWEN2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.DREAM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.LLADA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.QWEN2VL: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.QWEN2MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.QWEN3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.QWEN3MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.QWEN3NEXT: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_GATE,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_BETA_ALPHA,
+ MODEL_TENSOR.SSM_OUT
+ ],
+ MODEL_ARCH.QWEN3VL: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.QWEN3VLMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.QWEN35: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_GATE,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_BETA,
+ MODEL_TENSOR.SSM_ALPHA,
+ MODEL_TENSOR.SSM_OUT
+ ],
+ MODEL_ARCH.QWEN35MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_GATE,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_BETA,
+ MODEL_TENSOR.SSM_ALPHA,
+ MODEL_TENSOR.SSM_OUT
+ ],
+ MODEL_ARCH.PLAMO: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.PLAMO2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_X,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.SSM_DT_NORM,
+ MODEL_TENSOR.SSM_B_NORM,
+ MODEL_TENSOR.SSM_C_NORM,
+ ],
+ MODEL_ARCH.PLAMO3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.GPT2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.PHI2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.PHI3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.PHIMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.CODESHELL: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.POS_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.ORION: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.INTERNLM2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.MINICPM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.MINICPM3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FACTORS_LONG,
+ MODEL_TENSOR.ROPE_FACTORS_SHORT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q_A,
+ MODEL_TENSOR.ATTN_Q_B,
+ MODEL_TENSOR.ATTN_KV_A_MQA,
+ MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_Q_A_NORM,
+ MODEL_TENSOR.ATTN_KV_A_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.GEMMA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_NORM,
+ ],
+ MODEL_ARCH.GEMMA2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.GEMMA3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.GEMMA3N: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ # altup / laurel
+ MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
+ MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
+ MODEL_TENSOR.PER_LAYER_INP_GATE,
+ MODEL_TENSOR.PER_LAYER_PROJ,
+ MODEL_TENSOR.PER_LAYER_PROJ_NORM,
+ MODEL_TENSOR.PER_LAYER_POST_NORM,
+ MODEL_TENSOR.ALTUP_PROJ,
+ MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
+ MODEL_TENSOR.ALTUP_CORRECT_COEF,
+ MODEL_TENSOR.ALTUP_CORRECT_SCALE,
+ MODEL_TENSOR.ALTUP_PREDICT_COEF,
+ MODEL_TENSOR.ALTUP_ROUTER,
+ MODEL_TENSOR.ALTUP_ROUTER_NORM,
+ MODEL_TENSOR.LAUREL_L,
+ MODEL_TENSOR.LAUREL_R,
+ MODEL_TENSOR.LAUREL_POST_NORM,
+ ],
+ MODEL_ARCH.GEMMA_EMBEDDING: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.DENSE_2_OUT,
+ MODEL_TENSOR.DENSE_3_OUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.STARCODER2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.RWKV6: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_NORM_2,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_LERP_X,
+ MODEL_TENSOR.TIME_MIX_LERP_K,
+ MODEL_TENSOR.TIME_MIX_LERP_V,
+ MODEL_TENSOR.TIME_MIX_LERP_R,
+ MODEL_TENSOR.TIME_MIX_LERP_G,
+ MODEL_TENSOR.TIME_MIX_LERP_W,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+ MODEL_TENSOR.TIME_MIX_FIRST,
+ MODEL_TENSOR.TIME_MIX_DECAY,
+ MODEL_TENSOR.TIME_MIX_DECAY_W1,
+ MODEL_TENSOR.TIME_MIX_DECAY_W2,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_GATE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K,
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R,
+ MODEL_TENSOR.CHANNEL_MIX_KEY,
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
+ MODEL_TENSOR.CHANNEL_MIX_VALUE,
+ ],
+ MODEL_ARCH.RWKV6QWEN2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_LERP_X,
+ MODEL_TENSOR.TIME_MIX_LERP_K,
+ MODEL_TENSOR.TIME_MIX_LERP_V,
+ MODEL_TENSOR.TIME_MIX_LERP_R,
+ MODEL_TENSOR.TIME_MIX_LERP_G,
+ MODEL_TENSOR.TIME_MIX_LERP_W,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+ MODEL_TENSOR.TIME_MIX_FIRST,
+ MODEL_TENSOR.TIME_MIX_DECAY,
+ MODEL_TENSOR.TIME_MIX_DECAY_W1,
+ MODEL_TENSOR.TIME_MIX_DECAY_W2,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_GATE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.RWKV7: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_NORM_2,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+ MODEL_TENSOR.TIME_MIX_W0,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_A0,
+ MODEL_TENSOR.TIME_MIX_A1,
+ MODEL_TENSOR.TIME_MIX_A2,
+ MODEL_TENSOR.TIME_MIX_V0,
+ MODEL_TENSOR.TIME_MIX_V1,
+ MODEL_TENSOR.TIME_MIX_V2,
+ MODEL_TENSOR.TIME_MIX_G1,
+ MODEL_TENSOR.TIME_MIX_G2,
+ MODEL_TENSOR.TIME_MIX_K_K,
+ MODEL_TENSOR.TIME_MIX_K_A,
+ MODEL_TENSOR.TIME_MIX_R_K,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K,
+ MODEL_TENSOR.CHANNEL_MIX_KEY,
+ MODEL_TENSOR.CHANNEL_MIX_VALUE,
+ ],
+ MODEL_ARCH.ARWKV7: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.TIME_MIX_LERP_FUSED,
+ MODEL_TENSOR.TIME_MIX_W0,
+ MODEL_TENSOR.TIME_MIX_W1,
+ MODEL_TENSOR.TIME_MIX_W2,
+ MODEL_TENSOR.TIME_MIX_A0,
+ MODEL_TENSOR.TIME_MIX_A1,
+ MODEL_TENSOR.TIME_MIX_A2,
+ MODEL_TENSOR.TIME_MIX_V0,
+ MODEL_TENSOR.TIME_MIX_V1,
+ MODEL_TENSOR.TIME_MIX_V2,
+ MODEL_TENSOR.TIME_MIX_G1,
+ MODEL_TENSOR.TIME_MIX_G2,
+ MODEL_TENSOR.TIME_MIX_K_K,
+ MODEL_TENSOR.TIME_MIX_K_A,
+ MODEL_TENSOR.TIME_MIX_R_K,
+ MODEL_TENSOR.TIME_MIX_KEY,
+ MODEL_TENSOR.TIME_MIX_VALUE,
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE,
+ MODEL_TENSOR.TIME_MIX_LN,
+ MODEL_TENSOR.TIME_MIX_OUTPUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.MAMBA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_X,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_OUT,
+ ],
+ MODEL_ARCH.MAMBA2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ ],
+ MODEL_ARCH.JAMBA: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_X,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_DT_NORM,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_B_NORM,
+ MODEL_TENSOR.SSM_C_NORM,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.XVERSE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.COMMAND_R: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ ],
+ MODEL_ARCH.COHERE2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.DBRX: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.OLMO: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.OLMO2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.SEED_OSS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ ],
+ MODEL_ARCH.OLMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ ],
+ MODEL_ARCH.OPENELM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.ARCTIC: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_NORM_EXP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.DEEPSEEK: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.DEEPSEEK2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_A,
+ MODEL_TENSOR.ATTN_Q_B,
+ MODEL_TENSOR.ATTN_KV_A_MQA,
+ MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_K_B,
+ MODEL_TENSOR.ATTN_V_B,
+ MODEL_TENSOR.ATTN_Q_A_NORM,
+ MODEL_TENSOR.ATTN_KV_A_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.ERNIE4_5_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.PLM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_KV_A_MQA,
+ MODEL_TENSOR.ATTN_KV_A_NORM,
+ MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_DOWN,
+ ],
+ MODEL_ARCH.CHATGLM : [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.GLM4 : [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.GLM4_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ # NextN/MTP tensors - preserved but unused
+ MODEL_TENSOR.NEXTN_EH_PROJ,
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+ MODEL_TENSOR.NEXTN_ENORM,
+ MODEL_TENSOR.NEXTN_HNORM,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+ ],
+ MODEL_ARCH.BITNET: [
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_SUB_NORM,
+ MODEL_TENSOR.FFN_SUB_NORM,
+ ],
+ MODEL_ARCH.T5: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.DEC_ATTN_NORM,
+ MODEL_TENSOR.DEC_ATTN_Q,
+ MODEL_TENSOR.DEC_ATTN_K,
+ MODEL_TENSOR.DEC_ATTN_V,
+ MODEL_TENSOR.DEC_ATTN_OUT,
+ MODEL_TENSOR.DEC_ATTN_REL_B,
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM,
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q,
+ MODEL_TENSOR.DEC_CROSS_ATTN_K,
+ MODEL_TENSOR.DEC_CROSS_ATTN_V,
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT,
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B,
+ MODEL_TENSOR.DEC_FFN_NORM,
+ MODEL_TENSOR.DEC_FFN_GATE,
+ MODEL_TENSOR.DEC_FFN_DOWN,
+ MODEL_TENSOR.DEC_FFN_UP,
+ MODEL_TENSOR.DEC_OUTPUT_NORM,
+ MODEL_TENSOR.ENC_ATTN_NORM,
+ MODEL_TENSOR.ENC_ATTN_Q,
+ MODEL_TENSOR.ENC_ATTN_K,
+ MODEL_TENSOR.ENC_ATTN_V,
+ MODEL_TENSOR.ENC_ATTN_OUT,
+ MODEL_TENSOR.ENC_ATTN_REL_B,
+ MODEL_TENSOR.ENC_FFN_NORM,
+ MODEL_TENSOR.ENC_FFN_GATE,
+ MODEL_TENSOR.ENC_FFN_DOWN,
+ MODEL_TENSOR.ENC_FFN_UP,
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
+ ],
+ MODEL_ARCH.T5ENCODER: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ENC_ATTN_NORM,
+ MODEL_TENSOR.ENC_ATTN_Q,
+ MODEL_TENSOR.ENC_ATTN_K,
+ MODEL_TENSOR.ENC_ATTN_V,
+ MODEL_TENSOR.ENC_ATTN_OUT,
+ MODEL_TENSOR.ENC_ATTN_REL_B,
+ MODEL_TENSOR.ENC_FFN_NORM,
+ MODEL_TENSOR.ENC_FFN_GATE,
+ MODEL_TENSOR.ENC_FFN_DOWN,
+ MODEL_TENSOR.ENC_FFN_UP,
+ MODEL_TENSOR.ENC_OUTPUT_NORM,
+ ],
+ MODEL_ARCH.JAIS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.NEMOTRON: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.NEMOTRON_H: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.NEMOTRON_H_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ # experts
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ # shared expert
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.EXAONE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.EXAONE4: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
+ MODEL_ARCH.EXAONE_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ # NextN/MTP tensors - preserved but unused
+ MODEL_TENSOR.NEXTN_EH_PROJ,
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+ MODEL_TENSOR.NEXTN_ENORM,
+ MODEL_TENSOR.NEXTN_HNORM,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+ ],
+ MODEL_ARCH.GRANITE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.GRANITE_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ ],
+ MODEL_ARCH.GRANITE_HYBRID: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ # MoE
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ # Dense
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.CHAMELEON: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.WAVTOKENIZER_DEC: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.CONV1D,
+ MODEL_TENSOR.CONVNEXT_DW,
+ MODEL_TENSOR.CONVNEXT_NORM,
+ MODEL_TENSOR.CONVNEXT_PW1,
+ MODEL_TENSOR.CONVNEXT_PW2,
+ MODEL_TENSOR.CONVNEXT_GAMMA,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.POSNET_CONV1,
+ MODEL_TENSOR.POSNET_CONV2,
+ MODEL_TENSOR.POSNET_NORM,
+ MODEL_TENSOR.POSNET_NORM1,
+ MODEL_TENSOR.POSNET_NORM2,
+ MODEL_TENSOR.POSNET_ATTN_NORM,
+ MODEL_TENSOR.POSNET_ATTN_Q,
+ MODEL_TENSOR.POSNET_ATTN_K,
+ MODEL_TENSOR.POSNET_ATTN_V,
+ MODEL_TENSOR.POSNET_ATTN_OUT,
+ ],
+ MODEL_ARCH.BAILINGMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.BAILINGMOE2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.NEXTN_EH_PROJ,
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+ MODEL_TENSOR.NEXTN_ENORM,
+ MODEL_TENSOR.NEXTN_HNORM,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+ MODEL_TENSOR.LAYER_OUT_NORM,
+ ],
+ MODEL_ARCH.DOTS1: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.ARCEE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.AFMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_GATE,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_PRE_NORM,
+ MODEL_TENSOR.FFN_POST_NORM,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.ERNIE4_5: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.FALCON_H1: [
+ # Token embedding
+ MODEL_TENSOR.TOKEN_EMBD,
+
+ # Input layernorm
+ MODEL_TENSOR.ATTN_NORM,
+
+ # Attention components
+ MODEL_TENSOR.ATTN_Q, # Query projection
+ MODEL_TENSOR.ATTN_K, # Key projection
+ MODEL_TENSOR.ATTN_V, # Value projection
+ MODEL_TENSOR.ATTN_OUT, # Output projection
+
+ # SSM components (Mamba2 specific)
+ MODEL_TENSOR.SSM_IN, # Input projection for SSM
+ MODEL_TENSOR.SSM_CONV1D, # Convolution layer
+ MODEL_TENSOR.SSM_DT, # Delta time projection
+ MODEL_TENSOR.SSM_A, # A parameter (log form)
+ MODEL_TENSOR.SSM_D, # D parameter
+ MODEL_TENSOR.SSM_NORM, # Normalization in SSM
+ MODEL_TENSOR.SSM_OUT, # Output projection
+
+ # Pre-feedforward layernorm
+ MODEL_TENSOR.FFN_PRE_NORM,
+
+ # Feed-forward network components
+ MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
+ MODEL_TENSOR.FFN_DOWN, # Down projection
+ MODEL_TENSOR.FFN_UP, # Up projection
+
+ # Post-feedforward layernorm
+ MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
+ MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
+ ],
+ MODEL_ARCH.HUNYUAN_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ MODEL_ARCH.HUNYUAN_DENSE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.SMOLLM3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.GPT_OSS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_SINKS,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.LFM2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.SHORTCONV_CONV,
+ MODEL_TENSOR.SHORTCONV_INPROJ,
+ MODEL_TENSOR.SHORTCONV_OUTPROJ,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.ATTN_NORM, # operator_norm
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.DENSE_2_OUT, # LFM2-ColBert-350M
+ ],
+ MODEL_ARCH.LFM2MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.TOKEN_EMBD_NORM,
+ MODEL_TENSOR.SHORTCONV_CONV,
+ MODEL_TENSOR.SHORTCONV_INPROJ,
+ MODEL_TENSOR.SHORTCONV_OUTPROJ,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.ATTN_NORM, # operator_norm
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.SMALLTHINKER: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.APERTUS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.LLADA_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ ],
+ MODEL_ARCH.GROVEMOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_CHEXP,
+ MODEL_TENSOR.FFN_DOWN_CHEXP,
+ MODEL_TENSOR.FFN_UP_CHEXP,
+ ],
+ MODEL_ARCH.MINIMAXM2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.COGVLM: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.VISEXP_ATTN_QKV,
+ MODEL_TENSOR.VISEXP_ATTN_OUT,
+ MODEL_TENSOR.VISEXP_GATE,
+ MODEL_TENSOR.VISEXP_UP,
+ MODEL_TENSOR.VISEXP_DOWN,
+ ],
+ MODEL_ARCH.RND1: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.PANGU_EMBED: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.MISTRAL3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.MIMO2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_SINKS,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.STEP35: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_GATE,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ ],
+ MODEL_ARCH.LLAMA_EMBED: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
+ MODEL_ARCH.MAINCODER: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
+ MODEL_ARCH.KIMI_LINEAR: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_Q_A,
+ MODEL_TENSOR.ATTN_Q_B,
+ MODEL_TENSOR.ATTN_KV_A_MQA,
+ MODEL_TENSOR.ATTN_KV_B,
+ MODEL_TENSOR.ATTN_K_B,
+ MODEL_TENSOR.ATTN_V_B,
+ MODEL_TENSOR.ATTN_Q_A_NORM,
+ MODEL_TENSOR.ATTN_KV_A_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.SSM_CONV1D_Q,
+ MODEL_TENSOR.SSM_CONV1D_K,
+ MODEL_TENSOR.SSM_CONV1D_V,
+ MODEL_TENSOR.SSM_F_A,
+ MODEL_TENSOR.SSM_F_B,
+ MODEL_TENSOR.SSM_BETA,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_G_A,
+ MODEL_TENSOR.SSM_G_B,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.FFN_EXP_PROBS_B,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
+ # TODO
+}
+
+# tensors that will not be serialized
+MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
+ MODEL_ARCH.LLAMA: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.DECI: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.BAICHUAN: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.QWEN: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.CODESHELL: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.ORION: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.STARCODER2: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.XVERSE: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.DEEPSEEK: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.DEEPSEEK2: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.CHATGLM: [
+ MODEL_TENSOR.ROPE_FREQS,
+ ],
+ MODEL_ARCH.NEMOTRON: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+ MODEL_ARCH.BAILINGMOE: [
+ MODEL_TENSOR.ROPE_FREQS,
+ ],
+ MODEL_ARCH.PANGU_EMBED: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
+}
+
+#
+# types
+#
+
+
+class TokenType(IntEnum):
+ NORMAL = 1
+ UNKNOWN = 2
+ CONTROL = 3
+ USER_DEFINED = 4
+ UNUSED = 5
+ BYTE = 6
+
+
+class RopeScalingType(Enum):
+ NONE = 'none'
+ LINEAR = 'linear'
+ YARN = 'yarn'
+ LONGROPE = 'longrope'
+
+
+class PoolingType(IntEnum):
+ NONE = 0
+ MEAN = 1
+ CLS = 2
+ LAST = 3
+ RANK = 4
+
+
+class GGMLQuantizationType(IntEnum):
+ F32 = 0
+ F16 = 1
+ Q4_0 = 2
+ Q4_1 = 3
+ Q5_0 = 6
+ Q5_1 = 7
+ Q8_0 = 8
+ Q8_1 = 9
+ Q2_K = 10
+ Q3_K = 11
+ Q4_K = 12
+ Q5_K = 13
+ Q6_K = 14
+ Q8_K = 15
+ IQ2_XXS = 16
+ IQ2_XS = 17
+ IQ3_XXS = 18
+ IQ1_S = 19
+ IQ4_NL = 20
+ IQ3_S = 21
+ IQ2_S = 22
+ IQ4_XS = 23
+ I8 = 24
+ I16 = 25
+ I32 = 26
+ I64 = 27
+ F64 = 28
+ IQ1_M = 29
+ BF16 = 30
+ TQ1_0 = 34
+ TQ2_0 = 35
+ MXFP4 = 39
+
+
+class ExpertGatingFuncType(IntEnum):
+ SOFTMAX = 1
+ SIGMOID = 2
+
+
+# TODO: add GGMLFileType from ggml_ftype in ggml.h
+
+
+# from llama_ftype in llama.h
+# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
+class LlamaFileType(IntEnum):
+ ALL_F32 = 0
+ MOSTLY_F16 = 1 # except 1d tensors
+ MOSTLY_Q4_0 = 2 # except 1d tensors
+ MOSTLY_Q4_1 = 3 # except 1d tensors
+ # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
+ # MOSTLY_Q4_2 = 5 # support has been removed
+ # MOSTLY_Q4_3 = 6 # support has been removed
+ MOSTLY_Q8_0 = 7 # except 1d tensors
+ MOSTLY_Q5_0 = 8 # except 1d tensors
+ MOSTLY_Q5_1 = 9 # except 1d tensors
+ MOSTLY_Q2_K = 10 # except 1d tensors
+ MOSTLY_Q3_K_S = 11 # except 1d tensors
+ MOSTLY_Q3_K_M = 12 # except 1d tensors
+ MOSTLY_Q3_K_L = 13 # except 1d tensors
+ MOSTLY_Q4_K_S = 14 # except 1d tensors
+ MOSTLY_Q4_K_M = 15 # except 1d tensors
+ MOSTLY_Q5_K_S = 16 # except 1d tensors
+ MOSTLY_Q5_K_M = 17 # except 1d tensors
+ MOSTLY_Q6_K = 18 # except 1d tensors
+ MOSTLY_IQ2_XXS = 19 # except 1d tensors
+ MOSTLY_IQ2_XS = 20 # except 1d tensors
+ MOSTLY_Q2_K_S = 21 # except 1d tensors
+ MOSTLY_IQ3_XS = 22 # except 1d tensors
+ MOSTLY_IQ3_XXS = 23 # except 1d tensors
+ MOSTLY_IQ1_S = 24 # except 1d tensors
+ MOSTLY_IQ4_NL = 25 # except 1d tensors
+ MOSTLY_IQ3_S = 26 # except 1d tensors
+ MOSTLY_IQ3_M = 27 # except 1d tensors
+ MOSTLY_IQ2_S = 28 # except 1d tensors
+ MOSTLY_IQ2_M = 29 # except 1d tensors
+ MOSTLY_IQ4_XS = 30 # except 1d tensors
+ MOSTLY_IQ1_M = 31 # except 1d tensors
+ MOSTLY_BF16 = 32 # except 1d tensors
+ # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack
+ # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack
+ # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
+ MOSTLY_TQ1_0 = 36 # except 1d tensors
+ MOSTLY_TQ2_0 = 37 # except 1d tensors
+
+ GUESSED = 1024 # not specified in the model file
+
+
+class GGUFEndian(IntEnum):
+ LITTLE = 0
+ BIG = 1
+
+
+class GGUFValueType(IntEnum):
+ UINT8 = 0
+ INT8 = 1
+ UINT16 = 2
+ INT16 = 3
+ UINT32 = 4
+ INT32 = 5
+ FLOAT32 = 6
+ BOOL = 7
+ STRING = 8
+ ARRAY = 9
+ UINT64 = 10
+ INT64 = 11
+ FLOAT64 = 12
+
+ @staticmethod
+ def get_type(val: Any) -> GGUFValueType:
+ if isinstance(val, (str, bytes, bytearray)):
+ return GGUFValueType.STRING
+ elif isinstance(val, list):
+ return GGUFValueType.ARRAY
+ elif isinstance(val, float):
+ return GGUFValueType.FLOAT32
+ elif isinstance(val, bool):
+ return GGUFValueType.BOOL
+ elif isinstance(val, int):
+ return GGUFValueType.INT32
+ # TODO: need help with 64-bit types in Python
+ else:
+ raise ValueError(f"Unknown type: {type(val)}")
+
+
+class VisionProjectorType:
+ GEMMA3 = "gemma3"
+ GEMMA3NV = "gemma3nv"
+ GEMMA3NA = "gemma3na"
+ IDEFICS3 = "idefics3"
+ PIXTRAL = "pixtral"
+ LLAMA4 = "llama4"
+ QWEN2VL = "qwen2vl_merger"
+ QWEN25VL = "qwen2.5vl_merger"
+ QWEN3VL = "qwen3vl_merger"
+ ULTRAVOX = "ultravox"
+ INTERNVL = "internvl"
+ QWEN2A = "qwen2a" # audio
+ GLMA = "glma" # audio
+ QWEN25O = "qwen2.5o" # omni
+ VOXTRAL = "voxtral"
+ LFM2 = "lfm2"
+ KIMIVL = "kimivl"
+ KIMIK25 = "kimik25"
+ LIGHTONOCR = "lightonocr"
+ COGVLM = "cogvlm"
+ JANUS_PRO = "janus_pro"
+ LFM2A = "lfm2a" # audio
+ MUSIC_FLAMINGO = "musicflamingo" # audio
+ GLM4V = "glm4v"
+ YOUTUVL = "youtuvl"
+
+
+# Items here are (block size, type size)
+QK_K = 256
+GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
+ GGMLQuantizationType.F32: (1, 4),
+ GGMLQuantizationType.F16: (1, 2),
+ GGMLQuantizationType.Q4_0: (32, 2 + 16),
+ GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
+ GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
+ GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
+ GGMLQuantizationType.Q8_0: (32, 2 + 32),
+ GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
+ GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
+ GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
+ GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
+ GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
+ GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
+ GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
+ GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
+ GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
+ GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),
+ GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
+ GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
+ GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
+ GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
+ GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
+ GGMLQuantizationType.I8: (1, 1),
+ GGMLQuantizationType.I16: (1, 2),
+ GGMLQuantizationType.I32: (1, 4),
+ GGMLQuantizationType.I64: (1, 8),
+ GGMLQuantizationType.F64: (1, 8),
+ GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
+ GGMLQuantizationType.BF16: (1, 2),
+ GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
+ GGMLQuantizationType.TQ2_0: (256, 2 + 64),
+ GGMLQuantizationType.MXFP4: (32, 1 + 16),
+}
+
+
+# Aliases for backward compatibility.
+
+# general
+KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE
+KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION
+KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT
+KEY_GENERAL_NAME = Keys.General.NAME
+KEY_GENERAL_AUTHOR = Keys.General.AUTHOR
+KEY_GENERAL_URL = Keys.General.URL
+KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION
+KEY_GENERAL_LICENSE = Keys.General.LICENSE
+KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
+KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
+
+# LLM
+KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
+KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
+KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
+KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT
+KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH
+KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
+KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT
+
+# attention
+KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT
+KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV
+KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS
+KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV
+KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
+KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
+
+# RoPE
+KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
+KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
+KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
+KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
+KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
+KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
+
+# SSM
+KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
+KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
+KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
+KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
+KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
+KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
+
+# KDA
+KEY_KDA_HEAD_DIM = Keys.KDA.HEAD_DIM
+
+# tokenization
+KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
+KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
+KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
+KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
+KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
+KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
+KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
+KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
+KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
+KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID
+KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
+KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
+KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
+KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
+KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
+KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
+
+KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID
+KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID
+KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID
+KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID
+KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID
+KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID
+
+# deprecated
+KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID
+KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
+KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
diff --git a/llama.cpp/gguf-py/gguf/gguf.py b/llama.cpp/gguf-py/gguf/gguf.py
new file mode 100644
index 0000000..651a81e
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/gguf.py
@@ -0,0 +1,15 @@
+# This file left for compatibility. If you want to use the GGUF API from Python
+# then don't import gguf/gguf.py directly. If you're looking for examples, see the
+# examples/ directory for gguf-py
+
+import importlib
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+# Compatibility for people trying to import gguf/gguf.py directly instead of as a package.
+importlib.invalidate_caches()
+import gguf # noqa: E402
+
+importlib.reload(gguf)
diff --git a/llama.cpp/gguf-py/gguf/gguf_reader.py b/llama.cpp/gguf-py/gguf/gguf_reader.py
new file mode 100644
index 0000000..d87e8f7
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/gguf_reader.py
@@ -0,0 +1,367 @@
+#
+# GGUF file reading/modification support. For API usage information,
+# please see the files scripts/ for some fairly simple examples.
+#
+from __future__ import annotations
+
+import logging
+import os
+import sys
+from collections import OrderedDict
+from typing import Any, Literal, NamedTuple, TypeVar, Union
+
+import numpy as np
+import numpy.typing as npt
+
+from .quants import quant_shape_to_byte_shape
+
+if __name__ == "__main__":
+ from pathlib import Path
+
+ # Allow running file in package as a script.
+ sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from gguf.constants import (
+ GGML_QUANT_SIZES,
+ GGUF_DEFAULT_ALIGNMENT,
+ GGUF_MAGIC,
+ GGUF_VERSION,
+ GGMLQuantizationType,
+ GGUFValueType,
+ GGUFEndian,
+)
+
+logger = logging.getLogger(__name__)
+
+READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
+
+
+class ReaderField(NamedTuple):
+ # Offset to start of this field.
+ offset: int
+
+ # Name of the field (not necessarily from file data).
+ name: str
+
+ # Data parts. Some types have multiple components, such as strings
+ # that consist of a length followed by the string data.
+ parts: list[npt.NDArray[Any]] = []
+
+ # Indexes into parts that we can call the actual data. For example
+ # an array of strings will be populated with indexes to the actual
+ # string data.
+ data: list[int] = [-1]
+
+ types: list[GGUFValueType] = []
+
+ def contents(self, index_or_slice: int | slice = slice(None)) -> Any:
+ if self.types:
+ to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731
+ main_type = self.types[0]
+
+ if main_type == GGUFValueType.ARRAY:
+ sub_type = self.types[-1]
+
+ if sub_type == GGUFValueType.STRING:
+ indices = self.data[index_or_slice]
+
+ if isinstance(index_or_slice, int):
+ return to_string(self.parts[indices]) # type: ignore
+ else:
+ return [to_string(self.parts[idx]) for idx in indices] # type: ignore
+ else:
+ # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too
+
+ # Check if it's unsafe to perform slice optimization on data
+ # if any(True for idx in self.data if len(self.parts[idx]) != 1):
+ # optim_slice = slice(None)
+ # else:
+ # optim_slice = index_or_slice
+ # index_or_slice = slice(None)
+
+ # if isinstance(optim_slice, int):
+ # return self.parts[self.data[optim_slice]].tolist()[0]
+ # else:
+ # return [pv for idx in self.data[optim_slice] for pv in self.parts[idx].tolist()][index_or_slice]
+
+ if isinstance(index_or_slice, int):
+ return self.parts[self.data[index_or_slice]].tolist()[0]
+ else:
+ return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()]
+
+ if main_type == GGUFValueType.STRING:
+ return to_string(self.parts[-1])
+ else:
+ return self.parts[-1].tolist()[0]
+
+ return None
+
+
+class ReaderTensor(NamedTuple):
+ name: str
+ tensor_type: GGMLQuantizationType
+ shape: npt.NDArray[np.uint32]
+ n_elements: int
+ n_bytes: int
+ data_offset: int
+ data: npt.NDArray[Any]
+ field: ReaderField
+
+
+class GGUFReader:
+ # I - same as host, S - swapped
+ byte_order: Literal['I', 'S'] = 'I'
+ alignment: int = GGUF_DEFAULT_ALIGNMENT
+ data_offset: int
+
+ # Note: Internal helper, API may change.
+ gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
+ GGUFValueType.UINT8: np.uint8,
+ GGUFValueType.INT8: np.int8,
+ GGUFValueType.UINT16: np.uint16,
+ GGUFValueType.INT16: np.int16,
+ GGUFValueType.UINT32: np.uint32,
+ GGUFValueType.INT32: np.int32,
+ GGUFValueType.FLOAT32: np.float32,
+ GGUFValueType.UINT64: np.uint64,
+ GGUFValueType.INT64: np.int64,
+ GGUFValueType.FLOAT64: np.float64,
+ GGUFValueType.BOOL: np.bool_,
+ }
+
+ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
+ self.data = np.memmap(path, mode = mode)
+ offs = 0
+
+ # Check for GGUF magic
+ if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
+ raise ValueError('GGUF magic invalid')
+ offs += 4
+
+ # Check GGUF version
+ temp_version = self._get(offs, np.uint32)
+ if temp_version[0] & 65535 == 0:
+ # If we get 0 here that means it's (probably) a GGUF file created for
+ # the opposite byte order of the machine this script is running on.
+ self.byte_order = 'S'
+ temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order))
+ version = temp_version[0]
+ if version not in READER_SUPPORTED_VERSIONS:
+ raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
+ if sys.byteorder == "little":
+ # Host is little endian
+ host_endian = GGUFEndian.LITTLE
+ swapped_endian = GGUFEndian.BIG
+ else:
+ # Sorry PDP or other weird systems that don't use BE or LE.
+ host_endian = GGUFEndian.BIG
+ swapped_endian = GGUFEndian.LITTLE
+ self.endianess = swapped_endian if self.byte_order == "S" else host_endian
+ self.fields: OrderedDict[str, ReaderField] = OrderedDict()
+ self.tensors: list[ReaderTensor] = []
+ offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
+
+ # Check tensor count and kv count
+ temp_counts = self._get(offs, np.uint64, 2)
+ offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
+ offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
+ tensor_count, kv_count = temp_counts
+ offs = self._build_fields(offs, kv_count)
+
+ # Build Tensor Info Fields
+ offs, tensors_fields = self._build_tensor_info(offs, tensor_count)
+ new_align = self.fields.get('general.alignment')
+ if new_align is not None:
+ if new_align.types != [GGUFValueType.UINT32]:
+ raise ValueError('Bad type for general.alignment field')
+ self.alignment = new_align.parts[-1][0]
+ padding = offs % self.alignment
+ if padding != 0:
+ offs += self.alignment - padding
+ self.data_offset = offs
+ self._build_tensors(offs, tensors_fields)
+
+ _DT = TypeVar('_DT', bound = npt.DTypeLike)
+
+ # Fetch a key/value metadata field by key.
+ def get_field(self, key: str) -> Union[ReaderField, None]:
+ return self.fields.get(key, None)
+
+ # Fetch a tensor from the list by index.
+ def get_tensor(self, idx: int) -> ReaderTensor:
+ return self.tensors[idx]
+
+ def _get(
+ self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
+ ) -> npt.NDArray[Any]:
+ count = int(count)
+ itemsize = int(np.empty([], dtype = dtype).itemsize)
+ end_offs = offset + itemsize * count
+ arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
+ return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order))
+
+ def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
+ if field.name in self.fields:
+ # TODO: add option to generate error on duplicate keys
+ # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
+
+ logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
+ self.fields[field.name + '_{}'.format(field.offset)] = field
+ else:
+ self.fields[field.name] = field
+ return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
+
+ def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
+ slen = self._get(offset, np.uint64)
+ return slen, self._get(offset + 8, np.uint8, slen[0])
+
+ def _get_field_parts(
+ self, orig_offs: int, raw_type: int,
+ ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
+ offs = orig_offs
+ types: list[GGUFValueType] = []
+ gtype = GGUFValueType(raw_type)
+ types.append(gtype)
+ # Handle strings.
+ if gtype == GGUFValueType.STRING:
+ sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
+ size = sum(int(part.nbytes) for part in sparts)
+ return size, sparts, [1], types
+ # Check if it's a simple scalar type.
+ nptype = self.gguf_scalar_to_np.get(gtype)
+ if nptype is not None:
+ val = self._get(offs, nptype)
+ return int(val.nbytes), [val], [0], types
+ # Handle arrays.
+ if gtype == GGUFValueType.ARRAY:
+ raw_itype = self._get(offs, np.uint32)
+ offs += int(raw_itype.nbytes)
+ alen = self._get(offs, np.uint64)
+ offs += int(alen.nbytes)
+ aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
+ data_idxs: list[int] = []
+ # FIXME: Handle multi-dimensional arrays properly instead of flattening
+ for idx in range(alen[0]):
+ curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
+ if idx == 0:
+ types += curr_types
+ idxs_offs = len(aparts)
+ aparts += curr_parts
+ data_idxs += (idx + idxs_offs for idx in curr_idxs)
+ offs += curr_size
+ return offs - orig_offs, aparts, data_idxs, types
+ # We can't deal with this one.
+ raise ValueError(f'Unknown/unhandled field type {gtype}')
+
+ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
+ offs = orig_offs
+
+ # Get Tensor Name
+ name_len, name_data = self._get_str(offs)
+ offs += int(name_len.nbytes + name_data.nbytes)
+
+ # Get Tensor Dimensions Count
+ n_dims = self._get(offs, np.uint32)
+ offs += int(n_dims.nbytes)
+
+ # Get Tensor Dimension Array
+ dims = self._get(offs, np.uint64, n_dims[0])
+ offs += int(dims.nbytes)
+
+ # Get Tensor Encoding Scheme Type
+ raw_dtype = self._get(offs, np.uint32)
+ offs += int(raw_dtype.nbytes)
+
+ # Get Tensor Offset
+ offset_tensor = self._get(offs, np.uint64)
+ offs += int(offset_tensor.nbytes)
+
+ return ReaderField(
+ orig_offs,
+ str(bytes(name_data), encoding = 'utf-8'),
+ [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
+ [1, 3, 4, 5],
+ )
+
+ def _build_fields(self, offs: int, count: int) -> int:
+ for _ in range(count):
+ orig_offs = offs
+ kv_klen, kv_kdata = self._get_str(offs)
+ offs += int(kv_klen.nbytes + kv_kdata.nbytes)
+ raw_kv_type = self._get(offs, np.uint32)
+ offs += int(raw_kv_type.nbytes)
+ parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
+ idxs_offs = len(parts)
+ field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
+ parts += field_parts
+ self._push_field(ReaderField(
+ orig_offs,
+ str(bytes(kv_kdata), encoding = 'utf-8'),
+ parts,
+ [idx + idxs_offs for idx in field_idxs],
+ field_types,
+ ), skip_sum = True)
+ offs += field_size
+ return offs
+
+ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
+ tensor_fields = []
+ for _ in range(count):
+ field = self._get_tensor_info_field(offs)
+ offs += sum(int(part.nbytes) for part in field.parts)
+ tensor_fields.append(field)
+ return offs, tensor_fields
+
+ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
+ tensors = []
+ tensor_names = set() # keep track of name to prevent duplicated tensors
+ for field in fields:
+ _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
+ # check if there's any tensor having same name already in the list
+ tensor_name = str(bytes(name_data), encoding = 'utf-8')
+ if tensor_name in tensor_names:
+ raise ValueError(f'Found duplicated tensor with name {tensor_name}')
+ tensor_names.add(tensor_name)
+ ggml_type = GGMLQuantizationType(raw_dtype[0])
+ n_elems = int(np.prod(dims))
+ np_dims = tuple(reversed(dims.tolist()))
+ block_size, type_size = GGML_QUANT_SIZES[ggml_type]
+ n_bytes = n_elems * type_size // block_size
+ data_offs = int(start_offs + offset_tensor[0])
+ item_type: npt.DTypeLike
+ if ggml_type == GGMLQuantizationType.F16:
+ item_count = n_elems
+ item_type = np.float16
+ elif ggml_type == GGMLQuantizationType.F32:
+ item_count = n_elems
+ item_type = np.float32
+ elif ggml_type == GGMLQuantizationType.F64:
+ item_count = n_elems
+ item_type = np.float64
+ elif ggml_type == GGMLQuantizationType.I8:
+ item_count = n_elems
+ item_type = np.int8
+ elif ggml_type == GGMLQuantizationType.I16:
+ item_count = n_elems
+ item_type = np.int16
+ elif ggml_type == GGMLQuantizationType.I32:
+ item_count = n_elems
+ item_type = np.int32
+ elif ggml_type == GGMLQuantizationType.I64:
+ item_count = n_elems
+ item_type = np.int64
+ else:
+ item_count = n_bytes
+ item_type = np.uint8
+ np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
+ tensors.append(ReaderTensor(
+ name = tensor_name,
+ tensor_type = ggml_type,
+ shape = dims,
+ n_elements = n_elems,
+ n_bytes = n_bytes,
+ data_offset = data_offs,
+ data = self._get(data_offs, item_type, item_count).reshape(np_dims),
+ field = field,
+ ))
+ self.tensors = tensors
diff --git a/llama.cpp/gguf-py/gguf/gguf_writer.py b/llama.cpp/gguf-py/gguf/gguf_writer.py
new file mode 100644
index 0000000..a237537
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/gguf_writer.py
@@ -0,0 +1,1289 @@
+from __future__ import annotations
+
+import logging
+import os
+import shutil
+import struct
+import sys
+import tempfile
+from dataclasses import dataclass
+from enum import Enum, auto
+from math import prod
+from pathlib import Path
+from io import BufferedWriter
+from typing import IO, Any, Sequence, Mapping
+from string import ascii_letters, digits
+
+import numpy as np
+
+from .constants import (
+ GGUF_DEFAULT_ALIGNMENT,
+ GGUF_MAGIC,
+ GGUF_VERSION,
+ GGMLQuantizationType,
+ GGUFEndian,
+ GGUFValueType,
+ Keys,
+ RopeScalingType,
+ PoolingType,
+ TokenType,
+ ExpertGatingFuncType,
+)
+
+from .quants import quant_shape_from_byte_shape
+
+logger = logging.getLogger(__name__)
+
+
+SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
+
+
+@dataclass
+class TensorInfo:
+ shape: Sequence[int]
+ dtype: GGMLQuantizationType
+ nbytes: int
+ tensor: np.ndarray[Any, Any] | None = None
+
+
+@dataclass
+class GGUFValue:
+ value: Any
+ type: GGUFValueType
+ sub_type: GGUFValueType | None = None
+
+
+class WriterState(Enum):
+ NO_FILE = auto()
+ EMPTY = auto()
+ HEADER = auto()
+ KV_DATA = auto()
+ TI_DATA = auto()
+ WEIGHTS = auto()
+
+
+class GGUFWriter:
+ fout: list[BufferedWriter] | None
+ path: Path | None
+ temp_file: tempfile.SpooledTemporaryFile[bytes] | None
+ tensors: list[dict[str, TensorInfo]]
+ kv_data: list[dict[str, GGUFValue]]
+ state: WriterState
+ _simple_value_packing = {
+ GGUFValueType.UINT8: "B",
+ GGUFValueType.INT8: "b",
+ GGUFValueType.UINT16: "H",
+ GGUFValueType.INT16: "h",
+ GGUFValueType.UINT32: "I",
+ GGUFValueType.INT32: "i",
+ GGUFValueType.FLOAT32: "f",
+ GGUFValueType.UINT64: "Q",
+ GGUFValueType.INT64: "q",
+ GGUFValueType.FLOAT64: "d",
+ GGUFValueType.BOOL: "?",
+ }
+
+ def __init__(
+ self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
+ split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
+ ):
+ self.fout = None
+ self.path = Path(path) if path else None
+ self.arch = arch
+ self.endianess = endianess
+ self.data_alignment = GGUF_DEFAULT_ALIGNMENT
+ self.use_temp_file = use_temp_file
+ self.temp_file = None
+ self.tensors = [{}]
+ self.kv_data = [{}]
+ self.split_max_tensors = split_max_tensors
+ self.split_max_size = split_max_size
+ self.dry_run = dry_run
+ self.small_first_shard = small_first_shard
+ logger.info("gguf: This GGUF file is for {0} Endian only".format(
+ "Big" if self.endianess == GGUFEndian.BIG else "Little",
+ ))
+ self.state = WriterState.NO_FILE
+
+ if self.small_first_shard:
+ self.tensors.append({})
+
+ self.add_architecture()
+
+ def get_total_parameter_count(self) -> tuple[int, int, int, int]:
+ total_params = 0
+ shared_params = 0
+ expert_params = 0
+
+ expert_sum = 0
+ n_expert_tensors = 0
+
+ last_lora_a: tuple[str, TensorInfo] | None = None
+
+ for tensors in self.tensors:
+ for name, info in tensors.items():
+
+ shape = info.shape
+
+ if name.endswith(".lora_a"):
+ last_lora_a = (name, info)
+ continue
+ elif name.endswith(".lora_b"):
+ if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
+ # Bail when the LoRA pair can't be found trivially
+ logger.warning("can't measure LoRA size correctly, tensor order is unusual")
+ return 0, 0, 0, 0
+ else:
+ shape = (*shape[:-1], last_lora_a[1].shape[-1])
+
+ size = prod(shape)
+
+ if "_exps." in name:
+ expert_count = shape[-2 if ".bias" in name else -3]
+ expert_params += (size // expert_count)
+ expert_sum += expert_count
+ n_expert_tensors += 1
+ else:
+ shared_params += size
+
+ total_params += size
+
+ # Hopefully this should work even for variable-expert-count models
+ expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
+
+ # Negate the total to signal it's likely not exact
+ if last_lora_a is not None:
+ total_params = -total_params
+
+ # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
+ return total_params, shared_params, expert_params, expert_count
+
+ def format_shard_names(self, path: Path) -> list[Path]:
+ if len(self.tensors) == 1:
+ return [path]
+ return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
+
+ def open_output_file(self, path: Path | None = None) -> None:
+ if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
+ # allow calling this multiple times as long as the path is the same
+ return
+
+ if self.state is not WriterState.NO_FILE:
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
+
+ if path is not None:
+ self.path = path
+
+ if self.path is not None:
+ filenames = self.print_plan()
+ self.fout = [open(filename, "wb") for filename in filenames]
+ self.state = WriterState.EMPTY
+
+ def print_plan(self) -> list[Path]:
+ logger.info("Writing the following files:")
+ assert self.path is not None
+ filenames = self.format_shard_names(self.path)
+ assert len(filenames) == len(self.tensors)
+ for name, tensors in zip(filenames, self.tensors):
+ logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
+
+ if self.dry_run:
+ logger.info("Dry run, not writing files")
+ for name in filenames:
+ print(name) # noqa: NP100
+ exit()
+
+ return filenames
+
+ def add_shard_kv_data(self) -> None:
+ if len(self.tensors) == 1:
+ return
+
+ total_tensors = sum(len(t) for t in self.tensors)
+ assert self.fout is not None
+ total_splits = len(self.fout)
+ self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
+ for i, kv_data in enumerate(self.kv_data):
+ kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
+ kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
+ kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
+
+ def write_header_to_file(self, path: Path | None = None) -> None:
+ if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
+ logger.warning("Model fails split requirements, not splitting")
+
+ self.open_output_file(path)
+
+ if self.state is not WriterState.EMPTY:
+ raise ValueError(f'Expected output file to be empty, got {self.state}')
+
+ assert self.fout is not None
+ assert len(self.fout) == len(self.tensors)
+ assert len(self.kv_data) == 1
+
+ self.add_shard_kv_data()
+
+ for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
+ fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
+ fout.write(self._pack("I", GGUF_VERSION))
+ fout.write(self._pack("Q", len(tensors)))
+ fout.write(self._pack("Q", len(kv_data)))
+ fout.flush()
+ self.state = WriterState.HEADER
+
+ def write_kv_data_to_file(self) -> None:
+ if self.state is not WriterState.HEADER:
+ raise ValueError(f'Expected output file to contain the header, got {self.state}')
+ assert self.fout is not None
+
+ for fout, kv_data in zip(self.fout, self.kv_data):
+ kv_bytes = bytearray()
+
+ for key, val in kv_data.items():
+ kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
+ kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
+
+ fout.write(kv_bytes)
+
+ self.flush()
+ self.state = WriterState.KV_DATA
+
+ def write_ti_data_to_file(self) -> None:
+ if self.state is not WriterState.KV_DATA:
+ raise ValueError(f'Expected output file to contain KV data, got {self.state}')
+ assert self.fout is not None
+
+ for fout, tensors in zip(self.fout, self.tensors):
+ ti_data = bytearray()
+ offset_tensor = 0
+
+ for name, ti in tensors.items():
+ ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
+ n_dims = len(ti.shape)
+ ti_data += self._pack("I", n_dims)
+ for j in range(n_dims):
+ ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
+ ti_data += self._pack("I", ti.dtype)
+ ti_data += self._pack("Q", offset_tensor)
+ offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
+
+ fout.write(ti_data)
+ fout.flush()
+ self.state = WriterState.TI_DATA
+
+ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
+ if any(key in kv_data for kv_data in self.kv_data):
+ logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
+
+ self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
+
+ def add_uint8(self, key: str, val: int) -> None:
+ self.add_key_value(key,val, GGUFValueType.UINT8)
+
+ def add_int8(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.INT8)
+
+ def add_uint16(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.UINT16)
+
+ def add_int16(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.INT16)
+
+ def add_uint32(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.UINT32)
+
+ def add_int32(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.INT32)
+
+ def add_float32(self, key: str, val: float) -> None:
+ self.add_key_value(key, val, GGUFValueType.FLOAT32)
+
+ def add_uint64(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.UINT64)
+
+ def add_int64(self, key: str, val: int) -> None:
+ self.add_key_value(key, val, GGUFValueType.INT64)
+
+ def add_float64(self, key: str, val: float) -> None:
+ self.add_key_value(key, val, GGUFValueType.FLOAT64)
+
+ def add_bool(self, key: str, val: bool) -> None:
+ self.add_key_value(key, val, GGUFValueType.BOOL)
+
+ def add_string(self, key: str, val: str) -> None:
+ if not val:
+ return
+ self.add_key_value(key, val, GGUFValueType.STRING)
+
+ def add_array(self, key: str, val: Sequence[Any]) -> None:
+ if len(val) == 0:
+ return
+ self.add_key_value(key, val, GGUFValueType.ARRAY)
+
+ @staticmethod
+ def ggml_pad(x: int, n: int) -> int:
+ return ((x + n - 1) // n) * n
+
+ def add_tensor_info(
+ self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
+ tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
+ ) -> None:
+ if self.state is not WriterState.NO_FILE:
+ raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
+
+ if any(name in tensors for tensors in self.tensors):
+ raise ValueError(f'Duplicated tensor name {name!r}')
+
+ if raw_dtype is None:
+ if tensor_dtype == np.float16:
+ dtype = GGMLQuantizationType.F16
+ elif tensor_dtype == np.float32:
+ dtype = GGMLQuantizationType.F32
+ elif tensor_dtype == np.float64:
+ dtype = GGMLQuantizationType.F64
+ elif tensor_dtype == np.int8:
+ dtype = GGMLQuantizationType.I8
+ elif tensor_dtype == np.int16:
+ dtype = GGMLQuantizationType.I16
+ elif tensor_dtype == np.int32:
+ dtype = GGMLQuantizationType.I32
+ elif tensor_dtype == np.int64:
+ dtype = GGMLQuantizationType.I64
+ else:
+ raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
+ else:
+ dtype = raw_dtype
+ if tensor_dtype == np.uint8:
+ tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
+
+ # make sure there is at least one tensor before splitting
+ if len(self.tensors[-1]) > 0:
+ if ( # split when over tensor limit
+ self.split_max_tensors != 0
+ and len(self.tensors[-1]) >= self.split_max_tensors
+ ) or ( # split when over size limit
+ self.split_max_size != 0
+ and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
+ ):
+ self.tensors.append({})
+
+ self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
+
+ def add_tensor(
+ self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
+ raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
+ ) -> None:
+ # if tensor endianness is not passed, assume it's native to system
+ if tensor_endianess is None:
+ tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
+
+ if tensor_endianess != self.endianess:
+ # Don't byteswap inplace since lazy copies cannot handle it
+ tensor = tensor.byteswap(inplace=False)
+ if self.use_temp_file and self.temp_file is None:
+ fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
+ fp.seek(0)
+ self.temp_file = fp
+
+ shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
+ self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
+
+ if self.temp_file is None:
+ self.tensors[-1][name].tensor = tensor
+ return
+
+ tensor.tofile(self.temp_file)
+ self.write_padding(self.temp_file, tensor.nbytes)
+
+ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
+ pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
+ if pad != 0:
+ fp.write(bytes([0] * pad))
+
+ def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
+ if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
+ raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
+ assert self.fout is not None
+
+ # if tensor endianness is not passed, assume it's native to system
+ if tensor_endianess is None:
+ tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
+
+ if tensor_endianess != self.endianess:
+ # Don't byteswap inplace since lazy copies cannot handle it
+ tensor = tensor.byteswap(inplace=False)
+
+ file_id = -1
+ for i, tensors in enumerate(self.tensors):
+ if len(tensors) > 0:
+ file_id = i
+ break
+
+ fout = self.fout[file_id]
+
+ # pop the first tensor info
+ # TODO: cleaner way to get the first key
+ first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
+ ti = self.tensors[file_id].pop(first_tensor_name)
+ assert ti.nbytes == tensor.nbytes
+
+ self.write_padding(fout, fout.tell())
+ tensor.tofile(fout)
+ self.write_padding(fout, tensor.nbytes)
+
+ self.state = WriterState.WEIGHTS
+
+ def write_tensors_to_file(self, *, progress: bool = False) -> None:
+ self.write_ti_data_to_file()
+
+ assert self.fout is not None
+
+ for fout in self.fout:
+ self.write_padding(fout, fout.tell())
+
+ if self.temp_file is None:
+ shard_bar = None
+ bar = None
+
+ if progress:
+ from tqdm import tqdm
+
+ total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
+
+ if len(self.fout) > 1:
+ shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
+ bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
+
+ for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
+ if shard_bar is not None:
+ shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
+ total = sum(ti.nbytes for ti in tensors.values())
+ shard_bar.reset(total=(total if total > 0 else None))
+
+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
+ for ti in tensors.values():
+ assert ti.tensor is not None # can only iterate once over the tensors
+ assert ti.tensor.nbytes == ti.nbytes
+ ti.tensor.tofile(fout)
+ if shard_bar is not None:
+ shard_bar.update(ti.nbytes)
+ if bar is not None:
+ bar.update(ti.nbytes)
+ self.write_padding(fout, ti.nbytes)
+ ti.tensor = None
+ else:
+ self.temp_file.seek(0)
+
+ shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
+ self.flush()
+ self.temp_file.close()
+
+ self.state = WriterState.WEIGHTS
+
+ def flush(self) -> None:
+ assert self.fout is not None
+ for fout in self.fout:
+ fout.flush()
+
+ def close(self) -> None:
+ if self.fout is not None:
+ for fout in self.fout:
+ fout.close()
+ self.fout = None
+
+ def add_type(self, type_name: str) -> None:
+ self.add_string(Keys.General.TYPE, type_name)
+
+ def add_architecture(self) -> None:
+ self.add_string(Keys.General.ARCHITECTURE, self.arch)
+
+ def add_quantization_version(self, quantization_version: int) -> None:
+ self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
+
+ def add_custom_alignment(self, alignment: int) -> None:
+ self.data_alignment = alignment
+ self.add_uint32(Keys.General.ALIGNMENT, alignment)
+
+ def add_file_type(self, ftype: int) -> None:
+ self.add_uint32(Keys.General.FILE_TYPE, ftype)
+
+ def add_sampling_sequence(self, sequence: str) -> None:
+ self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
+
+ def add_sampling_top_k(self, top_k: int) -> None:
+ self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
+
+ def add_sampling_top_p(self, top_p: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
+
+ def add_sampling_min_p(self, min_p: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
+
+ def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
+
+ def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
+
+ def add_sampling_temp(self, temp: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_TEMP, temp)
+
+ def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
+ self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
+
+ def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
+
+ def add_sampling_mirostat(self, mirostat: int) -> None:
+ self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
+
+ def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
+
+ def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
+ self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
+
+ def add_name(self, name: str) -> None:
+ self.add_string(Keys.General.NAME, name)
+
+ def add_author(self, author: str) -> None:
+ self.add_string(Keys.General.AUTHOR, author)
+
+ def add_version(self, version: str) -> None:
+ self.add_string(Keys.General.VERSION, version)
+
+ def add_organization(self, organization: str) -> None:
+ self.add_string(Keys.General.ORGANIZATION, organization)
+
+ def add_finetune(self, finetune: str) -> None:
+ self.add_string(Keys.General.FINETUNE, finetune)
+
+ def add_basename(self, basename: str) -> None:
+ self.add_string(Keys.General.BASENAME, basename)
+
+ def add_description(self, description: str) -> None:
+ self.add_string(Keys.General.DESCRIPTION, description)
+
+ def add_quantized_by(self, quantized: str) -> None:
+ self.add_string(Keys.General.QUANTIZED_BY, quantized)
+
+ def add_size_label(self, size_label: str) -> None:
+ self.add_string(Keys.General.SIZE_LABEL, size_label)
+
+ def add_license(self, license: str) -> None:
+ self.add_string(Keys.General.LICENSE, license)
+
+ def add_license_name(self, license: str) -> None:
+ self.add_string(Keys.General.LICENSE_NAME, license)
+
+ def add_license_link(self, license: str) -> None:
+ self.add_string(Keys.General.LICENSE_LINK, license)
+
+ def add_url(self, url: str) -> None:
+ self.add_string(Keys.General.URL, url)
+
+ def add_doi(self, doi: str) -> None:
+ self.add_string(Keys.General.DOI, doi)
+
+ def add_uuid(self, uuid: str) -> None:
+ self.add_string(Keys.General.UUID, uuid)
+
+ def add_repo_url(self, repo_url: str) -> None:
+ self.add_string(Keys.General.REPO_URL, repo_url)
+
+ def add_source_url(self, url: str) -> None:
+ self.add_string(Keys.General.SOURCE_URL, url)
+
+ def add_source_doi(self, doi: str) -> None:
+ self.add_string(Keys.General.SOURCE_DOI, doi)
+
+ def add_source_uuid(self, uuid: str) -> None:
+ self.add_string(Keys.General.SOURCE_UUID, uuid)
+
+ def add_source_repo_url(self, repo_url: str) -> None:
+ self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
+
+ def add_base_model_count(self, source_count: int) -> None:
+ self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
+
+ def add_base_model_name(self, source_id: int, name: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
+
+ def add_base_model_author(self, source_id: int, author: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
+
+ def add_base_model_version(self, source_id: int, version: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
+
+ def add_base_model_organization(self, source_id: int, organization: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
+
+ def add_base_model_description(self, source_id: int, description: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
+
+ def add_base_model_url(self, source_id: int, url: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
+
+ def add_base_model_doi(self, source_id: int, doi: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
+
+ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
+
+ def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
+ self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
+
+ def add_dataset_count(self, source_count: int) -> None:
+ self.add_uint32(Keys.General.DATASET_COUNT, source_count)
+
+ def add_dataset_name(self, source_id: int, name: str) -> None:
+ self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
+
+ def add_dataset_author(self, source_id: int, author: str) -> None:
+ self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
+
+ def add_dataset_version(self, source_id: int, version: str) -> None:
+ self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
+
+ def add_dataset_organization(self, source_id: int, organization: str) -> None:
+ self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
+
+ def add_dataset_description(self, source_id: int, description: str) -> None:
+ self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
+
+ def add_dataset_url(self, source_id: int, url: str) -> None:
+ self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
+
+ def add_dataset_doi(self, source_id: int, doi: str) -> None:
+ self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
+
+ def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
+ self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
+
+ def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
+ self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
+
+ def add_tags(self, tags: Sequence[str]) -> None:
+ self.add_array(Keys.General.TAGS, tags)
+
+ def add_languages(self, languages: Sequence[str]) -> None:
+ self.add_array(Keys.General.LANGUAGES, languages)
+
+ def add_tensor_data_layout(self, layout: str) -> None:
+ self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+
+ def add_vocab_size(self, size: int) -> None:
+ self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
+
+ def add_context_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
+
+ def add_embedding_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+ def add_embedding_length_out(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)
+
+ def add_features_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
+
+ def add_posnet_embedding_length(self, length: int) -> None:
+ self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+ def add_posnet_block_count(self, length: int) -> None:
+ self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
+
+ def add_convnext_embedding_length(self, length: int) -> None:
+ self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+ def add_convnext_block_count(self, length: int) -> None:
+ self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
+
+ def add_shortconv_l_cache(self, length: int) -> None:
+ self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length)
+
+ def add_block_count(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
+
+ def add_leading_dense_block_count(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
+
+ def add_full_attention_interval(self, interval: int) -> None:
+ self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
+
+ def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
+ if isinstance(length, int):
+ self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+ else:
+ self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
+ def add_expert_feed_forward_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
+ def add_expert_shared_feed_forward_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
+ def add_expert_chunk_feed_forward_length(self, length: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
+ def add_parallel_residual(self, use: bool) -> None:
+ self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
+
+ def add_decoder_start_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
+
+ def add_decoder_block_count(self, value: int) -> None:
+ self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
+
+ def add_embedding_length_per_layer_input(self, value: int) -> None:
+ self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
+
+ def add_altup_active_idx(self, val: int) -> None:
+ self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
+
+ def add_altup_num_inputs(self, val: int) -> None:
+ self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
+
+ def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
+ self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
+
+ def add_head_count(self, count: int | Sequence[int]) -> None:
+ if isinstance(count, int):
+ self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
+ else:
+ self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
+
+ def add_head_count_kv(self, count: int | Sequence[int]) -> None:
+ if isinstance(count, int):
+ self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+ else:
+ self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+
+ def add_key_length(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
+
+ def add_value_length(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
+
+ def add_key_length_mla(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
+
+ def add_value_length_mla(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
+
+ def add_max_alibi_bias(self, bias: float) -> None:
+ self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
+
+ def add_clamp_kqv(self, value: float) -> None:
+ self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
+
+ def add_shared_kv_layers(self, value: int) -> None:
+ self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
+
+ def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
+ key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
+ if isinstance(value, int):
+ self.add_uint32(key, value)
+ else:
+ self.add_array(key, value)
+
+ def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
+ self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
+ self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
+
+ def add_logit_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
+
+ def add_attn_logit_softcapping(self, value: float) -> None:
+ self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
+ def add_router_logit_softcapping(self, value: float) -> None:
+ self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
+ def add_final_logit_softcapping(self, value: float) -> None:
+ self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
+ def add_expert_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
+
+ def add_expert_used_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
+
+ def add_expert_shared_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
+
+ def add_expert_group_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
+
+ def add_expert_group_used_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
+
+ def add_expert_weights_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
+
+ def add_expert_weights_norm(self, value: bool) -> None:
+ self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
+
+ def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
+ self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
+
+ def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None:
+ self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values)
+
+ def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None:
+ self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values)
+
+ def add_expert_group_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
+
+ def add_experts_per_group(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
+
+ def add_moe_every_n_layers(self, value: int) -> None:
+ self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
+
+ def add_nextn_predict_layers(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
+
+ def add_swin_norm(self, value: bool) -> None:
+ self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
+
+ def add_rescale_every_n_layers(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
+
+ def add_time_mix_extra_dim(self, dim: int) -> None:
+ self.add_uint32(Keys.LLM.TIME_MIX_EXTRA_DIM.format(arch=self.arch), dim)
+
+ def add_time_decay_extra_dim(self, dim: int) -> None:
+ self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim)
+
+ def add_residual_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value)
+
+ def add_embedding_scale(self, value: float) -> None:
+ self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value)
+
+ def add_wkv_head_size(self, size: int) -> None:
+ self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
+
+ def add_token_shift_count(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
+
+ def add_interleave_moe_layer_step(self, value: int) -> None:
+ self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value)
+
+ def add_layer_norm_eps(self, value: float) -> None:
+ self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
+
+ def add_layer_norm_rms_eps(self, value: float) -> None:
+ self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
+
+ def add_group_norm_eps(self, value: float) -> None:
+ self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
+
+ def add_group_norm_groups(self, value: int) -> None:
+ self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
+
+ def add_causal_attention(self, value: bool) -> None:
+ self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
+
+ def add_q_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
+
+ def add_kv_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
+
+ def add_decay_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.DECAY_LORA_RANK.format(arch=self.arch), length)
+
+ def add_iclr_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length)
+
+ def add_value_residual_mix_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
+
+ def add_rope_freq_base_swa(self, value: float) -> None:
+ self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
+
+ def add_gate_lora_rank(self, length: int) -> None:
+ self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
+
+ def add_relative_attn_buckets_count(self, value: int) -> None:
+ self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
+
+ def add_sliding_window(self, value: int) -> None:
+ self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
+
+ def add_attention_scale(self, value: float) -> None:
+ self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
+
+ def add_attn_output_scale(self, value: float) -> None:
+ self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
+
+ def add_attn_temperature_length(self, value: int) -> None:
+ self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
+
+ def add_attn_temperature_scale(self, value: float) -> None:
+ self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
+
+ def add_pooling_type(self, value: PoolingType) -> None:
+ self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
+
+ def add_num_deepstack_layers(self, count: int) -> None:
+ self.add_uint32(Keys.LLM.NUM_DEEPSTACK_LAYERS.format(arch=self.arch), count)
+
+ def add_rope_dimension_count(self, count: int) -> None:
+ self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
+
+ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
+ self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
+
+ def add_rope_freq_base(self, value: float) -> None:
+ self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
+
+ def add_rope_scaling_type(self, value: RopeScalingType) -> None:
+ self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
+
+ def add_rope_scaling_factor(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
+
+ def add_rope_scaling_attn_factors(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
+
+ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
+ self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
+
+ def add_rope_scaling_finetuned(self, value: bool) -> None:
+ self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
+
+ def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
+
+ def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
+
+ def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
+
+ def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
+
+ def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
+ self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
+
+ def add_ssm_conv_kernel(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
+
+ def add_ssm_inner_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.INNER_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_state_size(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.STATE_SIZE.format(arch=self.arch), value)
+
+ def add_ssm_time_step_rank(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
+
+ def add_ssm_group_count(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
+
+ def add_ssm_dt_b_c_rms(self, value: bool) -> None:
+ self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
+
+ def add_kda_head_dim(self, value: int) -> None:
+ self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)
+
+ def add_tokenizer_model(self, model: str) -> None:
+ self.add_string(Keys.Tokenizer.MODEL, model)
+
+ def add_tokenizer_pre(self, pre: str) -> None:
+ self.add_string(Keys.Tokenizer.PRE, pre)
+
+ def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
+ self.add_array(Keys.Tokenizer.LIST, tokens)
+
+ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
+ self.add_array(Keys.Tokenizer.MERGES, merges)
+
+ def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
+ self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
+
+ def add_token_type_count(self, value: int) -> None:
+ self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
+
+ def add_token_scores(self, scores: Sequence[float]) -> None:
+ self.add_array(Keys.Tokenizer.SCORES, scores)
+
+ def add_bos_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.BOS_ID, id)
+
+ def add_eos_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.EOS_ID, id)
+
+ def add_unk_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.UNK_ID, id)
+
+ def add_sep_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.SEP_ID, id)
+
+ def add_pad_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.PAD_ID, id)
+
+ def add_mask_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.MASK_ID, id)
+
+ def add_add_bos_token(self, value: bool) -> None:
+ self.add_bool(Keys.Tokenizer.ADD_BOS, value)
+
+ def add_add_eos_token(self, value: bool) -> None:
+ self.add_bool(Keys.Tokenizer.ADD_EOS, value)
+
+ def add_add_sep_token(self, value: bool) -> None:
+ self.add_bool(Keys.Tokenizer.ADD_SEP, value)
+
+ def add_add_space_prefix(self, value: bool) -> None:
+ self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
+
+ def add_remove_extra_whitespaces(self, value: bool) -> None:
+ self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
+
+ def add_precompiled_charsmap(self, charsmap: bytes) -> None:
+ self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
+
+ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
+ if not isinstance(value, str):
+ template_default = None
+ template_names = set()
+
+ for choice in value:
+ name = choice.get('name', '')
+ template = choice.get('template')
+
+ # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
+ name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
+
+ if name and template is not None:
+ if name == 'default':
+ template_default = template
+ else:
+ template_names.add(name)
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
+
+ if template_names:
+ self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
+
+ if template_default is None:
+ return
+
+ value = template_default
+
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
+
+ def add_eot_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.EOT_ID, id)
+
+ def add_eom_token_id(self, id: int) -> None:
+ self.add_uint32(Keys.Tokenizer.EOM_ID, id)
+
+ def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
+ self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
+
+ # for vision models
+
+ def add_clip_has_vision_encoder(self, value: bool) -> None:
+ self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
+
+ def add_clip_has_audio_encoder(self, value: bool) -> None:
+ self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
+
+ def add_clip_projector_type(self, value: str) -> None:
+ self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
+
+ def add_clip_vision_projector_type(self, value: str) -> None:
+ self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
+
+ def add_vision_projection_dim(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
+
+ def add_vision_patch_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
+
+ def add_vision_embedding_length(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
+
+ def add_vision_feed_forward_length(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
+
+ def add_vision_block_count(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
+
+ def add_vision_head_count(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
+
+ def add_vision_attention_layernorm_eps(self, value: float) -> None:
+ self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
+
+ def add_vision_image_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
+
+ def add_vision_max_pixels(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.IMAGE_MAX_PIXELS, value)
+
+ def add_vision_min_pixels(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value)
+
+ def add_vision_preproc_image_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
+
+ def add_vision_image_mean(self, values: Sequence[float]) -> None:
+ self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
+
+ def add_vision_image_std(self, values: Sequence[float]) -> None:
+ self.add_array(Keys.ClipVision.IMAGE_STD, values)
+
+ def add_vision_spatial_merge_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
+
+ def add_vision_use_gelu(self, value: bool) -> None:
+ self.add_bool(Keys.ClipVision.USE_GELU, value)
+
+ def add_vision_use_silu(self, value: bool) -> None:
+ self.add_bool(Keys.ClipVision.USE_SILU, value)
+
+ def add_vision_projector_scale_factor(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
+
+ def add_vision_n_wa_pattern(self, value: int) -> None:
+ """Add window attention pattern interval for vision models.
+
+ This defines the pattern interval for window attention vs full attention layers.
+ For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention,
+ while other layers use window attention.
+
+ Used by models like Qwen2.5-VL where full attention layers follow a regular pattern.
+ """
+ self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
+
+ def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None:
+ """Add explicit layer indexes that use full attention in vision models.
+
+ This specifies the exact layer indices (0-based) that should use full attention
+ instead of window attention. All other layers will use window attention.
+
+ Args:
+ layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15])
+
+ Used by models like YoutuVL where full attention layers are explicitly specified
+ rather than following a regular pattern.
+
+ Difference from add_vision_n_wa_pattern:
+ - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention)
+ - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern)
+ """
+ self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers)
+
+ def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
+ self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
+
+ def add_vision_window_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
+
+ # audio models
+
+ def add_clip_audio_projector_type(self, value: str) -> None:
+ self.add_string(Keys.ClipAudio.PROJECTOR_TYPE, value)
+
+ def add_audio_projection_dim(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
+
+ def add_audio_embedding_length(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
+
+ def add_audio_feed_forward_length(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
+
+ def add_audio_block_count(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
+
+ def add_audio_head_count(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
+
+ def add_audio_attention_layernorm_eps(self, value: float) -> None:
+ self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
+
+ def add_audio_num_mel_bins(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
+
+ def add_audio_stack_factor(self, value: int) -> None:
+ self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
+
+ def add_xielu_alpha_p(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.ALPHA_P, values)
+
+ def add_xielu_alpha_n(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.ALPHA_N, values)
+
+ def add_xielu_beta(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.BETA, values)
+
+ def add_xielu_eps(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.EPS, values)
+
+ # diffusion models
+
+ def add_diffusion_shift_logits(self, value: bool) -> None:
+ self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value)
+
+ def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
+ pack_prefix = ''
+ if not skip_pack_prefix:
+ pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
+ return struct.pack(f'{pack_prefix}{fmt}', value)
+
+ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
+ kv_data = bytearray()
+
+ if add_vtype:
+ kv_data += self._pack("I", vtype)
+
+ pack_fmt = self._simple_value_packing.get(vtype)
+ if pack_fmt is not None:
+ kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
+ elif vtype == GGUFValueType.STRING:
+ encoded_val = val.encode("utf-8") if isinstance(val, str) else val
+ kv_data += self._pack("Q", len(encoded_val))
+ kv_data += encoded_val
+ elif vtype == GGUFValueType.ARRAY:
+
+ if not isinstance(val, Sequence):
+ raise ValueError("Invalid GGUF metadata array, expecting sequence")
+
+ if len(val) == 0:
+ raise ValueError("Invalid GGUF metadata array. Empty array")
+
+ if sub_type is not None:
+ ltype = sub_type
+ elif isinstance(val, bytes):
+ ltype = GGUFValueType.UINT8
+ else:
+ ltype = GGUFValueType.get_type(val[0])
+ if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
+ raise ValueError("All items in a GGUF array should be of the same type")
+ kv_data += self._pack("I", ltype)
+ kv_data += self._pack("Q", len(val))
+ for item in val:
+ kv_data += self._pack_val(item, ltype, add_vtype=False)
+ else:
+ raise ValueError("Invalid GGUF metadata value type or value")
+
+ return kv_data
+
+ @staticmethod
+ def format_n_bytes_to_str(num: int) -> str:
+ if num == 0:
+ return "negligible - metadata only"
+ fnum = float(num)
+ for unit in ("", "K", "M", "G"):
+ if abs(fnum) < 1000.0:
+ return f"{fnum:3.1f}{unit}"
+ fnum /= 1000.0
+ return f"{fnum:.1f}T - over 1TB, split recommended"
diff --git a/llama.cpp/gguf-py/gguf/lazy.py b/llama.cpp/gguf-py/gguf/lazy.py
new file mode 100644
index 0000000..c126f09
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/lazy.py
@@ -0,0 +1,228 @@
+from __future__ import annotations
+from abc import ABC, ABCMeta, abstractmethod
+
+import logging
+from typing import Any, Callable
+
+import numpy as np
+from numpy.typing import DTypeLike
+
+
+logger = logging.getLogger(__name__)
+
+
+class LazyMeta(ABCMeta):
+
+ def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
+ def __getattr__(self, name: str) -> Any:
+ meta_attr = getattr(self._meta, name)
+ if callable(meta_attr):
+ return type(self)._wrap_fn(
+ (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
+ use_self=self,
+ )
+ elif isinstance(meta_attr, self._tensor_type):
+ # e.g. self.T with torch.Tensor should still be wrapped
+ return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
+ else:
+ # no need to wrap non-tensor properties,
+ # and they likely don't depend on the actual contents of the tensor
+ return meta_attr
+
+ namespace["__getattr__"] = __getattr__
+
+ # need to make a builder for the wrapped wrapper to copy the name,
+ # or else it fails with very cryptic error messages,
+ # because somehow the same string would end up in every closures
+ def mk_wrap(op_name: str, *, meta_noop: bool = False):
+ # need to wrap the wrapper to get self
+ def wrapped_special_op(self, *args, **kwargs):
+ return type(self)._wrap_fn(
+ getattr(type(self)._tensor_type, op_name),
+ meta_noop=meta_noop,
+ )(self, *args, **kwargs)
+ return wrapped_special_op
+
+ # special methods bypass __getattr__, so they need to be added manually
+ # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
+ # NOTE: doing this from a metaclass is very convenient
+ # TODO: make this even more comprehensive
+ for binary_op in (
+ "lt", "le", "eq", "ne", "ge", "gt",
+ "add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
+ "or", "pow", "rshift", "sub", "truediv", "xor",
+ "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
+ "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
+ ):
+ attr_name = f"__{binary_op}__"
+ # evaluation on the meta tensor is needed in case there's broadcasting
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
+
+ for unary_op in ("not", "abs", "invert", "neg", "pos"):
+ attr_name = f"__{unary_op}__"
+ # the result of these operators usually has the same shape and dtype as the input,
+ # so evaluation on the meta tensor can be skipped.
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
+
+ for special_op in (
+ "getitem", "setitem", "len",
+ ):
+ attr_name = f"__{special_op}__"
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
+
+ return super().__new__(cls, name, bases, namespace, **kwargs)
+
+
+# Tree of lazy tensors
+class LazyBase(ABC, metaclass=LazyMeta):
+ _tensor_type: type
+ _meta: Any
+ _data: Any | None
+ _args: tuple
+ _kwargs: dict[str, Any]
+ _func: Callable[[Any], Any] | None
+
+ def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
+ super().__init__()
+ self._meta = meta
+ self._data = data
+ self._args = args
+ self._kwargs = kwargs if kwargs is not None else {}
+ self._func = func
+ assert self._func is not None or self._data is not None
+
+ def __init_subclass__(cls) -> None:
+ if "_tensor_type" not in cls.__dict__:
+ raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
+ return super().__init_subclass__()
+
+ @staticmethod
+ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
+ # TODO: dict and set
+ if isinstance(o, (list, tuple)):
+ L = []
+ for item in o:
+ L.append(LazyBase._recurse_apply(item, fn))
+ if isinstance(o, tuple):
+ L = tuple(L)
+ return L
+ elif isinstance(o, LazyBase):
+ return fn(o)
+ else:
+ return o
+
+ @classmethod
+ def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
+ def wrapped_fn(*args, **kwargs):
+ if kwargs is None:
+ kwargs = {}
+ args = ((use_self,) if use_self is not None else ()) + args
+
+ meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+ # TODO: maybe handle tensors in kwargs too
+
+ if isinstance(meta_noop, bool) and not meta_noop:
+ try:
+ res = fn(*meta_args, **kwargs)
+ except NotImplementedError:
+ # running some operations on PyTorch's Meta tensors can cause this exception
+ res = None
+ else:
+ # some operators don't need to actually run on the meta tensors
+ assert len(args) > 0
+ res = args[0]
+ assert isinstance(res, cls)
+ res = res._meta
+ # allow operations to override the dtype and shape
+ if meta_noop is not True:
+ if isinstance(meta_noop, tuple):
+ dtype, shape = meta_noop
+ assert callable(shape)
+ res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
+ else:
+ res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
+
+ if isinstance(res, cls._tensor_type):
+ return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
+ elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
+ # share the evaluation between lazy tuple elements
+ shared_args: list = [args, None]
+
+ def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
+ assert len(a) == 2
+ if a[1] is None:
+ a[1] = fn(*a[0], **kw)
+ return a[1][i]
+ return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
+ else:
+ del res # not needed
+ # non-tensor return likely relies on the contents of the args
+ # (e.g. the result of torch.equal)
+ eager_args = cls.to_eager(args)
+ return fn(*eager_args, **kwargs)
+ return wrapped_fn
+
+ @classmethod
+ def to_eager(cls, t: Any) -> Any:
+ def simple_to_eager(_t: LazyBase) -> Any:
+ if _t._data is not None:
+ return _t._data
+
+ # NOTE: there's a recursion limit in Python (usually 1000)
+
+ assert _t._func is not None
+ _t._args = cls._recurse_apply(_t._args, simple_to_eager)
+ _t._data = _t._func(*_t._args, **_t._kwargs)
+ # sanity check
+ assert _t._data is not None
+ assert _t._data.dtype == _t._meta.dtype
+ assert _t._data.shape == _t._meta.shape
+
+ return _t._data
+
+ # recurse into lists and/or tuples, keeping their structure
+ return cls._recurse_apply(t, simple_to_eager)
+
+ @classmethod
+ def eager_to_meta(cls, t: Any) -> Any:
+ return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
+
+ # must be overridden, meta tensor init is backend-specific
+ @classmethod
+ @abstractmethod
+ def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
+
+ @classmethod
+ def from_eager(cls, t: Any) -> Any:
+ if type(t) is cls:
+ # already lazy
+ return t
+ elif isinstance(t, cls._tensor_type):
+ return cls(meta=cls.eager_to_meta(t), data=t)
+ else:
+ return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
+
+
+class LazyNumpyTensor(LazyBase):
+ _tensor_type = np.ndarray
+
+ shape: tuple[int, ...] # Makes the type checker happy in quants.py
+
+ @classmethod
+ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
+ # The initial idea was to use np.nan as the fill value,
+ # but non-float types like np.int16 can't use that.
+ # So zero it is.
+ cheat = np.zeros(1, dtype)
+ return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
+
+ def astype(self, dtype, *args, **kwargs):
+ meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
+ full_args = (self, dtype,) + args
+ return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
+
+ def tofile(self, *args, **kwargs):
+ eager = LazyNumpyTensor.to_eager(self)
+ return eager.tofile(*args, **kwargs)
+
+ # TODO: __array_function__
diff --git a/llama.cpp/gguf-py/gguf/metadata.py b/llama.cpp/gguf-py/gguf/metadata.py
new file mode 100644
index 0000000..e0d478c
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/metadata.py
@@ -0,0 +1,731 @@
+from __future__ import annotations
+
+import re
+import json
+import yaml
+import logging
+from pathlib import Path
+from typing import Any, Literal, Optional
+from dataclasses import dataclass
+
+from .constants import Keys
+
+import gguf
+
+logger = logging.getLogger("metadata")
+
+
+@dataclass
+class Metadata:
+ # Recommended Sampler Parameters to be written to GGUF KV Store
+ sampling_sequence: Optional[str] = None
+ sampling_top_k: Optional[int] = None
+ sampling_top_p: Optional[float] = None
+ sampling_min_p: Optional[float] = None
+ sampling_xtc_probability: Optional[float] = None
+ sampling_xtc_threshold: Optional[float] = None
+ sampling_temp: Optional[float] = None
+ sampling_penalty_last_n: Optional[int] = None
+ sampling_penalty_repeat: Optional[float] = None
+ sampling_mirostat: Optional[int] = None
+ sampling_mirostat_tau: Optional[float] = None
+ sampling_mirostat_eta: Optional[float] = None
+
+ # Authorship Metadata to be written to GGUF KV Store
+ name: Optional[str] = None
+ author: Optional[str] = None
+ version: Optional[str] = None
+ organization: Optional[str] = None
+ finetune: Optional[str] = None
+ basename: Optional[str] = None
+ description: Optional[str] = None
+ quantized_by: Optional[str] = None
+ size_label: Optional[str] = None
+ url: Optional[str] = None
+ doi: Optional[str] = None
+ uuid: Optional[str] = None
+ repo_url: Optional[str] = None
+ source_url: Optional[str] = None
+ source_doi: Optional[str] = None
+ source_uuid: Optional[str] = None
+ source_repo_url: Optional[str] = None
+ license: Optional[str] = None
+ license_name: Optional[str] = None
+ license_link: Optional[str] = None
+ base_models: Optional[list[dict]] = None
+ tags: Optional[list[str]] = None
+ languages: Optional[list[str]] = None
+ datasets: Optional[list[dict]] = None
+
+ @staticmethod
+ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
+ # This grabs as many contextual authorship metadata as possible from the model repository
+ # making any conversion as required to match the gguf kv store metadata format
+ # as well as giving users the ability to override any authorship metadata that may be incorrect
+
+ # Create a new Metadata instance
+ metadata = Metadata()
+
+ model_card = Metadata.load_model_card(model_path)
+ hf_params = Metadata.load_hf_parameters(model_path)
+ gen_config = Metadata.load_generation_config(model_path)
+ # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
+
+ # heuristics
+ metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
+
+ if gen_config:
+ metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
+ metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
+ metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
+ metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
+ metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
+ metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
+ metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
+ metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
+ metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
+ metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
+ metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
+ metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
+
+ # Metadata Override File Provided
+ # This is based on LLM_KV_NAMES mapping in llama.cpp
+ metadata_override = Metadata.load_metadata_override(metadata_override_path)
+
+ metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
+ metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
+ metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
+ metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
+ metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
+ metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
+ metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
+ metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
+ metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
+ metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
+ metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
+ metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
+
+ metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
+ metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
+ metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
+ metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
+
+ metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
+ metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
+
+ metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
+ metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
+
+ metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
+ metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
+ metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
+
+ metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
+ metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
+ metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
+ metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
+
+ metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
+ metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
+ metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
+ metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
+
+ # Base Models is received here as an array of models
+ metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
+
+ # Datasets is received here as an array of datasets
+ metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
+
+ metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
+ metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
+
+ # Direct Metadata Override (via direct cli argument)
+ if model_name is not None:
+ metadata.name = model_name
+
+ return metadata
+
+ @staticmethod
+ def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
+ if metadata_override_path is None or not metadata_override_path.is_file():
+ return {}
+
+ with open(metadata_override_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ @staticmethod
+ def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
+ if model_path is None or not model_path.is_dir():
+ return {}
+
+ model_card_path = model_path / "README.md"
+
+ if not model_card_path.is_file():
+ return {}
+
+ # The model card metadata is assumed to always be in YAML (frontmatter)
+ # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
+ yaml_content: str = ""
+ with open(model_card_path, "r", encoding="utf-8") as f:
+ content = f.read()
+ lines = content.splitlines()
+ lines_yaml = []
+ if len(lines) == 0:
+ # Empty file
+ return {}
+ if len(lines) > 0 and lines[0] != "---":
+ # No frontmatter
+ return {}
+ for line in lines[1:]:
+ if line == "---":
+ break # End of frontmatter
+ else:
+ lines_yaml.append(line)
+ yaml_content = "\n".join(lines_yaml) + "\n"
+
+ # Quick hack to fix the Norway problem
+ # https://hitchdev.com/strictyaml/why/implicit-typing-removed/
+ yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
+ # yaml should use 2 spaces insted of tab
+ # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
+ # (I've also sent a pr tp fix the modelcard too)
+ yaml_content = yaml_content.replace("\t", " ")
+
+ if yaml_content:
+ data = yaml.safe_load(yaml_content)
+ if isinstance(data, dict):
+ return data
+ else:
+ logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
+ return {}
+ else:
+ return {}
+
+ @staticmethod
+ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
+ if model_path is None or not model_path.is_dir():
+ return {}
+
+ config_path = model_path / "config.json"
+
+ if not config_path.is_file():
+ return {}
+
+ with open(config_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ @staticmethod
+ def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
+ if model_path is None or not model_path.is_dir():
+ return {}
+
+ generation_config_path = model_path / "generation_config.json"
+
+ if not generation_config_path.is_file():
+ return {}
+
+ try:
+ with open(generation_config_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+ except (json.JSONDecodeError, IOError):
+ # not all models have valid generation_config.json
+ return {}
+
+ @staticmethod
+ def id_to_title(string):
+ # Convert capitalization into title form unless acronym or version number
+ return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
+
+ @staticmethod
+ def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
+ # Huggingface often store model id as '<org>/<model name>'
+ # so let's parse it and apply some heuristics if possible for model name components
+
+ if model_id is None:
+ # model ID missing
+ return None, None, None, None, None, None
+
+ if ' ' in model_id:
+ # model ID is actually a normal human sentence
+ # which means its most likely a normal model name only
+ # not part of the hugging face naming standard, but whatever
+ return model_id, None, None, None, None, None
+
+ if '/' in model_id:
+ # model ID (huggingface style)
+ org_component, model_full_name_component = model_id.split('/', 1)
+ else:
+ # model ID but missing org components
+ org_component, model_full_name_component = None, model_id
+
+ # Check if we erroneously matched against './' or '../' etc...
+ if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
+ org_component = None
+
+ name_parts: list[str] = model_full_name_component.split('-')
+
+ # Remove empty parts
+ for i in reversed(range(len(name_parts))):
+ if len(name_parts[i]) == 0:
+ del name_parts[i]
+
+ name_types: list[
+ set[Literal["basename", "size_label", "finetune", "version", "type"]]
+ ] = [set() for _ in name_parts]
+
+ # Annotate the name
+ for i, part in enumerate(name_parts):
+ # Version
+ if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
+ name_types[i].add("version")
+ # Quant type (should not be there for base models, but still annotated)
+ elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
+ name_types[i].add("type")
+ name_parts[i] = part.upper()
+ # Model size
+ elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
+ part = part.replace("_", ".")
+ # Handle weird bloom-7b1 notation
+ if part[-1].isdecimal():
+ part = part[:-2] + "." + part[-1] + part[-2]
+ # Normalize the size suffixes
+ if len(part) > 1 and part[-2].isdecimal():
+ if part[-1] in "kmbt":
+ part = part[:-1] + part[-1].upper()
+ if total_params != 0:
+ try:
+ label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
+ # Only use it as a size label if it's close or bigger than the model size
+ # Note that LoRA adapters don't necessarily include all layers,
+ # so this is why bigger label sizes are accepted.
+ # Do not use the size label when it's smaller than 1/8 of the model size
+ if (total_params < 0 and label_params < abs(total_params) // 8) or (
+ # Check both directions when the current model isn't a LoRA adapter
+ total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
+ ):
+ # Likely a context length
+ name_types[i].add("finetune")
+ # Lowercase the size when it's a context length
+ part = part[:-1] + part[-1].lower()
+ except ValueError:
+ # Failed to convert the size label to float, use it anyway
+ pass
+ if len(name_types[i]) == 0:
+ name_types[i].add("size_label")
+ name_parts[i] = part
+ # Some easy to recognize finetune names
+ elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
+ if total_params < 0 and part.lower() == "lora":
+ # ignore redundant "lora" in the finetune part when the output is a lora adapter
+ name_types[i].add("type")
+ else:
+ name_types[i].add("finetune")
+
+ # Ignore word-based size labels when there is at least a number-based one present
+ # TODO: should word-based size labels always be removed instead?
+ if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
+ for n, t in zip(name_parts, name_types):
+ if "size_label" in t:
+ if all(c.isalpha() for c in n):
+ t.remove("size_label")
+
+ at_start = True
+ # Find the basename through the annotated name
+ for part, t in zip(name_parts, name_types):
+ if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
+ t.add("basename")
+ else:
+ if at_start:
+ at_start = False
+ if len(t) == 0:
+ t.add("finetune")
+
+ # Remove the basename annotation from trailing version
+ for part, t in zip(reversed(name_parts), reversed(name_types)):
+ if "basename" in t and len(t) > 1:
+ t.remove("basename")
+ else:
+ break
+
+ basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
+ # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
+ size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
+ finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
+ # TODO: should the basename version always be excluded?
+ # NOTE: multiple finetune versions are joined together
+ version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
+
+ if size_label is None and finetune is None and version is None:
+ # Too ambiguous, output nothing
+ basename = None
+
+ return model_full_name_component, org_component, basename, finetune, version, size_label
+
+ @staticmethod
+ def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
+ # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+
+ # Model Card Heuristics
+ ########################
+ if model_card is not None:
+
+ def use_model_card_metadata(metadata_key: str, model_card_key: str):
+ if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
+ setattr(metadata, metadata_key, model_card.get(model_card_key))
+
+ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
+ # Note: Will append rather than replace if already exist
+ tags_value = model_card.get(model_card_key, None)
+ if tags_value is None:
+ return
+
+ current_value = getattr(metadata, metadata_key, None)
+ if current_value is None:
+ current_value = []
+
+ if isinstance(tags_value, str):
+ current_value.append(tags_value)
+ elif isinstance(tags_value, list):
+ current_value.extend(tags_value)
+
+ setattr(metadata, metadata_key, current_value)
+
+ # LLAMA.cpp's direct internal convention
+ # (Definitely not part of hugging face formal/informal standard)
+ #########################################
+ use_model_card_metadata("name", "name")
+ use_model_card_metadata("author", "author")
+ use_model_card_metadata("version", "version")
+ use_model_card_metadata("organization", "organization")
+ use_model_card_metadata("description", "description")
+ use_model_card_metadata("finetune", "finetune")
+ use_model_card_metadata("basename", "basename")
+ use_model_card_metadata("size_label", "size_label")
+ use_model_card_metadata("source_url", "url")
+ use_model_card_metadata("source_doi", "doi")
+ use_model_card_metadata("source_uuid", "uuid")
+ use_model_card_metadata("source_repo_url", "repo_url")
+
+ # LLAMA.cpp's huggingface style convention
+ # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
+ ###########################################
+ use_model_card_metadata("name", "model_name")
+ use_model_card_metadata("author", "model_author")
+ use_model_card_metadata("version", "model_version")
+ use_model_card_metadata("organization", "model_organization")
+ use_model_card_metadata("description", "model_description")
+ use_model_card_metadata("finetune", "model_finetune")
+ use_model_card_metadata("basename", "model_basename")
+ use_model_card_metadata("size_label", "model_size_label")
+ use_model_card_metadata("source_url", "model_url")
+ use_model_card_metadata("source_doi", "model_doi")
+ use_model_card_metadata("source_uuid", "model_uuid")
+ use_model_card_metadata("source_repo_url", "model_repo_url")
+
+ # Hugging Face Direct Convention
+ #################################
+
+ # Not part of huggingface model card standard but notice some model creator using it
+ # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+ use_model_card_metadata("name", "model_name")
+ use_model_card_metadata("author", "model_creator")
+ use_model_card_metadata("basename", "model_type")
+
+ if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
+ # This represents the parent models that this is based on
+ # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
+ # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
+ metadata_base_models = []
+ base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
+
+ if base_model_value is not None:
+ if isinstance(base_model_value, str):
+ metadata_base_models.append(base_model_value)
+ elif isinstance(base_model_value, list):
+ metadata_base_models.extend(base_model_value)
+
+ if metadata.base_models is None:
+ metadata.base_models = []
+
+ for model_id in metadata_base_models:
+ # NOTE: model size of base model is assumed to be similar to the size of the current model
+ base_model = {}
+ if isinstance(model_id, str):
+ if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
+ base_model["repo_url"] = model_id
+
+ # Check if Hugging Face ID is present in URL
+ if "huggingface.co" in model_id:
+ match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
+ if match:
+ model_id_component = match.group(1)
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
+
+ # Populate model dictionary with extracted components
+ if model_full_name_component is not None:
+ base_model["name"] = Metadata.id_to_title(model_full_name_component)
+ if org_component is not None:
+ base_model["organization"] = Metadata.id_to_title(org_component)
+ if version is not None:
+ base_model["version"] = version
+
+ else:
+ # Likely a Hugging Face ID
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+
+ # Populate model dictionary with extracted components
+ if model_full_name_component is not None:
+ base_model["name"] = Metadata.id_to_title(model_full_name_component)
+ if org_component is not None:
+ base_model["organization"] = Metadata.id_to_title(org_component)
+ if version is not None:
+ base_model["version"] = version
+ if org_component is not None and model_full_name_component is not None:
+ base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
+
+ elif isinstance(model_id, dict):
+ base_model = model_id
+
+ else:
+ logger.error(f"base model entry '{str(model_id)}' not in a known format")
+
+ metadata.base_models.append(base_model)
+
+ if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
+ # This represents the datasets that this was trained from
+ metadata_datasets = []
+ dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
+
+ if dataset_value is not None:
+ if isinstance(dataset_value, str):
+ metadata_datasets.append(dataset_value)
+ elif isinstance(dataset_value, list):
+ metadata_datasets.extend(dataset_value)
+
+ if metadata.datasets is None:
+ metadata.datasets = []
+
+ for dataset_id in metadata_datasets:
+ # NOTE: model size of base model is assumed to be similar to the size of the current model
+ dataset = {}
+ if isinstance(dataset_id, str):
+ if dataset_id.startswith(("http://", "https://", "ssh://")):
+ dataset["repo_url"] = dataset_id
+
+ # Check if Hugging Face ID is present in URL
+ if "huggingface.co" in dataset_id:
+ match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
+ if match:
+ dataset_id_component = match.group(1)
+ dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
+
+ # Populate dataset dictionary with extracted components
+ if dataset_name_component is not None:
+ dataset["name"] = Metadata.id_to_title(dataset_name_component)
+ if org_component is not None:
+ dataset["organization"] = Metadata.id_to_title(org_component)
+ if version is not None:
+ dataset["version"] = version
+
+ else:
+ # Likely a Hugging Face ID
+ dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
+
+ # Populate dataset dictionary with extracted components
+ if dataset_name_component is not None:
+ dataset["name"] = Metadata.id_to_title(dataset_name_component)
+ if org_component is not None:
+ dataset["organization"] = Metadata.id_to_title(org_component)
+ if version is not None:
+ dataset["version"] = version
+ if org_component is not None and dataset_name_component is not None:
+ dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
+
+ elif isinstance(dataset_id, dict):
+ dataset = dataset_id
+
+ else:
+ logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
+
+ metadata.datasets.append(dataset)
+
+ use_model_card_metadata("license", "license")
+ use_model_card_metadata("license_name", "license_name")
+ use_model_card_metadata("license_link", "license_link")
+
+ use_array_model_card_metadata("tags", "tags")
+ use_array_model_card_metadata("tags", "pipeline_tag")
+
+ use_array_model_card_metadata("languages", "languages")
+ use_array_model_card_metadata("languages", "language")
+
+ # Hugging Face Parameter Heuristics
+ ####################################
+
+ if hf_params is not None:
+
+ hf_name_or_path = hf_params.get("_name_or_path")
+ if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
+ # Use _name_or_path only if its actually a model name and not some computer path
+ # e.g. 'meta-llama/Llama-2-7b-hf'
+ model_id = hf_name_or_path
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+ if metadata.name is None and model_full_name_component is not None:
+ metadata.name = Metadata.id_to_title(model_full_name_component)
+ if metadata.organization is None and org_component is not None:
+ metadata.organization = Metadata.id_to_title(org_component)
+ if metadata.basename is None and basename is not None:
+ metadata.basename = basename
+ if metadata.finetune is None and finetune is not None:
+ metadata.finetune = finetune
+ if metadata.version is None and version is not None:
+ metadata.version = version
+ if metadata.size_label is None and size_label is not None:
+ metadata.size_label = size_label
+
+ # Directory Folder Name Fallback Heuristics
+ ############################################
+ if model_path is not None:
+ model_id = model_path.name
+ model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+ if metadata.name is None and model_full_name_component is not None:
+ metadata.name = Metadata.id_to_title(model_full_name_component)
+ if metadata.organization is None and org_component is not None:
+ metadata.organization = Metadata.id_to_title(org_component)
+ if metadata.basename is None and basename is not None:
+ metadata.basename = basename
+ if metadata.finetune is None and finetune is not None:
+ metadata.finetune = finetune
+ if metadata.version is None and version is not None:
+ metadata.version = version
+ if metadata.size_label is None and size_label is not None:
+ metadata.size_label = size_label
+
+ return metadata
+
+ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
+ assert self.name is not None
+
+ if self.sampling_sequence is not None:
+ gguf_writer.add_sampling_sequence(self.sampling_sequence)
+ if self.sampling_top_k is not None:
+ gguf_writer.add_sampling_top_k(self.sampling_top_k)
+ if self.sampling_top_p is not None:
+ gguf_writer.add_sampling_top_p(self.sampling_top_p)
+ if self.sampling_min_p is not None:
+ gguf_writer.add_sampling_min_p(self.sampling_min_p)
+ if self.sampling_xtc_probability is not None:
+ gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
+ if self.sampling_xtc_threshold is not None:
+ gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
+ if self.sampling_temp is not None:
+ gguf_writer.add_sampling_temp(self.sampling_temp)
+ if self.sampling_penalty_last_n is not None:
+ gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
+ if self.sampling_penalty_repeat is not None:
+ gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
+ if self.sampling_mirostat is not None:
+ gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
+ if self.sampling_mirostat_tau is not None:
+ gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
+ if self.sampling_mirostat_eta is not None:
+ gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
+
+ gguf_writer.add_name(self.name)
+
+ if self.author is not None:
+ gguf_writer.add_author(self.author)
+ if self.version is not None:
+ gguf_writer.add_version(self.version)
+ if self.organization is not None:
+ gguf_writer.add_organization(self.organization)
+
+ if self.finetune is not None:
+ gguf_writer.add_finetune(self.finetune)
+ if self.basename is not None:
+ gguf_writer.add_basename(self.basename)
+
+ if self.description is not None:
+ gguf_writer.add_description(self.description)
+ if self.quantized_by is not None:
+ gguf_writer.add_quantized_by(self.quantized_by)
+
+ if self.size_label is not None:
+ gguf_writer.add_size_label(self.size_label)
+
+ if self.license is not None:
+ if isinstance(self.license, list):
+ gguf_writer.add_license(",".join(self.license))
+ else:
+ gguf_writer.add_license(self.license)
+ if self.license_name is not None:
+ gguf_writer.add_license_name(self.license_name)
+ if self.license_link is not None:
+ gguf_writer.add_license_link(self.license_link)
+
+ if self.url is not None:
+ gguf_writer.add_url(self.url)
+ if self.doi is not None:
+ gguf_writer.add_doi(self.doi)
+ if self.uuid is not None:
+ gguf_writer.add_uuid(self.uuid)
+ if self.repo_url is not None:
+ gguf_writer.add_repo_url(self.repo_url)
+
+ if self.source_url is not None:
+ gguf_writer.add_source_url(self.source_url)
+ if self.source_doi is not None:
+ gguf_writer.add_source_doi(self.source_doi)
+ if self.source_uuid is not None:
+ gguf_writer.add_source_uuid(self.source_uuid)
+ if self.source_repo_url is not None:
+ gguf_writer.add_source_repo_url(self.source_repo_url)
+
+ if self.base_models is not None:
+ gguf_writer.add_base_model_count(len(self.base_models))
+ for key, base_model_entry in enumerate(self.base_models):
+ if "name" in base_model_entry:
+ gguf_writer.add_base_model_name(key, base_model_entry["name"])
+ if "author" in base_model_entry:
+ gguf_writer.add_base_model_author(key, base_model_entry["author"])
+ if "version" in base_model_entry:
+ gguf_writer.add_base_model_version(key, base_model_entry["version"])
+ if "organization" in base_model_entry:
+ gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
+ if "description" in base_model_entry:
+ gguf_writer.add_base_model_description(key, base_model_entry["description"])
+ if "url" in base_model_entry:
+ gguf_writer.add_base_model_url(key, base_model_entry["url"])
+ if "doi" in base_model_entry:
+ gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
+ if "uuid" in base_model_entry:
+ gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
+ if "repo_url" in base_model_entry:
+ gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
+
+ if self.datasets is not None:
+ gguf_writer.add_dataset_count(len(self.datasets))
+ for key, dataset_entry in enumerate(self.datasets):
+ if "name" in dataset_entry:
+ gguf_writer.add_dataset_name(key, dataset_entry["name"])
+ if "author" in dataset_entry:
+ gguf_writer.add_dataset_author(key, dataset_entry["author"])
+ if "version" in dataset_entry:
+ gguf_writer.add_dataset_version(key, dataset_entry["version"])
+ if "organization" in dataset_entry:
+ gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
+ if "description" in dataset_entry:
+ gguf_writer.add_dataset_description(key, dataset_entry["description"])
+ if "url" in dataset_entry:
+ gguf_writer.add_dataset_url(key, dataset_entry["url"])
+ if "doi" in dataset_entry:
+ gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
+ if "uuid" in dataset_entry:
+ gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
+ if "repo_url" in dataset_entry:
+ gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
+
+ if self.tags is not None:
+ gguf_writer.add_tags(self.tags)
+ if self.languages is not None:
+ gguf_writer.add_languages(self.languages)
diff --git a/llama.cpp/gguf-py/gguf/py.typed b/llama.cpp/gguf-py/gguf/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/py.typed
diff --git a/llama.cpp/gguf-py/gguf/quants.py b/llama.cpp/gguf-py/gguf/quants.py
new file mode 100644
index 0000000..31845ea
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/quants.py
@@ -0,0 +1,1318 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Sequence
+from math import log2, ceil
+
+from numpy.typing import DTypeLike
+
+from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
+from .lazy import LazyNumpyTensor
+
+import numpy as np
+
+
+def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+ if shape[-1] % block_size != 0:
+ raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
+ return (*shape[:-1], shape[-1] // block_size * type_size)
+
+
+def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+ if shape[-1] % type_size != 0:
+ raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
+ return (*shape[:-1], shape[-1] // type_size * block_size)
+
+
+# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
+def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
+ rows = arr.reshape((-1, arr.shape[-1]))
+ osize = 1
+ for dim in oshape:
+ osize *= dim
+ out = np.empty(shape=osize, dtype=otype)
+ # compute over groups of 16 rows (arbitrary, but seems good for performance)
+ n_groups = (rows.shape[0] // 16) or 1
+ np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out)
+ return out.reshape(oshape)
+
+
+# round away from zero
+# ref: https://stackoverflow.com/a/59143326/22827863
+def np_roundf(n: np.ndarray) -> np.ndarray:
+ a = abs(n)
+ floored = np.floor(a)
+ b = floored + np.floor(2 * (a - floored))
+ return np.sign(n) * b
+
+
+class QuantError(Exception): ...
+
+
+_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
+
+
+def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ if qtype == GGMLQuantizationType.F32:
+ return data.astype(np.float32, copy=False)
+ elif qtype == GGMLQuantizationType.F16:
+ return data.astype(np.float16, copy=False)
+ elif (q := _type_traits.get(qtype)) is not None:
+ return q.quantize(data)
+ else:
+ raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
+
+
+def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ if qtype == GGMLQuantizationType.F32:
+ return data.view(np.float32)
+ elif qtype == GGMLQuantizationType.F16:
+ return data.view(np.float16).astype(np.float32)
+ elif (q := _type_traits.get(qtype)) is not None:
+ return q.dequantize(data)
+ else:
+ raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
+
+
+class __Quant(ABC):
+ qtype: GGMLQuantizationType
+ block_size: int
+ type_size: int
+
+ grid: np.ndarray[Any, np.dtype[np.float32]] | None = None
+ grid_shape: tuple[int, int] = (0, 0)
+ grid_map: tuple[int | float, ...] = ()
+ grid_hex: bytes | None = None
+
+ def __init__(self):
+ return TypeError("Quant conversion classes can't have instances")
+
+ def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
+ cls.qtype = qtype
+ cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
+ cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
+ cls.__quantize_array,
+ meta_noop=(np.uint8, cls.__shape_to_bytes)
+ )
+ cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
+ cls.__dequantize_array,
+ meta_noop=(np.float32, cls.__shape_from_bytes)
+ )
+ assert qtype not in _type_traits
+ _type_traits[qtype] = cls
+
+ @classmethod
+ def init_grid(cls):
+ if cls.grid is not None or cls.grid_hex is None:
+ return
+
+ bits_per_elem = ceil(log2(len(cls.grid_map)))
+ assert bits_per_elem != 0, cls.qtype.name
+ elems_per_byte = 8 // bits_per_elem
+
+ grid = np.frombuffer(cls.grid_hex, dtype=np.uint8)
+ # decode hexadecimal chars from grid
+ grid = grid.reshape((-1, 2))
+ grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2))
+ grid = grid[..., 0] | grid[..., 1]
+ # unpack the grid values
+ grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte))
+ grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1))
+ grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1))
+ grid = np.take_along_axis(grid_map, grid, axis=-1)
+ cls.grid = grid.reshape((1, 1, *cls.grid_shape))
+
+ @classmethod
+ @abstractmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ raise NotImplementedError
+
+ @classmethod
+ def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+ rows = rows.astype(np.float32, copy=False)
+ shape = rows.shape
+ n_blocks = rows.size // cls.block_size
+ blocks = rows.reshape((n_blocks, cls.block_size))
+ blocks = cls.quantize_blocks(blocks)
+ assert blocks.dtype == np.uint8
+ assert blocks.shape[-1] == cls.type_size
+ return blocks.reshape(cls.__shape_to_bytes(shape))
+
+ @classmethod
+ def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+ rows = rows.view(np.uint8)
+ shape = rows.shape
+ n_blocks = rows.size // cls.type_size
+ blocks = rows.reshape((n_blocks, cls.type_size))
+ blocks = cls.dequantize_blocks(blocks)
+ assert blocks.dtype == np.float32
+ assert blocks.shape[-1] == cls.block_size
+ return blocks.reshape(cls.__shape_from_bytes(shape))
+
+ @classmethod
+ def __shape_to_bytes(cls, shape: Sequence[int]):
+ return quant_shape_to_byte_shape(shape, cls.qtype)
+
+ @classmethod
+ def __shape_from_bytes(cls, shape: Sequence[int]):
+ return quant_shape_from_byte_shape(shape, cls.qtype)
+
+ @classmethod
+ def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
+ return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
+
+ @classmethod
+ def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
+ cls.init_grid()
+ return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
+
+ @classmethod
+ def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+ pass
+
+ @classmethod
+ def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+ pass
+
+ @classmethod
+ def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
+ return tensor.shape[-1] % cls.block_size == 0
+
+ @classmethod
+ def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+ if not cls.can_quantize(tensor):
+ raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
+ if isinstance(tensor, LazyNumpyTensor):
+ return cls.__quantize_lazy(tensor)
+ else:
+ return cls.__quantize_array(tensor)
+
+ @classmethod
+ def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+ if isinstance(tensor, LazyNumpyTensor):
+ return cls.__dequantize_lazy(tensor)
+ else:
+ return cls.__dequantize_array(tensor)
+
+
+class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
+ @classmethod
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n = blocks.view(np.uint32)
+ # force nan to quiet
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
+ # round to nearest even
+ n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
+ return n.astype(np.uint16).view(np.uint8)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
+
+
+class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
+ max = np.take_along_axis(blocks, imax, axis=-1)
+
+ d = max / -8
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
+
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
+
+ d = d.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([d, qs], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, qs = np.hsplit(blocks, [2])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8)
+
+ return (d * qs.astype(np.float32))
+
+
+class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ max = blocks.max(axis=-1, keepdims=True)
+ min = blocks.min(axis=-1, keepdims=True)
+
+ d = (max - min) / 15
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
+
+ qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
+ qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
+
+ d = d.astype(np.float16).view(np.uint8)
+ m = min.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([d, m, qs], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ m, qs = np.hsplit(rest, [2])
+
+ d = d.view(np.float16).astype(np.float32)
+ m = m.view(np.float16).astype(np.float32)
+
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32)
+
+ return (d * qs) + m
+
+
+class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ imax = abs(blocks).argmax(axis=-1, keepdims=True)
+ max = np.take_along_axis(blocks, imax, axis=-1)
+
+ d = max / -16
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
+
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
+
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
+
+ d = d.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([d, qh, qs], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qh, qs = np.hsplit(rest, [4])
+
+ d = d.view(np.float16).astype(np.float32)
+ qh = qh.view(np.uint32)
+
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
+
+ qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16)
+
+ return (d * qs.astype(np.float32))
+
+
+class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ max = blocks.max(axis=-1, keepdims=True)
+ min = blocks.min(axis=-1, keepdims=True)
+
+ d = (max - min) / 31
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
+
+ qs = q.reshape((n_blocks, 2, cls.block_size // 2))
+ qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
+
+ qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4)
+
+ d = d.astype(np.float16).view(np.uint8)
+ m = min.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([d, m, qh, qs], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ m, rest = np.hsplit(rest, [2])
+ qh, qs = np.hsplit(rest, [4])
+
+ d = d.view(np.float16).astype(np.float32)
+ m = m.view(np.float16).astype(np.float32)
+ qh = qh.view(np.uint32)
+
+ qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32))
+ ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qh = (qh & np.uint32(0x01)).astype(np.uint8)
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1))
+
+ qs = (ql | (qh << np.uint8(4))).astype(np.float32)
+
+ return (d * qs) + m
+
+
+class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
+ @classmethod
+ # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+
+ d = abs(blocks).max(axis=1, keepdims=True) / 127
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np_roundf(blocks * id)
+
+ # (n_blocks, 2)
+ d = d.astype(np.float16).view(np.uint8)
+ # (n_blocks, block_size)
+ qs = qs.astype(np.int8).view(np.uint8)
+
+ return np.concatenate([d, qs], axis=1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ d, x = np.split(blocks, [2], axis=1)
+ d = d.view(np.float16).astype(np.float32)
+ x = x.view(np.int8).astype(np.float32)
+
+ return (x * d)
+
+
+class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ scales, rest = np.hsplit(blocks, [QK_K // 16])
+ qs, rest = np.hsplit(rest, [QK_K // 4])
+ d, dmin = np.hsplit(rest, [2])
+
+ d = d.view(np.float16).astype(np.float32)
+ dmin = dmin.view(np.float16).astype(np.float32)
+
+ # (n_blocks, 16, 1)
+ dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
+ ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1))
+
+ shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
+
+ qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & np.uint8(3)
+
+ qs = qs.reshape((n_blocks, QK_K // 16, 16)).astype(np.float32)
+
+ qs = dl * qs - ml
+
+ return qs.reshape((n_blocks, -1))
+
+
+class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ hmask, rest = np.hsplit(blocks, [QK_K // 8])
+ qs, rest = np.hsplit(rest, [QK_K // 4])
+ scales, d = np.hsplit(rest, [12])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ # The scales are packed at 6-bit each in this pattern:
+ # 0: IIIIAAAA
+ # 1: JJJJBBBB
+ # 2: KKKKCCCC
+ # 3: LLLLDDDD
+ # 4: MMMMEEEE
+ # 5: NNNNFFFF
+ # 6: OOOOGGGG
+ # 7: PPPPHHHH
+ # 8: MMIIEEAA
+ # 9: NNJJFFBB
+ # 10: OOKKGGCC
+ # 11: PPLLHHDD
+ lscales, hscales = np.hsplit(scales, [8])
+ lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
+ lscales = lscales.reshape((n_blocks, 16))
+ hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1))
+ hscales = hscales.reshape((n_blocks, 16))
+ scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4))
+ scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32)
+
+ dl = (d * scales).reshape((n_blocks, 16, 1))
+
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
+ ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3)
+ qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1))
+ qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1
+ q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32)
+
+ return (dl * q).reshape((n_blocks, QK_K))
+
+
+class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
+ K_SCALE_SIZE = 12
+
+ @staticmethod
+ def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ n_blocks = scales.shape[0]
+ scales = scales.view(np.uint8)
+ ### Unpacking the following: ###
+ # 0 EEAAAAAA
+ # 1 FFBBBBBB
+ # 2 GGCCCCCC
+ # 3 HHDDDDDD
+ # 4 eeaaaaaa
+ # 5 ffbbbbbb
+ # 6 ggcccccc
+ # 7 hhdddddd
+ # 8 eeeeEEEE
+ # 9 ffffFFFF
+ # 10 ggggGGGG
+ # 11 hhhhHHHH
+ scales = scales.reshape((n_blocks, 3, 4))
+ d, m, m_d = np.split(scales, 3, axis=-2)
+
+ sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1)
+ min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1)
+
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ dmin, rest = np.hsplit(rest, [2])
+ scales, qs = np.hsplit(rest, [cls.K_SCALE_SIZE])
+
+ d = d.view(np.float16).astype(np.float32)
+ dmin = dmin.view(np.float16).astype(np.float32)
+
+ sc, m = Q4_K.get_scale_min(scales)
+
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
+
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32)
+
+ return (d * qs - dm).reshape((n_blocks, QK_K))
+
+
+class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ dmin, rest = np.hsplit(rest, [2])
+ scales, rest = np.hsplit(rest, [Q4_K.K_SCALE_SIZE])
+ qh, qs = np.hsplit(rest, [QK_K // 8])
+
+ d = d.view(np.float16).astype(np.float32)
+ dmin = dmin.view(np.float16).astype(np.float32)
+
+ sc, m = Q4_K.get_scale_min(scales)
+
+ d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
+ dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
+
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1))
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
+ qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32))
+ q = (ql | (qh << np.uint8(4))).astype(np.float32)
+
+ return (d * q - dm).reshape((n_blocks, QK_K))
+
+
+class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ ql, rest = np.hsplit(blocks, [QK_K // 2])
+ qh, rest = np.hsplit(rest, [QK_K // 4])
+ scales, d = np.hsplit(rest, [QK_K // 16])
+
+ scales = scales.view(np.int8).astype(np.float32)
+ d = d.view(np.float16).astype(np.float32)
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
+
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32))
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32))
+ q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32)
+ q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32)
+
+ return (d * q).reshape((n_blocks, QK_K))
+
+
+class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d = abs(blocks).max(axis=-1, keepdims=True)
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np_roundf(blocks * id)
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
+
+ qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):]
+ qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
+ qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1))
+ qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1))
+ qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1))
+ qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qh = np.sum(qh, axis=-2).reshape((n_blocks, -1))
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
+ qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243
+
+ qs = qs.astype(np.uint8)
+ d = d.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([qs, d], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5])
+ qh, d = np.hsplit(rest, [QK_K // 64])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ qs0, qs1 = qs[..., :32], qs[..., 32:]
+ qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
+ qs0 = qs0.reshape((n_blocks, -1))
+ qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1))
+ qs1 = qs1.reshape((n_blocks, -1))
+ qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qh = qh.reshape((n_blocks, -1))
+ qs = np.concatenate([qs0, qs1, qh], axis=-1)
+ qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1)
+
+ return (d * qs.astype(np.float32))
+
+
+class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d = abs(blocks).max(axis=-1, keepdims=True)
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np_roundf(blocks * id)
+ qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8)
+
+ qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :]
+ qs = qs.reshape((n_blocks, -1))
+
+ d = d.astype(np.float16).view(np.uint8)
+
+ return np.concatenate([qs, d], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ qs, d = np.hsplit(blocks, [QK_K // 4])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1))
+ qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1)
+
+ return (d * qs.astype(np.float32))
+
+
+class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
+ # e2m1 values (doubled)
+ # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+ kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
+
+ @staticmethod
+ # see ggml_e8m0_to_fp32_half in ggml-impl.h
+ def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
+ bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
+ return bits.view(np.float32)
+
+ @classmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d = abs(blocks).max(axis=-1, keepdims=True)
+
+ with np.errstate(divide="ignore"):
+ e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
+
+ d = cls.e8m0_to_fp32_half(e)
+
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
+
+ errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
+ best = np.argmin(errs, axis=-1, keepdims=True)
+
+ qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
+ qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
+
+ qs = qs.reshape((n_blocks, cls.block_size // 2))
+
+ return np.concatenate([e, qs], axis=-1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ e, qs = np.hsplit(blocks, [1])
+
+ d = cls.e8m0_to_fp32_half(e)
+
+ qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
+ qs = (qs & np.uint8(0x0F)).view(np.int8)
+
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
+ qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
+
+ return (d * qs.astype(np.float32))
+
+
+class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
+ ksigns: bytes = (
+ b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
+ b"\x90\x11\x12\x93\x14\x95\x96\x17\x18\x99\x9a\x1b\x9c\x1d\x1e\x9f"
+ b"\xa0\x21\x22\xa3\x24\xa5\xa6\x27\x28\xa9\xaa\x2b\xac\x2d\x2e\xaf"
+ b"\x30\xb1\xb2\x33\xb4\x35\x36\xb7\xb8\x39\x3a\xbb\x3c\xbd\xbe\x3f"
+ b"\xc0\x41\x42\xc3\x44\xc5\xc6\x47\x48\xc9\xca\x4b\xcc\x4d\x4e\xcf"
+ b"\x50\xd1\xd2\x53\xd4\x55\x56\xd7\xd8\x59\x5a\xdb\x5c\xdd\xde\x5f"
+ b"\x60\xe1\xe2\x63\xe4\x65\x66\xe7\xe8\x69\x6a\xeb\x6c\xed\xee\x6f"
+ b"\xf0\x71\x72\xf3\x74\xf5\xf6\x77\x78\xf9\xfa\x7b\xfc\x7d\x7e\xff"
+ )
+
+ # iq2xxs_grid, but with each byte of the original packed in 2 bits,
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
+ grid_shape = (256, 8)
+ grid_map = (0x08, 0x19, 0x2b)
+ grid_hex = (
+ b"00000200050008000a00110014002000220028002a0041004400500058006100"
+ b"6400800082008a00a20001010401100115014001840198010002020222028202"
+ b"010404041004210424044004420448046004810484049004a404000502050805"
+ b"200546056905800591050906100640068406a406000805080808140828084108"
+ b"440850085208880804094009020a140a01100410101021104010601084109010"
+ b"951000110811201150115a118011241245120014081420142514491480141815"
+ b"6215001616160118041810184018811800190519a019511a002002200a204420"
+ b"6120802082202921482100220222012404241024402456240025412564259026"
+ b"082820289428442a014004401040184021402440404048405640604081408440"
+ b"9040004120416141804185410142104248425642684200440844204480449944"
+ b"124524450046014804481048404845480049584961498249454a904a00500850"
+ b"1150195020508050885004514251a4519152905492540a550156545600581158"
+ b"195864584059085a046010604060686000615561186260620064056410651265"
+ b"84654268008002800a8041808280048118814081118201840484108415844084"
+ b"608400854685948509864086608602880489118a0490109024904090a1901691"
+ b"8091459200942294449451958198209902a050a085a009a100a218a450a804a9"
+ )
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, qs = np.hsplit(blocks, [2])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ qs = qs.view(np.uint32).reshape(n_blocks, -1, 2)
+
+ db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25)
+ db = db.reshape((n_blocks, -1, 1, 1))
+
+ # get the sign indices and unpack the bits
+ signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
+ ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
+ signs = signs & np.uint8(0x01)
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
+ signs = signs.reshape((n_blocks, -1, 4, 8))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 4, 8))
+
+ return (db * grid * signs).reshape((n_blocks, -1))
+
+
+class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS):
+ # iq2xs_grid, but with each byte of the original packed in 2 bits,
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
+ grid_shape = (512, 8)
+ grid_map = (0x08, 0x19, 0x2b)
+ grid_hex = (
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
+ b"49005000520055005800610064008000820085008800910094009900a0000101"
+ b"04010601090110011201150118011a0121012401400142014501480151015401"
+ b"6001680181018401900100020202050208021102140220024102440250025502"
+ b"80028a0201040404060409041004120415041804210424044004420445044804"
+ b"5104540456046004810484049004000502050505080511051405200541054405"
+ b"500561058005010604061006260640064206840600080208050808080a081108"
+ b"14082008250841084408500858088008a008aa08010904091009400981098909"
+ b"000a200a280a960aa00a01100410061009101010121015101810211024104010"
+ b"4210451048105110541060106a10811084109010001102110511081111111411"
+ b"2011411144115011801194119611011204120612101240126012001402140514"
+ b"0814111414142014411444144914501464148014011504151015401500161416"
+ b"49160118041810181218401854188618001905196619511aa91a002002200520"
+ b"08200a201120142020204120442050208020a020012104211021402148216521"
+ b"002222228022a82201240424102429244024002541255225992501261a26a626"
+ b"002808280a28202855288828a22868299029082a202a822a882a8a2a01400440"
+ b"0640094010401240154018402140244040404240454048404a40514054406040"
+ b"6540814084409040004102410541084111411441204141414441504180418541"
+ b"a241014204421042124229424042004402440544084411441444194420444144"
+ b"4444504480449444014504451045244540459a4500460a464446504601480448"
+ b"1048404845485448624800491149444950496949044a00500250055008501150"
+ b"145020502850415044505050805001510451105115514051425100524452aa52"
+ b"0154045410542154405460548154a154005508558055885521566856a1560058"
+ b"14584158505899581a5940594259855a0160046010604060546062608660a960"
+ b"006124624a62926200641664106540654565a46501686a682569066a546a626a"
+ b"00800280058008801180148020802a8041804480508080808280a880aa800181"
+ b"0481068110814081518159810082208280828282a082a8820184048410841284"
+ b"158440846084898400854485a58518866a860088088825885a8880888288a888"
+ b"0689228a808a888a968aa88a0190049010904090569084900091229164915692"
+ b"89920094059444945094589429959095929541965198a6984999159a609a00a0"
+ b"02a008a00aa020a02aa0a0a051a159a1a6a100a202a208a22aa280a2a0a240a4"
+ b"95a465a698a60aa820a822a828a8a0a8a8a804a984a986a928aa2aaa91aaaaaa"
+ )
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qs, scales = np.hsplit(rest, [2 * QK_K // 8])
+
+ d = d.view(np.float16).astype(np.float32)
+ qs = qs.view(np.uint16)
+
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
+ db = db.reshape((n_blocks, -1, 1, 1))
+
+ # get the sign indices and unpack the bits
+ signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128)
+ signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1)
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
+ signs = signs & np.uint8(0x01)
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
+ signs = signs.reshape((n_blocks, -1, 2, 8))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 2, 8))
+
+ return (db * grid * signs).reshape((n_blocks, -1))
+
+
+class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S):
+ # iq2s_grid, but with each byte of the original packed in 2 bits,
+ # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2.
+ grid_shape = (1024, 8)
+ grid_map = (0x08, 0x19, 0x2b)
+ grid_hex = (
+ b"00000200050008000a0011001400160019002000220025002800410044004600"
+ b"490050005200550058006100640066006900800082008500880091009400a000"
+ b"a500aa0001010401060109011001120115011801210124014001420145014801"
+ b"510154015601590160016501680181018401900192019501a101a40100020202"
+ b"050208021102140220022a02410244024602490250025502800285028a029402"
+ b"a202010404040604090410041204150418042104240426042904400442044504"
+ b"48044a0451045404560459046004620465048104840486048904900495049804"
+ b"a104a40400050205050508050a05110514051605190520052505280541054405"
+ b"46054905500552055505580561056405800582058505880591059405a0050106"
+ b"0406060609061006150640064506480651065406600681068406900600080208"
+ b"050808081108140816081908200825082a084108440846084908500852085508"
+ b"580861086408800885089408aa08010904091009120915091809210940094509"
+ b"480951095409600981099009000a110a140a220a280a2a0a500a990a01100410"
+ b"0610091010101210151018102110241026104010421045104810511054105610"
+ b"59106010621065106810811084108610901095109810a110a410001102110511"
+ b"08110a1111111411161119112011221125112811411144114611491150115211"
+ b"5511581161116411801182118511881191119411011204120912101215122112"
+ b"2412401245125112541281128412901200140214051408141114141416141914"
+ b"2014251428144114441446144914501452145514581461146414801482148514"
+ b"881491149414a014011504150615091510151215151518152115241540154215"
+ b"4515481551155415601581158415901500160516081611161416201641164416"
+ b"50168016aa160118041806180918101815181818211840184218451848185118"
+ b"541860188118841800190219051908191119141920194119441950196919a219"
+ b"041a101a401a561a00200220052008201120142016201920202025202a204120"
+ b"4420502052205520642080208a209420aa200121042110211221152121214021"
+ b"4221452151215421602181218421902100220a22222228222a22442250228822"
+ b"8a22a82201240424062409241024152418242124242440244224452448245124"
+ b"5424602481248424902400250525082511251425202541254425502566258025"
+ b"0126042610264026592600280528112814284128442850288a28aa2801290429"
+ b"102995290a2a222a642a882a8a2a014004400640094010401240154018401a40"
+ b"21402440264040404240454048404a4051405440564059406040624065408140"
+ b"8440904095409840a140a4400041024105410841114114411641194120412241"
+ b"2541414144414641494150415241554158416141644180418241854188419141"
+ b"9441a04101420442104212421542184224424042454248425142544260428142"
+ b"844200440244054408440a441144144416441944204422442544284441444444"
+ b"46444944504452445544584461446444804482448544884491449444a0440145"
+ b"0445064509451045124515451845214524454045424545454845514554456045"
+ b"6a4581458445904500460246054608461146144620464146444650468046a546"
+ b"0148044809481048124815481848214824484048424845484848514854486048"
+ b"84489048004902490549084911491449204941494449504980499649014a044a"
+ b"104a404a00500250055008501150145016501950205022502550285041504450"
+ b"4650495050505250555058506150645080508250855088509150945001510451"
+ b"0651095110511251155118512151245140514251455148515151545160518151"
+ b"8451905100520552085211521452205241524452505269528052015404540654"
+ b"0954105412541554185421542454405442544554485451545454605481548454"
+ b"9054005502550555085511551455205541554455505580550156045610562656"
+ b"405600580258055808581158145820584158445850585a588058015904591059"
+ b"4059005a195a855aa85a01600460066010601260156018602160246040604560"
+ b"4860516054606060846090600061026105610861116114612061416144615061"
+ b"806199610462106240625662a162006405640864116414642064416444645064"
+ b"806401650465106540654a656865926500669466016804681068656898680069"
+ b"2a69426aa16a0080028005800880118014801980208025804180448050805280"
+ b"5580588061808080858091809480018104810981108112811581188121812481"
+ b"408142814581488151815481818184819081a981008205820a82118214824182"
+ b"4482508201840484068409841084128415841884218440844284458448845184"
+ b"5484608481848484908400850285058508851185148520854185448550858085"
+ b"8a85018604861086298640860088058811881488418844885088a28801890489"
+ b"40896589228a588a5a8a828aa28a019004900990109012901590189024904090"
+ b"4290459048905190549060908190849090900091059111911491419144915091"
+ b"5a910192049210924092a6920094029405940894119414942094419444945094"
+ b"8094969401950495109540959895a19500964696649601980498109826984098"
+ b"a998009949995299909a00a005a00aa014a022a02aa041a044a050a0a2a0aaa0"
+ b"40a165a102a20aa222a228a22aa282a288a28aa2a8a201a404a410a440a489a4"
+ b"a4a400a519a551a60aa828a8a2a854a986a908aa0aaa20aa22aa28aa88aaaaaa"
+ )
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qs, rest = np.hsplit(rest, [QK_K // 8])
+ signs, rest = np.hsplit(rest, [QK_K // 8])
+ qh, scales = np.hsplit(rest, [QK_K // 32])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
+ db = d * (np.float32(0.5) + scales) * np.float32(0.25)
+ db = db.reshape((n_blocks, -1, 1, 1))
+
+ # unpack the sign bits
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
+ signs = signs & np.uint8(0x01)
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
+ signs = signs.reshape((n_blocks, -1, 2, 8))
+
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4))
+ qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 2, 8))
+
+ return (db * grid * signs).reshape((n_blocks, -1))
+
+
+class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS):
+ grid_shape = (256, 4)
+ grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e)
+ grid_hex = (
+ b"0000020004001100130017002000220031004200730075000101030110011201"
+ b"2101250130013201410154017001000202020402110220022202310233023702"
+ b"5102570275020103070310031203250370031304370444045704730475040105"
+ b"0705320552053506640610071407160743076107011003101010121021102310"
+ b"3010321034104710501000110211111120112211011203121012121221123012"
+ b"7212001302132013311346136613011405145014201524154615711505162217"
+ b"4017002002201120132020202220262031204220012103210521102112212121"
+ b"3021632167217021002202221122172220222222372240225522012310231423"
+ b"7023742335245324032527254125742501270327162745270130103012302130"
+ b"2330503065307230003102312031313144314631013203321032253252327232"
+ b"1133333330344734723400350635223555351436363663363337603704401740"
+ b"3540374053405740744120423742404260426642074345430444514464442545"
+ b"4345704505471047124730471250415070500051065126515551145232527252"
+ b"0253535310542354275472540255315550562457425724604460466064602161"
+ b"6161176264623063366344640565526533660367216703700570077010703270"
+ b"5270267140711272457252720073157333736073217441740075027524753076"
+ )
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qs, scales = np.hsplit(rest, [QK_K // 4])
+
+ d = d.view(np.float16).astype(np.float32)
+ scales = scales.view(np.uint32)
+
+ db = d * (np.float32(0.5) + (scales >> 28).astype(np.float32)) * np.float32(0.5)
+ db = db.reshape((n_blocks, -1, 1, 1))
+
+ # get the sign indices and unpack the bits
+ signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4))
+ ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128))
+ signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1))
+ signs = np.take_along_axis(ksigns, signs, axis=-1)
+ signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8))
+ signs = signs & np.uint8(0x01)
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
+ signs = signs.reshape((n_blocks, -1, 4, 8))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 4, 8))
+
+ return (db * grid * signs).reshape((n_blocks, -1))
+
+
+class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S):
+ grid_shape = (512, 4)
+ grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f)
+ grid_hex = (
+ b"0000010002000500070010001100120014001600200021002500330040004200"
+ b"4500470051005300600062007100740077000001010102010401100111011501"
+ b"2001230127013101350144016101650172010002010205020702100213021602"
+ b"2102250230023402420245024702510253027002730203031103150320032203"
+ b"3103330336034403500352036703710375030004130417042104240432044004"
+ b"4304510470040205040520052205260533054105450547056605730506061106"
+ b"1306310652067106000702070407200722072607330750075407001001100210"
+ b"0410101011101310151017102010221031103410361054105610611072100011"
+ b"0111031106111011141121113011331141115011521170117611001212121512"
+ b"1712201224123212401243125512601272120113041307131013131321132713"
+ b"3013341341136213701303140514121414143114331442144614501454140115"
+ b"1015131521153015321551152016241627164416461601170317101712172117"
+ b"3517411762177017002001200320052007201020122014201620212023202720"
+ b"3020322041204320452050205220672070207320752000210221102113211721"
+ b"2221252131213421422151210122042207222122232230223722412253225722"
+ b"7122742200230223052311232223242331233323422350236623012407242024"
+ b"2324322435244124722475240425112522253725402553257025002602260726"
+ b"2126552661260527112726273027432750270230113013301530173022303130"
+ b"3330353042304430473051306330713001310331053114312131233140316031"
+ b"7231763100321232203232323432503201331033143321332333273330334133"
+ b"4333473355337333033411341634223431345234603464340135103512352535"
+ b"3235443556357335163641360137033720372237353700400440124020402440"
+ b"2740324041405040704002410741114113412241304135414341514155410142"
+ b"0342104215422142334240425742624270420443114313432043224331433543"
+ b"0044024424443744404471440545074521456245134634466046104715473047"
+ b"4347514702501050145022504050445047505250665074500151035105511251"
+ b"2151325172510052115223523052365253520253075310532753445351536553"
+ b"7353015404542054325446541255265551555355425602570457225711601360"
+ b"1560316033606060006120612761646112623462426255626262706200631463"
+ b"2163406325644364626400650365346560650566406611671367007004700770"
+ b"2070227036704070547062700271117124714371457101720472107216722172"
+ b"3072517202733273357353730174057413742074507422754275027631760077"
+ )
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qs, rest = np.hsplit(rest, [QK_K // 4])
+ qh, rest = np.hsplit(rest, [QK_K // 32])
+ signs, scales = np.hsplit(rest, [QK_K // 8])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
+ scales = (scales & 0x0F).reshape((n_blocks, -1))
+ db = d * (1 + 2 * scales)
+ db = db.reshape((n_blocks, -1, 1, 1))
+
+ # unpack the sign bits
+ signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8))
+ signs = signs & np.uint8(0x01)
+ signs = np.where(signs == 0, np.float32(1), np.float32(-1))
+ signs = signs.reshape((n_blocks, -1, 4, 8))
+
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8)
+ qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1))
+ qs = qs.astype(np.uint16) | (qh << 8)
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 4, 8))
+
+ return (db * grid * signs).reshape((n_blocks, -1))
+
+
+class IQ1_S(__Quant, qtype=GGMLQuantizationType.IQ1_S):
+ # iq1s_grid, with each byte packed into 2 bits
+ # -1, 0, 1 <=> 0, 1, 2
+ grid_shape = (2048, 8)
+ grid_map = (-1, 0, 1)
+ grid_hex = (
+ b"00000200050008000a00110015002000220028002a0045005100540056006500"
+ b"8000820088008a009500a000a200a800aa000401050111011401160119011a01"
+ b"2501410146014901520155015a0161016401660168018501910194019601a501"
+ b"0002020208020a0215022002220228022a024502510259026402690280028202"
+ b"88028a02910295029902a002a202a802aa021104140416042504410449045504"
+ b"5a046404650491049904a5040105040505050605150518051a05290540054505"
+ b"4a0550055105540555055605590560056205650568056a058105910595059805"
+ b"9a05a105a405a505a605a9051406190641064406500652065506580660066106"
+ b"6606690685069106940699060008020808080a0815082008220828082a084508"
+ b"5108560865088008820888088a089508a008a208a808aa080509110914091909"
+ b"2409250941095009510955096109640969099109940996099909a509000a020a"
+ b"080a0a0a150a200a220a280a2a0a450a510a590a610a650a800a820a850a880a"
+ b"8a0a950aa00aa20aa80aaa0a1010111014101910241025104110441050105510"
+ b"58106110641065106910911094109610a110a510011104110611091110111211"
+ b"1511181121112411291145114a11501151115211541155115611591160116511"
+ b"841192119511a111a41111121412161225124012461249125212551258125a12"
+ b"641266128512911294129612a512011406140914141415141814191421142614"
+ b"41144514461448144a1451145414551456145914621465146814841489149014"
+ b"94149514981499149a14a114a414a514a914021505150a151115141515151615"
+ b"191520152215251528152a154115441545154615511552155415551556155915"
+ b"5a1561156415651566156915801582158415851588158a159015911594159515"
+ b"961599159a15a015a215a51501160416051606161516161618161a1621162616"
+ b"401642164416451648164a165116551656165816591661166416651668166916"
+ b"6a1686168a1692169516a416a916111816182518411844184618491850185518"
+ b"58185a1860186118641866186918851891189418a5181019121915191a192119"
+ b"25194219441945194819511954195519561959195a19601965196a1989199119"
+ b"921995199819a119a619a919091a161a241a261a441a461a491a501a521a551a"
+ b"581a611a661a691a851a911a961a9a1a0020022008200a201520202022202520"
+ b"28202a20452051205920612065208020822088208a209520a020a220a520a820"
+ b"aa2005211121142119212521422144214921552158215a216121642165216621"
+ b"8521902196219921a521012208220a22112215222022222228222a2245225122"
+ b"562259226522812288228a2291229522a022a222a822aa220524142416241924"
+ b"252444244524462449245224552458245a2466248524912494249924a124a524"
+ b"0925152521252925402545254825512554255525592562256525682589259025"
+ b"9425952598259a25a125a425a625a92505261026122619262526412649265526"
+ b"6026612669268426862690269a260028022808280a2815282028222828282a28"
+ b"45285128542865288028822888288a28a028a228a828aa280929112914291929"
+ b"2529462949295229552961296429662969298529902996299929a429a529002a"
+ b"022a082a0a2a202a222a282a2a2a452a512a562a592a652a802a822a882a8a2a"
+ b"952aa02aa22aa82aaa2a054011401640254049405240554058405a4061406440"
+ b"664094409940a140a6400041014104410641094112411541164118411a412141"
+ b"26412941454148414a41514154415541564159415a41654168416a4181418441"
+ b"8641904192419541a041a141a241054211421442164225424142524255425a42"
+ b"6442694289429442a5420144154419442944454448444a445144544455445644"
+ b"61446244654468446a44814486448944904492449544a044a144a94401450245"
+ b"05450a4511451445154516451945204525452a45414544454545464549455045"
+ b"5145544555455645584559456145644565456645694582458445854588459145"
+ b"94459545964599459a45a545a845aa450146054609461446154618461a462146"
+ b"2446294640464246454648465046514652465546564659466246654668468146"
+ b"85468a4694469546a146a446a6460548114815481a4825484248494850485548"
+ b"5848614864486648694885489148944896489948a5480149054906490a491049"
+ b"144915491849214924492649404945494a495149524954495549564959496049"
+ b"6249654966496a49864989499249954996499849a149a449a649a949164a444a"
+ b"464a494a554a584a5a4a644a694a944aa54a0150045005500650095012501550"
+ b"1a50215024502950405045504850515054505550565059506550685086508950"
+ b"95509850a050a150a650a9500551085109510a51115114511551165118511951"
+ b"20512551265128512a5141514451455146514951505151515251545155515651"
+ b"585159515a51615164516551665169518251855191519451955196519951a051"
+ b"a551aa5101520652125215521a5221522452425245524a525152545255525652"
+ b"595262526552855290529252955299529a52a452045405541154145415541654"
+ b"185419542154255428542a54415444544554465449544a545054515454545554"
+ b"5654585459545a54615462546454655466546954805488548a54915494549554"
+ b"96549954a154a454a554aa540155025504550555065509551055115512551455"
+ b"1555165519551a55215524552555265529554055415542554455455546554855"
+ b"4955505551555255545555555655585559555a55605561556455655566556855"
+ b"69556a5581558455855589558a559055915594559555965598559955a155a455"
+ b"a555a655a9550056015602560456065608560956115614561556185619562056"
+ b"2156225624562556265628562956415645564656485649564a56505651565256"
+ b"545655565656585659565a566156645665566956825685568656885689568a56"
+ b"915695569a56a256a556a656a856a95604580558065809581058155818582158"
+ b"2a58455848584a58515854585558565858585958605862586458655882588958"
+ b"9058925895589858a158a9580159025905590a59115914591559165919592559"
+ b"41594459455946594959505951595259545955595659585959595a5961596459"
+ b"655966596959815985598959915994599559965998599959a559045a085a155a"
+ b"1a5a205a255a265a295a455a485a495a515a555a565a585a595a625a655a685a"
+ b"6a5a815a8a5a925a955a965a985a9a5aa15a0560146016601960256044605060"
+ b"5560566058605a60616064606660696081609660a56001610461066109611261"
+ b"15612161226126612961456149615161556156615961656166616a6184618a61"
+ b"92619561a161a661a96111621662196240624162466255625662586260628562"
+ b"91629662a56211641264156416641a6421642664296440644264456448644a64"
+ b"516454645564566459645a646064626465648464856489649064926494649564"
+ b"966498649a64a164a464a964056508650a651165156516651965446545654665"
+ b"496550655165546555655665596561656465656566656965866589658a659165"
+ b"9565966599659a65a265a565a665a86502660966156620662666286629664066"
+ b"456648664a66516654665566566658665a666066656668668066826685668a66"
+ b"9466966698669966a066a466a666aa661668196825684168526855685a686168"
+ b"6968856891689868a66801690469106915692169246926692969406941694569"
+ b"4669486951695469556956695969606965696a69826984698a699569a169a469"
+ b"a569a969116a166a186a416a446a496a506a556a586a5a6a646a656a696a866a"
+ b"946a986a9a6aa66a0080028008800a802080228028802a804580508051805480"
+ b"5680598065808080828088808a809580a080a280a880aa800581118114811681"
+ b"1981258141814481498150815281558156815881598164816681698185818981"
+ b"948196819981a5810082028208820a8215822082228228822a82518254825982"
+ b"65828082828288828a829582a082a282a882aa82148419844184448451845584"
+ b"5a846184648469849484998401850985128515851a8526852985408541854585"
+ b"4885518554855585568559855a856585668568856a8581858485868589859085"
+ b"928595859885a68511861686198625864186448649864a865086558659865a86"
+ b"618666866a86858691869a86a4860088028808880a8815882088228828882a88"
+ b"41884588518854885988658869888088828888888a889588a088a288a888aa88"
+ b"05890689118914891689258941894489468949895089528955895a8961896489"
+ b"858996899989a589008a028a088a0a8a158a208a228a288a2a8a458a518a548a"
+ b"568a808a828a888a8a8a958aa08aa28aa88aaa8a059011901690189019902590"
+ b"419046904990559058905a9069906a9085909190949096909990a59001910491"
+ b"069109911091159118911a912191249126912991409145915091519154915591"
+ b"569159916291659184918691929195919891a191a491a691a991059211921492"
+ b"19922592449246924992509252925592589266926992859294929692a9920194"
+ b"04940694109415941894269440944a9451945494559456945894599460946194"
+ b"62946594849486949294949495949894a194a9940095059508950a9510951195"
+ b"14951595169519952195259529952a9541954495459546954995509551955295"
+ b"549555955695589559955a956195649565956695699581958595889591959295"
+ b"94959595969599959a95a095a295a595a895aa95019604961096159619962096"
+ b"2696299645964896499651965296559656965996659668968296849689968a96"
+ b"929694969596a496a696a9960598169819982598419846985098529855985698"
+ b"5a98649865988598919896989998a59804990699099910991299159918991a99"
+ b"209921992499269940994299459948994a995199549955995699599962996599"
+ b"66996a99819984999099929995999a99a199a699059a159a259a449a469a499a"
+ b"509a559a589a619a859a919a949a959a969a00a002a008a00aa015a020a022a0"
+ b"28a02aa045a051a054a056a059a080a082a088a08aa095a0a0a0a2a0a8a0aaa0"
+ b"05a109a111a114a116a119a11aa146a149a151a155a158a15aa161a164a185a1"
+ b"90a192a196a199a102a208a20aa210a219a222a228a22aa245a251a256a259a2"
+ b"65a280a282a288a28aa295a2a0a2a2a2a8a2aaa219a425a441a444a450a454a4"
+ b"55a458a45aa461a465a466a468a469a485a406a509a510a512a515a518a526a5"
+ b"29a542a545a551a554a555a556a559a565a56aa581a584a585a586a589a592a5"
+ b"95a598a505a611a616a61aa621a625a644a646a64aa652a655a656a658a660a6"
+ b"62a686a690a695a696a699a6a1a6a4a6a6a600a802a808a80aa820a822a828a8"
+ b"2aa851a854a856a859a880a882a888a88aa895a8a0a8a2a8a8a8aaa805a914a9"
+ b"19a921a925a941a950a955a95aa961a966a969a990a996a900aa02aa08aa0aaa"
+ b"20aa22aa28aa2aaa51aa54aa56aa80aa82aa88aa8aaa95aaa0aaa2aaa8aaaaaa"
+ )
+
+ delta = np.float32(0.125)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ qs, qh = np.hsplit(rest, [QK_K // 8])
+
+ d = d.view(np.float16).astype(np.float32)
+ qh = qh.view(np.uint16)
+
+ dl = d * (2 * ((qh >> 12) & 7) + 1)
+ dl = dl.reshape((n_blocks, -1, 1, 1))
+ delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta)
+ delta = delta.reshape((n_blocks, -1, 1, 1))
+
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
+ qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 4, 8))
+
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
+
+
+class IQ1_M(__Quant, qtype=GGMLQuantizationType.IQ1_M):
+ grid_shape = IQ1_S.grid_shape
+ grid_map = IQ1_S.grid_map
+ grid_hex = IQ1_S.grid_hex
+
+ delta = IQ1_S.delta
+
+ # Okay *this* type is weird. It's the only one which stores the f16 scales in multiple parts.
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ qs, rest = np.hsplit(blocks, [QK_K // 8])
+ qh, scales = np.hsplit(rest, [QK_K // 16])
+
+ # The f16 scale is packed across multiple bytes
+ scales = scales.view(np.uint16)
+ d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4))
+ d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3]
+ d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1))
+
+ scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4))
+ scales = (scales & 0x07).reshape((n_blocks, -1))
+ dl = d * (2 * scales + 1)
+ dl = dl.reshape((n_blocks, -1, 2, 1, 1))
+
+ qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
+ qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1))
+
+ delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta)
+ delta = delta.reshape((n_blocks, -1, 2, 2, 1))
+
+ assert cls.grid is not None
+ grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2)
+ grid = grid.reshape((n_blocks, -1, 2, 2, 8))
+
+ return (dl * (grid + delta)).reshape((n_blocks, -1))
+
+
+class IQ4_NL(__Quant, qtype=GGMLQuantizationType.IQ4_NL):
+ kvalues = (-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, qs = np.hsplit(blocks, [2])
+
+ d = d.view(np.float16).astype(np.float32)
+
+ qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+
+ qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1))
+
+ kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1))
+
+ return (d * qs)
+
+
+class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS):
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n_blocks = blocks.shape[0]
+
+ d, rest = np.hsplit(blocks, [2])
+ scales_h, rest = np.hsplit(rest, [2])
+ scales_l, qs = np.hsplit(rest, [QK_K // 64])
+
+ d = d.view(np.float16).astype(np.float32)
+ scales_h = scales_h.view(np.uint16)
+
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2))
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1))
+ scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F)
+ scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03)
+
+ scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32)
+ dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1))
+
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
+ qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F)
+
+ kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1))
+ qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32))
+
+ return (dl * qs).reshape((n_blocks, -1))
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py b/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py
new file mode 100755
index 0000000..86bf878
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py
@@ -0,0 +1,186 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import argparse
+import os
+import sys
+from tqdm import tqdm
+from pathlib import Path
+
+import numpy as np
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+import gguf
+
+logger = logging.getLogger("gguf-convert-endian")
+
+
+def byteswap_noop(tensor, block_offs):
+ # this function is used when byteswapping is not needed
+ pass
+
+
+def byteswap_q4_0(tensor, block_offs):
+ # Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
+
+ # Byte-Swap f16 sized delta field
+ delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
+ delta.byteswap(inplace=True)
+
+
+def byteswap_q8_0(tensor, block_offs):
+ # Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations.
+
+ # Byte-Swap f16 sized delta field
+ delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
+ delta.byteswap(inplace=True)
+
+
+def byteswap_q4_k(tensor, block_offs):
+ # Each block_q4_k consists of 2 f16 values followed by 140 int8 values.
+
+ # Byte-Swap f16 sized fields
+ delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
+ delta.byteswap(inplace=True)
+
+ delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16)
+ delta.byteswap(inplace=True)
+
+
+def byteswap_q6_k(tensor, block_offs):
+ # Each block_q6_k consists of 208 int8 values followed by 1 f16 value.
+
+ # Byte-Swap f16 sized field
+ delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16)
+ delta.byteswap(inplace=True)
+
+
+byteswap_tensors = {
+ gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0,
+ gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0,
+ gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k,
+ gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k,
+ gguf.GGMLQuantizationType.MXFP4: byteswap_noop,
+}
+
+
+def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
+ file_endian = reader.endianess.name
+ if reader.byte_order == 'S':
+ host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
+ else:
+ host_endian = file_endian
+ order = host_endian if args.order == "native" else args.order.upper()
+ logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian")
+ if file_endian == order:
+ logger.info(f"* File is already {order} endian. Nothing to do.")
+ sys.exit(0)
+ logger.info("* Checking tensors for conversion compatibility")
+ for tensor in reader.tensors:
+ if tensor.tensor_type not in byteswap_tensors and \
+ tensor.tensor_type not in (
+ gguf.GGMLQuantizationType.F32,
+ gguf.GGMLQuantizationType.F16,
+ gguf.GGMLQuantizationType.BF16,
+ ):
+ raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
+ logger.info(f"* Preparing to convert from {file_endian} to {order}")
+ if args.dry_run:
+ return
+ logger.warning("*** Warning *** Warning *** Warning **")
+ logger.warning("* This conversion process may damage the file. Ensure you have a backup.")
+ if order != host_endian:
+ logger.warning("* Requested endian differs from host, you will not be able to load the model on this machine.")
+ logger.warning("* The file will be modified immediately, so if conversion fails or is interrupted")
+ logger.warning("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:")
+ response = input("YES, I am sure> ")
+ if response != "YES":
+ logger.warning("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+ logger.info(f"* Converting fields ({len(reader.fields)})")
+ for idx, field in enumerate(reader.fields.values()):
+ logger.info(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}")
+ for part in field.parts:
+ part.byteswap(inplace=True)
+ logger.info(f"* Converting tensors ({len(reader.tensors)})")
+
+ for idx, tensor in enumerate(pbar := tqdm(reader.tensors, desc="Converting tensor")):
+ log_message = (
+ f"Converting tensor {repr(tensor.name)}, "
+ f"type={tensor.tensor_type.name}, "
+ f"elements={tensor.n_elements} "
+ )
+
+ # Byte-swap each part of the tensor's field
+ for part in tensor.field.parts:
+ part.byteswap(inplace=True)
+
+ # Byte-swap tensor data if necessary
+ if tensor.tensor_type in byteswap_tensors:
+ # first flatten structure
+ oldshape = tensor.data.shape
+ newshape = 1
+ for i in tensor.data.shape:
+ newshape *= i
+
+ tensor.data.resize(newshape)
+
+ block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1]
+ byteswap_func = byteswap_tensors[tensor.tensor_type]
+
+ n_blocks = len(tensor.data) // block_size
+ for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):
+ block_offs = block_num * block_size
+
+ byteswap_func(tensor, block_offs)
+
+ if block_num % 100000 == 0:
+ inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]")
+
+ # restore old shape in case it's ever used
+ tensor.data.resize(oldshape)
+ elif tensor.tensor_type == gguf.GGMLQuantizationType.BF16:
+ # Special case for BF16
+ # It is 2-bytes data, but by default view loads it as 1-byte data.
+ # Change to correct view before byteswapping.
+ tensor.data.view(dtype=np.uint16).byteswap(inplace=True)
+ else:
+ # Handle other tensor types
+ tensor.data.byteswap(inplace=True)
+
+ pbar.set_description(log_message)
+
+ logger.info("* Completion")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Convert GGUF file byte order")
+ parser.add_argument(
+ "model", type=str,
+ help="GGUF format model filename",
+ )
+ parser.add_argument(
+ "order", type=str, choices=['big', 'little', 'native'],
+ help="Requested byte order",
+ )
+ parser.add_argument(
+ "--dry-run", action="store_true",
+ help="Don't actually change anything",
+ )
+ parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+
+ args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ logger.info(f'* Loading: {args.model}')
+ reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+')
+ convert_byteorder(reader, args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py b/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py
new file mode 100755
index 0000000..8177dff
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py
@@ -0,0 +1,477 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import argparse
+import os
+import re
+import sys
+from pathlib import Path
+from typing import Any
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402
+
+logger = logging.getLogger("gguf-dump")
+
+
+def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
+ file_endian = reader.endianess.name
+ if reader.byte_order == 'S':
+ host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE'
+ else:
+ host_endian = file_endian
+ return (host_endian, file_endian)
+
+
+# For more information about what field.parts and field.data represent,
+# please see the comments in the modify_gguf.py example.
+def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+ host_endian, file_endian = get_file_host_endian(reader)
+ print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') # noqa: NP100
+ print(f'* Dumping {len(reader.fields)} key/value pair(s)') # noqa: NP100
+ for n, field in enumerate(reader.fields.values(), 1):
+ if not field.types:
+ pretty_type = 'N/A'
+ elif field.types[0] == GGUFValueType.ARRAY:
+ nest_count = len(field.types) - 1
+ pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
+ else:
+ pretty_type = str(field.types[-1].name)
+
+ log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}'
+ if field.types:
+ curr_type = field.types[0]
+ if curr_type == GGUFValueType.STRING:
+ content = field.contents()
+ if len(content) > 60:
+ content = content[:57] + '...'
+ log_message += ' = {0}'.format(repr(content))
+ elif curr_type in reader.gguf_scalar_to_np:
+ log_message += ' = {0}'.format(field.contents())
+ else:
+ content = repr(field.contents(slice(6)))
+ if len(field.data) > 6:
+ content = content[:-1] + ', ...]'
+ log_message += ' = {0}'.format(content)
+ print(log_message) # noqa: NP100
+ if args.no_tensors:
+ return
+ print(f'* Dumping {len(reader.tensors)} tensor(s)') # noqa: NP100
+ for n, tensor in enumerate(reader.tensors, 1):
+ prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)))
+ print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') # noqa: NP100
+
+
+def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
+ import json
+ host_endian, file_endian = get_file_host_endian(reader)
+ metadata: dict[str, Any] = {}
+ tensors: dict[str, Any] = {}
+ result = {
+ "filename": args.model,
+ "endian": file_endian,
+ "metadata": metadata,
+ "tensors": tensors,
+ }
+ for idx, field in enumerate(reader.fields.values()):
+ curr: dict[str, Any] = {
+ "index": idx,
+ "type": field.types[0].name if field.types else 'UNKNOWN',
+ "offset": field.offset,
+ }
+ metadata[field.name] = curr
+ if field.types[:1] == [GGUFValueType.ARRAY]:
+ curr["array_types"] = [t.name for t in field.types][1:]
+ if not args.json_array:
+ continue
+ curr["value"] = field.contents()
+ else:
+ curr["value"] = field.contents()
+ if not args.no_tensors:
+ for idx, tensor in enumerate(reader.tensors):
+ tensors[tensor.name] = {
+ "index": idx,
+ "shape": tensor.shape.tolist(),
+ "type": tensor.tensor_type.name,
+ "offset": tensor.field.offset,
+ }
+ json.dump(result, sys.stdout)
+
+
+def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]):
+ # JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957
+
+ # Alignment Utility Function
+ def strAlign(padding: int, alignMode: str | None, strVal: str):
+ if alignMode == 'center':
+ return strVal.center(padding)
+ elif alignMode == 'right':
+ return strVal.rjust(padding - 1) + ' '
+ elif alignMode == 'left':
+ return ' ' + strVal.ljust(padding - 1)
+ else: # default left
+ return ' ' + strVal.ljust(padding - 1)
+
+ def dashAlign(padding: int, alignMode: str | None):
+ if alignMode == 'center':
+ return ':' + '-' * (padding - 2) + ':'
+ elif alignMode == 'right':
+ return '-' * (padding - 1) + ':'
+ elif alignMode == 'left':
+ return ':' + '-' * (padding - 1)
+ else: # default left
+ return '-' * (padding)
+
+ # Calculate Padding For Each Column Based On Header and Data Length
+ rowsPadding = {}
+ for index, columnEntry in enumerate(header_map):
+ padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2
+ headerPadCount = len(columnEntry['header_name']) + 2
+ rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount
+
+ # Render Markdown Header
+ rows = []
+ rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map)))
+ rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map)))
+
+ # Render Tabular Data
+ for item in data:
+ rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map)))
+
+ # Convert Tabular String Rows Into String
+ tableString = ""
+ for row in rows:
+ tableString += f'|{row}|\n'
+
+ return tableString
+
+
+def element_count_rounded_notation(count: int) -> str:
+ if count > 1e15 :
+ # Quadrillion
+ scaled_amount = count * 1e-15
+ scale_suffix = "Q"
+ elif count > 1e12 :
+ # Trillions
+ scaled_amount = count * 1e-12
+ scale_suffix = "T"
+ elif count > 1e9 :
+ # Billions
+ scaled_amount = count * 1e-9
+ scale_suffix = "B"
+ elif count > 1e6 :
+ # Millions
+ scaled_amount = count * 1e-6
+ scale_suffix = "M"
+ elif count > 1e3 :
+ # Thousands
+ scaled_amount = count * 1e-3
+ scale_suffix = "K"
+ else:
+ # Under Thousands
+ scaled_amount = count
+ scale_suffix = ""
+ return f"{'~' if count > 1e3 else ''}{round(scaled_amount)}{scale_suffix}"
+
+
+def translate_tensor_name(name):
+ words = name.split(".")
+
+ # Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names
+ abbreviation_dictionary = {
+ 'token_embd': 'Token embedding',
+ 'pos_embd': 'Position embedding',
+ 'output_norm': 'Output normalization',
+ 'output': 'Output',
+ 'attn_norm': 'Attention normalization',
+ 'attn_norm_2': 'Attention normalization',
+ 'attn_qkv': 'Attention query-key-value',
+ 'attn_q': 'Attention query',
+ 'attn_k': 'Attention key',
+ 'attn_v': 'Attention value',
+ 'attn_output': 'Attention output',
+ 'ffn_norm': 'Feed-forward network normalization',
+ 'ffn_up': 'Feed-forward network "up"',
+ 'ffn_gate': 'Feed-forward network "gate"',
+ 'ffn_down': 'Feed-forward network "down"',
+ 'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models',
+ 'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models',
+ 'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models',
+ 'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models',
+ 'ssm_in': 'State space model input projections',
+ 'ssm_conv1d': 'State space model rolling/shift',
+ 'ssm_x': 'State space model selective parametrization',
+ 'ssm_a': 'State space model state compression',
+ 'ssm_d': 'State space model skip connection',
+ 'ssm_dt': 'State space model time step',
+ 'ssm_out': 'State space model output projection',
+ 'blk': 'Block',
+ 'enc': 'Encoder',
+ 'dec': 'Decoder',
+ }
+
+ expanded_words = []
+ for word in words:
+ word_norm = word.strip().lower()
+ if word_norm in abbreviation_dictionary:
+ expanded_words.append(abbreviation_dictionary[word_norm].title())
+ else:
+ expanded_words.append(word.title())
+
+ return ' '.join(expanded_words)
+
+
+def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+ host_endian, file_endian = get_file_host_endian(reader)
+ markdown_content = ""
+ markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n'
+ markdown_content += f'- Endian: {file_endian} endian\n'
+ markdown_content += '\n'
+ markdown_content += '## Key Value Metadata Store\n\n'
+ markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n'
+ markdown_content += '\n'
+ total_model_bytes = 0
+ total_model_elements = 0
+
+ kv_dump_table: list[dict[str, str | int]] = []
+ for n, field in enumerate(reader.fields.values(), 1):
+ if not field.types:
+ pretty_type = 'N/A'
+ elif field.types[0] == GGUFValueType.ARRAY:
+ nest_count = len(field.types) - 1
+ pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
+ else:
+ pretty_type = str(field.types[-1].name)
+
+ def escape_markdown_inline_code(value_string):
+ # Find the longest contiguous sequence of backticks in the string then
+ # wrap string with appropriate number of backticks required to escape it
+ max_backticks = max((len(match.group(0)) for match in re.finditer(r'`+', value_string)), default=0)
+ inline_code_marker = '`' * (max_backticks + 1)
+
+ # If the string starts or ends with a backtick, add a space at the beginning and end
+ if value_string.startswith('`') or value_string.endswith('`'):
+ value_string = f" {value_string} "
+
+ return f"{inline_code_marker}{value_string}{inline_code_marker}"
+
+ total_elements = len(field.data)
+ value = ""
+ if len(field.types) == 1:
+ curr_type = field.types[0]
+ if curr_type == GGUFValueType.STRING:
+ truncate_length = 60
+ value_string = str(bytes(field.parts[-1]), encoding='utf-8')
+ if len(value_string) > truncate_length:
+ head = escape_markdown_inline_code(value_string[:truncate_length // 2])
+ tail = escape_markdown_inline_code(value_string[-truncate_length // 2:])
+ value = "{head}...{tail}".format(head=head, tail=tail)
+ else:
+ value = escape_markdown_inline_code(value_string)
+ elif curr_type in reader.gguf_scalar_to_np:
+ value = str(field.parts[-1][0])
+ else:
+ if field.types[0] == GGUFValueType.ARRAY:
+ curr_type = field.types[1]
+ array_elements = []
+
+ if curr_type == GGUFValueType.STRING:
+ render_element = min(5, total_elements)
+ for element_pos in range(render_element):
+ truncate_length = 30
+ value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8')
+ if len(value_string) > truncate_length:
+ head = escape_markdown_inline_code(value_string[:truncate_length // 2])
+ tail = escape_markdown_inline_code(value_string[-truncate_length // 2:])
+ value = "{head}...{tail}".format(head=head, tail=tail)
+ else:
+ value = escape_markdown_inline_code(value_string)
+ array_elements.append(value)
+
+ elif curr_type in reader.gguf_scalar_to_np:
+ render_element = min(7, total_elements)
+ for element_pos in range(render_element):
+ array_elements.append(str(field.parts[-1 - (total_elements - element_pos - 1)][0]))
+
+ value = f'[ {", ".join(array_elements).strip()}{", ..." if total_elements > len(array_elements) else ""} ]'
+
+ kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value})
+
+ kv_dump_table_header_map = [
+ {'key_name':'n', 'header_name':'POS', 'align':'right'},
+ {'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'},
+ {'key_name':'total_elements', 'header_name':'Count', 'align':'right'},
+ {'key_name':'field_name', 'header_name':'Key', 'align':'left'},
+ {'key_name':'value', 'header_name':'Value', 'align':'left'},
+ ]
+
+ markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table)
+
+ markdown_content += "\n"
+
+ if not args.no_tensors:
+ # Group tensors by their prefix and maintain order
+ tensor_prefix_order: list[str] = []
+ tensor_name_to_key: dict[str, int] = {}
+ tensor_groups: dict[str, list[ReaderTensor]] = {}
+ total_elements = sum(tensor.n_elements for tensor in reader.tensors)
+
+ # Parsing Tensors Record
+ for key, tensor in enumerate(reader.tensors):
+ tensor_components = tensor.name.split('.')
+
+ # Classify Tensor Group
+ tensor_group_name = "base"
+ if tensor_components[0] == 'blk':
+ tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}"
+ elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk':
+ tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}"
+ elif tensor_components[0] in ['enc', 'dec']:
+ tensor_group_name = f"{tensor_components[0]}"
+
+ # Check if new Tensor Group
+ if tensor_group_name not in tensor_groups:
+ tensor_groups[tensor_group_name] = []
+ tensor_prefix_order.append(tensor_group_name)
+
+ # Record Tensor and Tensor Position
+ tensor_groups[tensor_group_name].append(tensor)
+ tensor_name_to_key[tensor.name] = key
+
+ # Tensors Mapping Dump
+ markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n'
+ markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n'
+ markdown_content += '\n'
+
+ for group in tensor_prefix_order:
+ tensors = tensor_groups[group]
+ group_elements = sum(tensor.n_elements for tensor in tensors)
+ markdown_content += f"- [{translate_tensor_name(group)} Tensor Group - {element_count_rounded_notation(group_elements)} Elements](#{group.replace('.', '_')})\n"
+
+ markdown_content += "\n"
+
+ markdown_content += "### Tensor Data Offset\n"
+ markdown_content += '\n'
+ markdown_content += 'This table contains the offset and data segment relative to start of file\n'
+ markdown_content += '\n'
+
+ tensor_mapping_table: list[dict[str, str | int]] = []
+ for key, tensor in enumerate(reader.tensors):
+ data_offset_pretty = '{0:#16x}'.format(tensor.data_offset)
+ data_size_pretty = '{0:#16x}'.format(tensor.n_bytes)
+ tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty})
+
+ tensors_mapping_table_header_map = [
+ {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
+ {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'},
+ {'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'},
+ {'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'},
+ ]
+
+ markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table)
+ markdown_content += "\n"
+
+ for group in tensor_prefix_order:
+ tensors = tensor_groups[group]
+ group_elements = sum(tensor.n_elements for tensor in tensors)
+ group_percentage = group_elements / total_elements * 100
+ total_group_bytes = 0
+ total_group_elements = 0
+ markdown_content += f"### <a name=\"{group.replace('.', '_')}\">{translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements</a>\n\n"
+
+ # Precalculate column sizing for visual consistency
+ prettify_element_est_count_size: int = 1
+ prettify_element_count_size: int = 1
+ prettify_dimension_max_widths: dict[int, int] = {}
+ for tensor in tensors:
+ prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements))))
+ prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements)))
+ for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))):
+ prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size)))
+
+ # Generate Tensor Layer Table Content
+ tensor_dump_table: list[dict[str, str | int]] = []
+ for tensor in tensors:
+ human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)"))
+ pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))))
+ element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})"
+ element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}"
+ type_name_string = f"{tensor.tensor_type.name}"
+ if tensor.n_elements > 0:
+ bpw = (tensor.n_bytes * 8) / tensor.n_elements
+ else:
+ bpw = float('nan')
+ tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string, "bpw": f"{bpw:.4f}"})
+ total_group_bytes += tensor.n_bytes
+ total_group_elements += tensor.n_elements
+
+ tensor_dump_table_header_map = [
+ {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'},
+ {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'},
+ {'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'},
+ {'key_name':'element_count', 'header_name':'Elements', 'align':'left'},
+ {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'},
+ {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'},
+ {'key_name':'bpw', 'header_name':'BPW', 'align':'right'},
+ ]
+
+ markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table)
+
+ markdown_content += "\n"
+ markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n"
+ markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n"
+ if total_group_elements > 0:
+ total_group_bpw = (total_group_bytes * 8) / total_group_elements
+ markdown_content += f"- Bits per Weight (BPW) for {group}: {total_group_bpw:.4f} bits\n"
+ else:
+ markdown_content += f"- Bits per Weight (BPW) for {group}: undefined (no elements)\n"
+ markdown_content += "\n\n"
+ total_model_bytes += total_group_bytes
+ total_model_elements += total_group_elements
+
+ if total_model_elements > 0:
+ total_model_bpw = (total_model_bytes * 8) / total_model_elements
+ markdown_content += f"Total BPW for {os.path.basename(args.model)}: {total_model_bpw:.4f} bits"
+ else:
+ markdown_content += f"Total BPW for {os.path.basename(args.model)}: undefined (no elements)"
+ print(markdown_content) # noqa: NP100
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
+ parser.add_argument("model", type=str, help="GGUF format model filename")
+ parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
+ parser.add_argument("--json", action="store_true", help="Produce JSON output")
+ parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
+ parser.add_argument("--data-offset", action="store_true", help="Start of data offset")
+ parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field")
+ parser.add_argument("--markdown", action="store_true", help="Produce markdown output")
+ parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+
+ args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ if not args.json and not args.markdown and not args.data_offset and not args.data_alignment:
+ logger.info(f'* Loading: {args.model}')
+
+ reader = GGUFReader(args.model, 'r')
+
+ if args.json:
+ dump_metadata_json(reader, args)
+ elif args.markdown:
+ dump_markdown_metadata(reader, args)
+ elif args.data_offset:
+ print(reader.data_offset) # noqa: NP100
+ elif args.data_alignment:
+ print(reader.alignment) # noqa: NP100
+ else:
+ dump_metadata(reader, args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py b/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py
new file mode 100755
index 0000000..293316a
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py
@@ -0,0 +1,1621 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import argparse
+import os
+import sys
+import numpy
+import enum
+from pathlib import Path
+from typing import Any, Optional, Tuple, Type
+import warnings
+
+import numpy as np
+from PySide6.QtWidgets import (
+ QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
+ QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget,
+ QTableWidgetItem, QComboBox, QMessageBox, QTabWidget,
+ QTextEdit, QFormLayout,
+ QHeaderView, QDialog, QDialogButtonBox
+)
+from PySide6.QtCore import Qt
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+import gguf
+from gguf import GGUFReader, GGUFWriter, GGUFValueType, ReaderField
+from gguf.constants import TokenType, RopeScalingType, PoolingType, GGMLQuantizationType
+
+logger = logging.getLogger("gguf-editor-gui")
+
+# Map of key names to enum types for automatic enum interpretation
+KEY_TO_ENUM_TYPE = {
+ gguf.Keys.Tokenizer.TOKEN_TYPE: TokenType,
+ gguf.Keys.Rope.SCALING_TYPE: RopeScalingType,
+ gguf.Keys.LLM.POOLING_TYPE: PoolingType,
+ gguf.Keys.General.FILE_TYPE: GGMLQuantizationType,
+}
+
+# Define the tokenizer keys that should be edited together
+TOKENIZER_LINKED_KEYS = [
+ gguf.Keys.Tokenizer.LIST,
+ gguf.Keys.Tokenizer.TOKEN_TYPE,
+ gguf.Keys.Tokenizer.SCORES
+]
+
+
+class TokenizerEditorDialog(QDialog):
+ def __init__(self, tokens, token_types, scores, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle("Edit Tokenizer Data")
+ self.resize(900, 600)
+
+ self.tokens = tokens.copy() if tokens else []
+ self.token_types = token_types.copy() if token_types else []
+ self.scores = scores.copy() if scores else []
+
+ # Ensure all arrays have the same length
+ max_len = max(len(self.tokens), len(self.token_types), len(self.scores))
+ if len(self.tokens) < max_len:
+ self.tokens.extend([""] * (max_len - len(self.tokens)))
+ if len(self.token_types) < max_len:
+ self.token_types.extend([0] * (max_len - len(self.token_types)))
+ if len(self.scores) < max_len:
+ self.scores.extend([0.0] * (max_len - len(self.scores)))
+
+ layout = QVBoxLayout(self)
+
+ # Add filter controls
+ filter_layout = QHBoxLayout()
+ filter_layout.addWidget(QLabel("Filter:"))
+ self.filter_edit = QLineEdit()
+ self.filter_edit.setPlaceholderText("Type to filter tokens...")
+ self.filter_edit.textChanged.connect(self.apply_filter)
+ filter_layout.addWidget(self.filter_edit)
+
+ # Add page controls
+ self.page_size = 100 # Show 100 items per page
+ self.current_page = 0
+ self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size)
+
+ self.page_label = QLabel(f"Page 1 of {self.total_pages}")
+ filter_layout.addWidget(self.page_label)
+
+ prev_page = QPushButton("Previous")
+ prev_page.clicked.connect(self.previous_page)
+ filter_layout.addWidget(prev_page)
+
+ next_page = QPushButton("Next")
+ next_page.clicked.connect(self.next_page)
+ filter_layout.addWidget(next_page)
+
+ layout.addLayout(filter_layout)
+
+ # Tokenizer data table
+ self.tokens_table = QTableWidget()
+ self.tokens_table.setColumnCount(4)
+ self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"])
+ self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+ self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+ self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+ self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+
+ layout.addWidget(self.tokens_table)
+
+ # Controls
+ controls_layout = QHBoxLayout()
+
+ add_button = QPushButton("Add Token")
+ add_button.clicked.connect(self.add_token)
+ controls_layout.addWidget(add_button)
+
+ remove_button = QPushButton("Remove Selected")
+ remove_button.clicked.connect(self.remove_selected)
+ controls_layout.addWidget(remove_button)
+
+ controls_layout.addStretch()
+
+ layout.addLayout(controls_layout)
+
+ # Buttons
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addWidget(buttons)
+
+ # Initialize the filtered values
+ self.filtered_indices = list(range(len(self.tokens)))
+
+ # Load data for the first page
+ self.load_page()
+
+ def apply_filter(self):
+ """Filter the tokens based on the search text."""
+ filter_text = self.filter_edit.text().lower()
+
+ if not filter_text:
+ # No filter, show all values
+ self.filtered_indices = list(range(len(self.tokens)))
+ else:
+ # Apply filter
+ self.filtered_indices = []
+ for i, token in enumerate(self.tokens):
+ if filter_text in str(token).lower():
+ self.filtered_indices.append(i)
+
+ # Reset to first page and reload
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+ self.current_page = 0
+ self.page_label.setText(f"Page 1 of {self.total_pages}")
+ self.load_page()
+
+ def previous_page(self):
+ """Go to the previous page of results."""
+ if self.current_page > 0:
+ self.current_page -= 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+ self.load_page()
+
+ def next_page(self):
+ """Go to the next page of results."""
+ if self.current_page < self.total_pages - 1:
+ self.current_page += 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+ self.load_page()
+
+ def load_page(self):
+ """Load the current page of tokenizer data."""
+ self.tokens_table.setRowCount(0) # Clear the table
+
+ # Calculate start and end indices for the current page
+ start_idx = self.current_page * self.page_size
+ end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
+
+ # Pre-allocate rows for better performance
+ self.tokens_table.setRowCount(end_idx - start_idx)
+
+ for row, i in enumerate(range(start_idx, end_idx)):
+ orig_idx = self.filtered_indices[i]
+
+ # Index
+ index_item = QTableWidgetItem(str(orig_idx))
+ index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index
+ index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tokens_table.setItem(row, 0, index_item)
+
+ # Token
+ token_item = QTableWidgetItem(str(self.tokens[orig_idx]))
+ self.tokens_table.setItem(row, 1, token_item)
+
+ # Token Type
+ token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
+ try:
+ enum_val = TokenType(token_type)
+ display_text = f"{enum_val.name} ({token_type})"
+ except (ValueError, KeyError):
+ display_text = f"Unknown ({token_type})"
+
+ type_item = QTableWidgetItem(display_text)
+ type_item.setData(Qt.ItemDataRole.UserRole, token_type)
+
+ # Make type cell editable with a double-click handler
+ type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tokens_table.setItem(row, 2, type_item)
+
+ # Score
+ score = self.scores[orig_idx] if orig_idx < len(self.scores) else 0.0
+ score_item = QTableWidgetItem(str(score))
+ self.tokens_table.setItem(row, 3, score_item)
+
+ # Connect double-click handler for token type cells
+ self.tokens_table.cellDoubleClicked.connect(self.handle_cell_double_click)
+
+ def handle_cell_double_click(self, row, column):
+ """Handle double-click on a cell, specifically for token type editing."""
+ if column == 2: # Token Type column
+ orig_item = self.tokens_table.item(row, 0)
+ if orig_item:
+ orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+ self.edit_token_type(row, orig_idx)
+
+ def edit_token_type(self, row, orig_idx):
+ """Edit a token type using a dialog with a dropdown of all enum options."""
+ current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
+
+ # Create a dialog with enum options
+ dialog = QDialog(self)
+ dialog.setWindowTitle("Select Token Type")
+ layout = QVBoxLayout(dialog)
+
+ combo = QComboBox()
+ for enum_val in TokenType:
+ combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+ # Set current value
+ try:
+ if isinstance(current_value, int):
+ enum_val = TokenType(current_value)
+ combo.setCurrentText(f"{enum_val.name} ({current_value})")
+ except (ValueError, KeyError):
+ pass
+
+ layout.addWidget(combo)
+
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(dialog.accept)
+ buttons.rejected.connect(dialog.reject)
+ layout.addWidget(buttons)
+
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ # Get the selected value
+ new_value = combo.currentData()
+ enum_val = TokenType(new_value)
+ display_text = f"{enum_val.name} ({new_value})"
+
+ # Update the display
+ type_item = self.tokens_table.item(row, 2)
+ if type_item:
+ type_item.setText(display_text)
+ type_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+ # Update the actual value
+ self.token_types[orig_idx] = new_value
+
+ def add_token(self):
+ """Add a new token to the end of the list."""
+ # Add to the end of the arrays
+ self.tokens.append("")
+ self.token_types.append(0) # Default to normal token
+ self.scores.append(0.0)
+
+ orig_idx = len(self.tokens) - 1
+
+ # Add to filtered indices if it matches the current filter
+ filter_text = self.filter_edit.text().lower()
+ if not filter_text or filter_text in "":
+ self.filtered_indices.append(orig_idx)
+
+ # Update pagination
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+
+ # Go to the last page to show the new item
+ self.current_page = self.total_pages - 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+ # Reload the page
+ self.load_page()
+
+ def remove_selected(self):
+ """Remove selected tokens from all arrays."""
+ selected_rows = []
+ for item in self.tokens_table.selectedItems():
+ row = item.row()
+ if row not in selected_rows:
+ selected_rows.append(row)
+
+ if not selected_rows:
+ return
+
+ # Get original indices in descending order to avoid index shifting
+ orig_indices = []
+ for row in selected_rows:
+ orig_item = self.tokens_table.item(row, 0)
+ if orig_item:
+ orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
+ orig_indices.sort(reverse=True)
+
+ # Remove from all arrays
+ for idx in orig_indices:
+ if idx < len(self.tokens):
+ del self.tokens[idx]
+ if idx < len(self.token_types):
+ del self.token_types[idx]
+ if idx < len(self.scores):
+ del self.scores[idx]
+
+ # Rebuild filtered_indices
+ self.filtered_indices = []
+ filter_text = self.filter_edit.text().lower()
+
+ for i, token in enumerate(self.tokens):
+ if not filter_text or filter_text in str(token).lower():
+ self.filtered_indices.append(i)
+
+ # Update pagination
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+ self.current_page = min(self.current_page, self.total_pages - 1)
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+ # Reload the page
+ self.load_page()
+
+ def get_data(self):
+ """Return the edited tokenizer data."""
+ return self.tokens, self.token_types, self.scores
+
+
+class ArrayEditorDialog(QDialog):
+ def __init__(self, array_values, element_type, key=None, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle("Edit Array Values")
+ self.resize(700, 500)
+
+ self.array_values = array_values
+ self.element_type = element_type
+ self.key = key
+
+ # Get enum type for this array if applicable
+ self.enum_type = None
+ if key in KEY_TO_ENUM_TYPE and element_type == GGUFValueType.INT32:
+ self.enum_type = KEY_TO_ENUM_TYPE[key]
+
+ layout = QVBoxLayout(self)
+
+ # Add enum type information if applicable
+ if self.enum_type is not None:
+ enum_info_layout = QHBoxLayout()
+ enum_label = QLabel(f"Editing {self.enum_type.__name__} values:")
+ enum_info_layout.addWidget(enum_label)
+
+ # Add a legend for the enum values
+ enum_values = ", ".join([f"{e.name}={e.value}" for e in self.enum_type])
+ enum_values_label = QLabel(f"Available values: {enum_values}")
+ enum_values_label.setWordWrap(True)
+ enum_info_layout.addWidget(enum_values_label, 1)
+
+ layout.addLayout(enum_info_layout)
+
+ # Add search/filter controls
+ filter_layout = QHBoxLayout()
+ filter_layout.addWidget(QLabel("Filter:"))
+ self.filter_edit = QLineEdit()
+ self.filter_edit.setPlaceholderText("Type to filter values...")
+ self.filter_edit.textChanged.connect(self.apply_filter)
+ filter_layout.addWidget(self.filter_edit)
+
+ # Add page controls for large arrays
+ self.page_size = 100 # Show 100 items per page
+ self.current_page = 0
+ self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size)
+
+ self.page_label = QLabel(f"Page 1 of {self.total_pages}")
+ filter_layout.addWidget(self.page_label)
+
+ prev_page = QPushButton("Previous")
+ prev_page.clicked.connect(self.previous_page)
+ filter_layout.addWidget(prev_page)
+
+ next_page = QPushButton("Next")
+ next_page.clicked.connect(self.next_page)
+ filter_layout.addWidget(next_page)
+
+ layout.addLayout(filter_layout)
+
+ # Array items table
+ self.items_table = QTableWidget()
+
+ # Set up columns based on whether we have an enum type
+ if self.enum_type is not None:
+ self.items_table.setColumnCount(3)
+ self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"])
+ self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+ self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+ self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+ else:
+ self.items_table.setColumnCount(2)
+ self.items_table.setHorizontalHeaderLabels(["Index", "Value"])
+ self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+ self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+
+ layout.addWidget(self.items_table)
+
+ # Controls
+ controls_layout = QHBoxLayout()
+
+ add_button = QPushButton("Add Item")
+ add_button.clicked.connect(self.add_item)
+ controls_layout.addWidget(add_button)
+
+ remove_button = QPushButton("Remove Selected")
+ remove_button.clicked.connect(self.remove_selected)
+ controls_layout.addWidget(remove_button)
+
+ # Add bulk edit button for enum arrays
+ if self.enum_type is not None:
+ bulk_edit_button = QPushButton("Bulk Edit Selected")
+ bulk_edit_button.clicked.connect(self.bulk_edit_selected)
+ controls_layout.addWidget(bulk_edit_button)
+
+ controls_layout.addStretch()
+
+ layout.addLayout(controls_layout)
+
+ # Buttons
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addWidget(buttons)
+
+ # Initialize the filtered values
+ self.filtered_indices = list(range(len(self.array_values)))
+
+ # Load array values for the first page
+ self.load_page()
+
+ def apply_filter(self):
+ """Filter the array values based on the search text."""
+ filter_text = self.filter_edit.text().lower()
+
+ if not filter_text:
+ # No filter, show all values
+ self.filtered_indices = list(range(len(self.array_values)))
+ else:
+ # Apply filter
+ self.filtered_indices = []
+ for i, value in enumerate(self.array_values):
+ # For enum values, search in both name and value
+ if self.enum_type is not None and isinstance(value, int):
+ try:
+ enum_val = self.enum_type(value)
+ display_text = f"{enum_val.name} ({value})".lower()
+ if filter_text in display_text:
+ self.filtered_indices.append(i)
+ except (ValueError, KeyError):
+ # If not a valid enum value, just check the raw value
+ if filter_text in str(value).lower():
+ self.filtered_indices.append(i)
+ else:
+ # For non-enum values, just check the string representation
+ if filter_text in str(value).lower():
+ self.filtered_indices.append(i)
+
+ # Reset to first page and reload
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+ self.current_page = 0
+ self.page_label.setText(f"Page 1 of {self.total_pages}")
+ self.load_page()
+
+ def previous_page(self):
+ """Go to the previous page of results."""
+ if self.current_page > 0:
+ self.current_page -= 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+ self.load_page()
+
+ def next_page(self):
+ """Go to the next page of results."""
+ if self.current_page < self.total_pages - 1:
+ self.current_page += 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+ self.load_page()
+
+ def load_page(self):
+ """Load the current page of array values."""
+ self.items_table.setRowCount(0) # Clear the table
+
+ # Calculate start and end indices for the current page
+ start_idx = self.current_page * self.page_size
+ end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
+
+ # Pre-allocate rows for better performance
+ self.items_table.setRowCount(end_idx - start_idx)
+
+ for row, i in enumerate(range(start_idx, end_idx)):
+ orig_idx = self.filtered_indices[i]
+ value = self.array_values[orig_idx]
+
+ # Index
+ index_item = QTableWidgetItem(str(orig_idx))
+ index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index
+ index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.items_table.setItem(row, 0, index_item)
+
+ # Value
+ if self.enum_type is not None:
+ # Display enum value and name
+ try:
+ if isinstance(value, (int, numpy.signedinteger)):
+ enum_val = self.enum_type(value)
+ display_text = f"{enum_val.name} ({value})"
+ else:
+ display_text = str(value)
+ except (ValueError, KeyError):
+ display_text = f"Unknown ({value})"
+
+ # Store the enum value in the item
+ value_item = QTableWidgetItem(display_text)
+ value_item.setData(Qt.ItemDataRole.UserRole, value)
+ value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.items_table.setItem(row, 1, value_item)
+
+ # Add an edit button in a separate column
+ edit_button = QPushButton("Edit")
+ edit_button.setProperty("row", row)
+ edit_button.clicked.connect(self.edit_array_enum_value)
+
+ # Create a widget to hold the button
+ button_widget = QWidget()
+ button_layout = QHBoxLayout(button_widget)
+ button_layout.setContentsMargins(2, 2, 2, 2)
+ button_layout.addWidget(edit_button)
+ button_layout.addStretch()
+
+ self.items_table.setCellWidget(row, 2, button_widget)
+ else:
+ value_item = QTableWidgetItem(str(value))
+ self.items_table.setItem(row, 1, value_item)
+
+ def edit_array_enum_value(self):
+ """Handle editing an enum value in the array editor."""
+ button = self.sender()
+ row = button.property("row")
+
+ # Get the original index from the table item
+ orig_item = self.items_table.item(row, 0)
+ new_item = self.items_table.item(row, 1)
+ if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type):
+ orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+ new_value = new_item.data(Qt.ItemDataRole.UserRole)
+ # Update the stored value in the array
+ if isinstance(new_value, (int, float, str, bool)):
+ self.array_values[orig_idx] = new_value
+
+ def bulk_edit_selected(self):
+ """Edit multiple enum values at once."""
+ if not self.enum_type:
+ return
+
+ selected_rows = set()
+ for item in self.items_table.selectedItems():
+ selected_rows.add(item.row())
+
+ if not selected_rows:
+ QMessageBox.information(self, "No Selection", "Please select at least one row to edit.")
+ return
+
+ # Create a dialog with enum options
+ dialog = QDialog(self)
+ dialog.setWindowTitle(f"Bulk Edit {self.enum_type.__name__} Values")
+ layout = QVBoxLayout(dialog)
+
+ layout.addWidget(QLabel(f"Set {len(selected_rows)} selected items to:"))
+
+ combo = QComboBox()
+ for enum_val in self.enum_type:
+ combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+ layout.addWidget(combo)
+
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(dialog.accept)
+ buttons.rejected.connect(dialog.reject)
+ layout.addWidget(buttons)
+
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ # Get the selected value
+ new_value = combo.currentData()
+ enum_val = self.enum_type(new_value)
+ display_text = f"{enum_val.name} ({new_value})"
+
+ # Update all selected rows
+ for row in selected_rows:
+ orig_item = self.items_table.item(row, 0)
+ new_item = self.items_table.item(row, 1)
+ if orig_item and new_item:
+ orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+ self.array_values[orig_idx] = new_value
+
+ # Update the display
+ new_item.setText(display_text)
+ new_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+ def add_item(self):
+ # Add to the end of the array
+ orig_idx = len(self.array_values)
+
+ # Add default value based on type
+ if self.enum_type is not None:
+ # Default to first enum value
+ default_value = list(self.enum_type)[0].value
+ self.array_values.append(default_value)
+ else:
+ if self.element_type == GGUFValueType.STRING:
+ self.array_values.append("")
+ else:
+ self.array_values.append(0)
+
+ # Add to filtered indices if it matches the current filter
+ self.filtered_indices.append(orig_idx)
+
+ # Update pagination
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+
+ # Go to the last page to show the new item
+ self.current_page = self.total_pages - 1
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+ # Reload the page
+ self.load_page()
+
+ def remove_selected(self):
+ selected_rows = []
+ for item in self.items_table.selectedItems():
+ row = item.row()
+ if row not in selected_rows:
+ selected_rows.append(row)
+
+ if not selected_rows:
+ return
+
+ # Get original indices in descending order to avoid index shifting
+ orig_indices = list()
+ for row in selected_rows:
+ orig_item = self.items_table.item(row, 0)
+ if orig_item:
+ orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
+ orig_indices.sort(reverse=True)
+
+ # Remove from array_values
+ for idx in orig_indices:
+ del self.array_values[idx]
+
+ # Rebuild filtered_indices
+ self.filtered_indices = []
+ filter_text = self.filter_edit.text().lower()
+
+ for i, value in enumerate(self.array_values):
+ if not filter_text:
+ self.filtered_indices.append(i)
+ else:
+ # Apply filter
+ if self.enum_type is not None and isinstance(value, int):
+ try:
+ enum_val = self.enum_type(value)
+ display_text = f"{enum_val.name} ({value})".lower()
+ if filter_text in display_text:
+ self.filtered_indices.append(i)
+ except (ValueError, KeyError):
+ if filter_text in str(value).lower():
+ self.filtered_indices.append(i)
+ else:
+ if filter_text in str(value).lower():
+ self.filtered_indices.append(i)
+
+ # Update pagination
+ self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+ self.current_page = min(self.current_page, self.total_pages - 1)
+ self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+ # Reload the page
+ self.load_page()
+
+ def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]):
+ """Edit an enum value using a dialog with a dropdown of all enum options."""
+ # Get the original index from the table item
+ orig_item = self.items_table.item(row, 0)
+ if orig_item:
+ orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+ else:
+ return
+ current_value = self.array_values[orig_idx]
+
+ # Create a dialog with enum options
+ dialog = QDialog(self)
+ dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
+ layout = QVBoxLayout(dialog)
+
+ # Add description
+ description = QLabel(f"Select a {enum_type.__name__} value:")
+ layout.addWidget(description)
+
+ # Use a combo box for quick selection
+ combo = QComboBox()
+ for enum_val in enum_type:
+ combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+ # Set current value
+ try:
+ if isinstance(current_value, int):
+ enum_val = enum_type(current_value)
+ combo.setCurrentText(f"{enum_val.name} ({current_value})")
+ except (ValueError, KeyError):
+ pass
+
+ layout.addWidget(combo)
+
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(dialog.accept)
+ buttons.rejected.connect(dialog.reject)
+ layout.addWidget(buttons)
+
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ # Update the value display and stored data
+ new_value = combo.currentData()
+ enum_val = enum_type(new_value)
+ display_text = f"{enum_val.name} ({new_value})"
+
+ new_item = self.items_table.item(row, 1)
+ if new_item:
+ new_item.setText(display_text)
+ new_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+ # Update the actual array value
+ self.array_values[orig_idx] = new_value
+ return True
+ return False
+
+ def get_array_values(self):
+ # The array_values list is kept up-to-date as edits are made
+ return self.array_values
+
+
+class AddMetadataDialog(QDialog):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle("Add Metadata")
+ self.resize(400, 200)
+
+ layout = QVBoxLayout(self)
+
+ form_layout = QFormLayout()
+
+ self.key_edit = QLineEdit()
+ form_layout.addRow("Key:", self.key_edit)
+
+ self.type_combo = QComboBox()
+ for value_type in GGUFValueType:
+ if value_type != GGUFValueType.ARRAY: # Skip array type for simplicity
+ self.type_combo.addItem(value_type.name, value_type)
+ form_layout.addRow("Type:", self.type_combo)
+
+ self.value_edit = QTextEdit()
+ form_layout.addRow("Value:", self.value_edit)
+
+ layout.addLayout(form_layout)
+
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(self.accept)
+ buttons.rejected.connect(self.reject)
+ layout.addWidget(buttons)
+
+ def get_data(self) -> Tuple[str, GGUFValueType, Any]:
+ key = self.key_edit.text()
+ value_type = self.type_combo.currentData()
+ value_text = self.value_edit.toPlainText()
+
+ # Convert value based on type
+ if value_type == GGUFValueType.UINT8:
+ value = np.uint8(int(value_text))
+ elif value_type == GGUFValueType.INT8:
+ value = np.int8(int(value_text))
+ elif value_type == GGUFValueType.UINT16:
+ value = np.uint16(int(value_text))
+ elif value_type == GGUFValueType.INT16:
+ value = np.int16(int(value_text))
+ elif value_type == GGUFValueType.UINT32:
+ value = np.uint32(int(value_text))
+ elif value_type == GGUFValueType.INT32:
+ value = np.int32(int(value_text))
+ elif value_type == GGUFValueType.FLOAT32:
+ value = np.float32(float(value_text))
+ elif value_type == GGUFValueType.BOOL:
+ value = value_text.lower() in ('true', 'yes', '1')
+ elif value_type == GGUFValueType.STRING:
+ value = value_text
+ else:
+ value = value_text
+
+ return key, value_type, value
+
+
+class GGUFEditorWindow(QMainWindow):
+ def __init__(self):
+ super().__init__()
+
+ self.setWindowTitle("GGUF Editor")
+ self.resize(1000, 800)
+
+ self.current_file = None
+ self.reader = None
+ self.modified = False
+ self.metadata_changes = {} # Store changes to apply when saving
+ self.metadata_to_remove = set() # Store keys to remove when saving
+ self.on_metadata_changed_is_connected = False
+
+ self.setup_ui()
+
+ def setup_ui(self):
+ central_widget = QWidget()
+ self.setCentralWidget(central_widget)
+
+ main_layout = QVBoxLayout(central_widget)
+
+ # File controls
+ file_layout = QHBoxLayout()
+
+ self.file_path_edit = QLineEdit()
+ self.file_path_edit.setReadOnly(True)
+ file_layout.addWidget(self.file_path_edit)
+
+ open_button = QPushButton("Open GGUF")
+ open_button.clicked.connect(self.open_file)
+ file_layout.addWidget(open_button)
+
+ save_button = QPushButton("Save As...")
+ save_button.clicked.connect(self.save_file)
+ file_layout.addWidget(save_button)
+
+ main_layout.addLayout(file_layout)
+
+ # Tabs for different views
+ self.tabs = QTabWidget()
+
+ # Metadata tab
+ self.metadata_tab = QWidget()
+ metadata_layout = QVBoxLayout(self.metadata_tab)
+
+ # Metadata table
+ self.metadata_table = QTableWidget()
+ self.metadata_table.setColumnCount(4)
+ self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"])
+ self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
+ self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
+ self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch)
+ self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+ metadata_layout.addWidget(self.metadata_table)
+
+ # Metadata controls
+ metadata_controls = QHBoxLayout()
+
+ add_metadata_button = QPushButton("Add Metadata")
+ add_metadata_button.clicked.connect(self.add_metadata)
+ metadata_controls.addWidget(add_metadata_button)
+
+ metadata_controls.addStretch()
+
+ metadata_layout.addLayout(metadata_controls)
+
+ # Tensors tab
+ self.tensors_tab = QWidget()
+ tensors_layout = QVBoxLayout(self.tensors_tab)
+
+ self.tensors_table = QTableWidget()
+ self.tensors_table.setColumnCount(5)
+ self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"])
+ self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
+ self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
+ self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+ self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+ self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents)
+ tensors_layout.addWidget(self.tensors_table)
+
+ # Add tabs to tab widget
+ self.tabs.addTab(self.metadata_tab, "Metadata")
+ self.tabs.addTab(self.tensors_tab, "Tensors")
+
+ main_layout.addWidget(self.tabs)
+
+ # Status bar
+ self.statusBar().showMessage("Ready")
+
+ def load_file(self, file_path):
+ """Load a GGUF file by path"""
+ try:
+ self.statusBar().showMessage(f"Loading {file_path}...")
+ QApplication.processEvents()
+
+ self.reader = GGUFReader(file_path, 'r')
+ self.current_file = file_path
+ self.file_path_edit.setText(file_path)
+
+ self.load_metadata()
+ self.load_tensors()
+
+ self.metadata_changes = {}
+ self.metadata_to_remove = set()
+ self.modified = False
+
+ self.statusBar().showMessage(f"Loaded {file_path}")
+ return True
+ except Exception as e:
+ QMessageBox.critical(self, "Error", f"Failed to open file: {str(e)}")
+ self.statusBar().showMessage("Error loading file")
+ return False
+
+ def open_file(self):
+ file_path, _ = QFileDialog.getOpenFileName(
+ self, "Open GGUF File", "", "GGUF Files (*.gguf);;All Files (*)"
+ )
+
+ if not file_path:
+ return
+
+ self.load_file(file_path)
+
+ def load_metadata(self):
+ self.metadata_table.setRowCount(0)
+
+ if not self.reader:
+ return
+
+ # Disconnect to prevent triggering during loading
+ if self.on_metadata_changed_is_connected:
+ with warnings.catch_warnings():
+ warnings.filterwarnings('ignore')
+ self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
+ self.on_metadata_changed_is_connected = False
+
+ for i, (key, field) in enumerate(self.reader.fields.items()):
+ self.metadata_table.insertRow(i)
+
+ # Key
+ key_item = QTableWidgetItem(key)
+ key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.metadata_table.setItem(i, 0, key_item)
+
+ # Type
+ if not field.types:
+ type_str = "N/A"
+ elif field.types[0] == GGUFValueType.ARRAY:
+ nest_count = len(field.types) - 1
+ element_type = field.types[-1].name
+ # Check if this is an enum array
+ enum_type = self.get_enum_for_key(key)
+ if enum_type is not None and field.types[-1] == GGUFValueType.INT32:
+ element_type = enum_type.__name__
+ type_str = '[' * nest_count + element_type + ']' * nest_count
+ else:
+ type_str = str(field.types[0].name)
+ # Check if this is an enum field
+ enum_type = self.get_enum_for_key(key)
+ if enum_type is not None and field.types[0] == GGUFValueType.INT32:
+ type_str = enum_type.__name__
+
+ type_item = QTableWidgetItem(type_str)
+ type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.metadata_table.setItem(i, 1, type_item)
+
+ # Value
+ value_str = self.format_field_value(field)
+ value_item = QTableWidgetItem(value_str)
+
+ # Make only simple values editable
+ if len(field.types) == 1 and field.types[0] != GGUFValueType.ARRAY:
+ value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
+ else:
+ value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+
+ self.metadata_table.setItem(i, 2, value_item)
+
+ # Actions
+ actions_widget = QWidget()
+ actions_layout = QHBoxLayout(actions_widget)
+ actions_layout.setContentsMargins(2, 2, 2, 2)
+
+ # Add Edit button for arrays and enum fields
+ if field.types and field.types[0] == GGUFValueType.ARRAY:
+ edit_button = QPushButton("Edit")
+ edit_button.setProperty("row", i)
+ edit_button.setProperty("key", key)
+ edit_button.clicked.connect(self.edit_array_metadata)
+ actions_layout.addWidget(edit_button)
+
+ # Add special label for tokenizer linked fields
+ if key in TOKENIZER_LINKED_KEYS:
+ edit_button.setText("Edit Tokenizer")
+ edit_button.setToolTip("Edit all tokenizer data together")
+ elif len(field.types) == 1 and self.get_enum_for_key(key) is not None:
+ edit_button = QPushButton("Edit")
+ edit_button.setProperty("row", i)
+ edit_button.setProperty("key", key)
+ edit_button.clicked.connect(self.edit_metadata_enum)
+ actions_layout.addWidget(edit_button)
+
+ remove_button = QPushButton("Remove")
+ remove_button.setProperty("row", i)
+ remove_button.setProperty("key", key)
+ remove_button.clicked.connect(self.remove_metadata)
+ actions_layout.addWidget(remove_button)
+
+ self.metadata_table.setCellWidget(i, 3, actions_widget)
+
+ # Reconnect after loading
+ self.metadata_table.itemChanged.connect(self.on_metadata_changed)
+ self.on_metadata_changed_is_connected = True
+
+ def extract_array_values(self, field: ReaderField) -> list:
+ """Extract all values from an array field."""
+ if not field.types or field.types[0] != GGUFValueType.ARRAY:
+ return []
+
+ curr_type = field.types[1]
+ array_values = []
+ total_elements = len(field.data)
+
+ if curr_type == GGUFValueType.STRING:
+ for element_pos in range(total_elements):
+ value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8')
+ array_values.append(value_string)
+ elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
+ for element_pos in range(total_elements):
+ array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0])
+
+ return array_values
+
+ def get_enum_for_key(self, key: str) -> Optional[Type[enum.Enum]]:
+ """Get the enum type for a given key if it exists."""
+ return KEY_TO_ENUM_TYPE.get(key)
+
+ def format_enum_value(self, value: Any, enum_type: Type[enum.Enum]) -> str:
+ """Format a value as an enum if possible."""
+ try:
+ if isinstance(value, (int, str)):
+ enum_value = enum_type(value)
+ return f"{enum_value.name} ({value})"
+ except (ValueError, KeyError):
+ pass
+ return str(value)
+
+ def format_field_value(self, field: ReaderField) -> str:
+ if not field.types:
+ return "N/A"
+
+ if len(field.types) == 1:
+ curr_type = field.types[0]
+ if curr_type == GGUFValueType.STRING:
+ return str(bytes(field.parts[-1]), encoding='utf-8')
+ elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
+ value = field.parts[-1][0]
+ # Check if this field has an enum type
+ enum_type = self.get_enum_for_key(field.name)
+ if enum_type is not None:
+ return self.format_enum_value(value, enum_type)
+ return str(value)
+
+ if field.types[0] == GGUFValueType.ARRAY:
+ array_values = self.extract_array_values(field)
+ render_element = min(5, len(array_values))
+
+ # Get enum type for this array if applicable
+ enum_type = self.get_enum_for_key(field.name)
+
+ if enum_type is not None:
+ array_elements = []
+ for i in range(render_element):
+ array_elements.append(self.format_enum_value(array_values[i], enum_type))
+ else:
+ array_elements = [str(array_values[i]) for i in range(render_element)]
+
+ return f"[ {', '.join(array_elements).strip()}{', ...' if len(array_values) > len(array_elements) else ''} ]"
+
+ return "Complex value"
+
+ def load_tensors(self):
+ self.tensors_table.setRowCount(0)
+
+ if not self.reader:
+ return
+
+ for i, tensor in enumerate(self.reader.tensors):
+ self.tensors_table.insertRow(i)
+
+ # Name
+ name_item = QTableWidgetItem(tensor.name)
+ name_item.setFlags(name_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tensors_table.setItem(i, 0, name_item)
+
+ # Type
+ type_item = QTableWidgetItem(tensor.tensor_type.name)
+ type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tensors_table.setItem(i, 1, type_item)
+
+ # Shape
+ shape_str = " × ".join(str(d) for d in tensor.shape)
+ shape_item = QTableWidgetItem(shape_str)
+ shape_item.setFlags(shape_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tensors_table.setItem(i, 2, shape_item)
+
+ # Elements
+ elements_item = QTableWidgetItem(str(tensor.n_elements))
+ elements_item.setFlags(elements_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tensors_table.setItem(i, 3, elements_item)
+
+ # Size
+ size_item = QTableWidgetItem(f"{tensor.n_bytes:,}")
+ size_item.setFlags(size_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.tensors_table.setItem(i, 4, size_item)
+
+ def on_metadata_changed(self, item):
+ if item.column() != 2: # Only handle value column changes
+ return
+
+ row = item.row()
+ orig_item = self.metadata_table.item(row, 0)
+ key = None
+ if orig_item:
+ key = orig_item.text()
+ new_value = item.text()
+
+ field = None
+ if self.reader and key:
+ field = self.reader.get_field(key)
+ if not field or not field.types or not key:
+ return
+
+ value_type = field.types[0]
+
+ # Check if this is an enum field
+ enum_type = self.get_enum_for_key(key)
+ if enum_type is not None and value_type == GGUFValueType.INT32:
+ # Try to parse the enum value from the text
+ try:
+ # Check if it's a name
+ try:
+ enum_val = enum_type[new_value]
+ converted_value = enum_val.value
+ except (KeyError, AttributeError):
+ # Check if it's a number or "NAME (value)" format
+ if '(' in new_value and ')' in new_value:
+ # Extract the value from "NAME (value)" format
+ value_part = new_value.split('(')[1].split(')')[0].strip()
+ converted_value = int(value_part)
+ else:
+ # Try to convert directly to int
+ converted_value = int(new_value)
+
+ # Validate that it's a valid enum value
+ enum_type(converted_value)
+
+ # Store the change
+ self.metadata_changes[key] = (value_type, converted_value)
+ self.modified = True
+
+ # Update display with formatted enum value
+ formatted_value = self.format_enum_value(converted_value, enum_type)
+ item.setText(formatted_value)
+
+ self.statusBar().showMessage(f"Changed {key} to {formatted_value}")
+ return
+ except (ValueError, KeyError) as e:
+ QMessageBox.warning(
+ self,
+ f"Invalid Enum Value ({e})",
+ f"'{new_value}' is not a valid {enum_type.__name__} value.\n"
+ f"Valid values are: {', '.join(v.name for v in enum_type)}")
+
+ # Revert to original value
+ original_value = self.format_field_value(field)
+ item.setText(original_value)
+ return
+
+ try:
+ # Convert the string value to the appropriate type
+ if value_type == GGUFValueType.UINT8:
+ converted_value = np.uint8(int(new_value))
+ elif value_type == GGUFValueType.INT8:
+ converted_value = np.int8(int(new_value))
+ elif value_type == GGUFValueType.UINT16:
+ converted_value = np.uint16(int(new_value))
+ elif value_type == GGUFValueType.INT16:
+ converted_value = np.int16(int(new_value))
+ elif value_type == GGUFValueType.UINT32:
+ converted_value = np.uint32(int(new_value))
+ elif value_type == GGUFValueType.INT32:
+ converted_value = np.int32(int(new_value))
+ elif value_type == GGUFValueType.FLOAT32:
+ converted_value = np.float32(float(new_value))
+ elif value_type == GGUFValueType.BOOL:
+ converted_value = new_value.lower() in ('true', 'yes', '1')
+ elif value_type == GGUFValueType.STRING:
+ converted_value = new_value
+ else:
+ # Unsupported type for editing
+ return
+
+ # Store the change
+ self.metadata_changes[key] = (value_type, converted_value)
+ self.modified = True
+
+ self.statusBar().showMessage(f"Changed {key} to {new_value}")
+ except ValueError:
+ QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}")
+
+ # Revert to original value
+ original_value = self.format_field_value(field)
+ item.setText(original_value)
+
+ def remove_metadata(self):
+ button = self.sender()
+ key = button.property("key")
+ row = button.property("row")
+
+ reply = QMessageBox.question(
+ self, "Confirm Removal",
+ f"Are you sure you want to remove the metadata key '{key}'?",
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No
+ )
+
+ if reply == QMessageBox.StandardButton.Yes:
+ self.metadata_table.removeRow(row)
+ self.metadata_to_remove.add(key)
+
+ # If we previously had changes for this key, remove them
+ if key in self.metadata_changes:
+ del self.metadata_changes[key]
+
+ self.modified = True
+ self.statusBar().showMessage(f"Marked {key} for removal")
+
+ def edit_metadata_enum(self):
+ """Edit an enum metadata field."""
+ button = self.sender()
+ key = button.property("key")
+ row = button.property("row")
+
+ field = None
+ if self.reader:
+ field = self.reader.get_field(key)
+ if not field or not field.types:
+ return
+
+ enum_type = self.get_enum_for_key(key)
+ if enum_type is None:
+ return
+
+ # Get current value
+ current_value = field.contents()
+
+ # Create a dialog with enum options
+ dialog = QDialog(self)
+ dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
+ layout = QVBoxLayout(dialog)
+
+ combo = QComboBox()
+ for enum_val in enum_type:
+ combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+ # Set current value
+ try:
+ if isinstance(current_value, (int, str)):
+ enum_val = enum_type(current_value)
+ combo.setCurrentText(f"{enum_val.name} ({current_value})")
+ except (ValueError, KeyError):
+ pass
+
+ layout.addWidget(combo)
+
+ buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ buttons.accepted.connect(dialog.accept)
+ buttons.rejected.connect(dialog.reject)
+ layout.addWidget(buttons)
+
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ # Get the selected value
+ new_value = combo.currentData()
+ enum_val = enum_type(new_value)
+
+ # Store the change
+ self.metadata_changes[key] = (field.types[0], new_value)
+ self.modified = True
+
+ # Update display
+ display_text = f"{enum_val.name} ({new_value})"
+ target_item = self.metadata_table.item(row, 2)
+ if target_item:
+ target_item.setText(display_text)
+
+ self.statusBar().showMessage(f"Changed {key} to {display_text}")
+
+ def edit_array_metadata(self):
+ button = self.sender()
+ key = button.property("key")
+ row = button.property("row")
+
+ # Check if this is one of the linked tokenizer keys
+ if key in TOKENIZER_LINKED_KEYS:
+ self.edit_tokenizer_metadata(key)
+ return
+
+ field = None
+ if self.reader:
+ field = self.reader.get_field(key)
+ if not field or not field.types or field.types[0] != GGUFValueType.ARRAY:
+ return
+
+ # Get array element type
+ element_type = field.types[1]
+
+ # Extract array values
+ array_values = self.extract_array_values(field)
+
+ # Open array editor dialog
+ dialog = ArrayEditorDialog(array_values, element_type, key, self)
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ new_values = dialog.get_array_values()
+
+ # Store the change
+ self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values))
+ self.modified = True
+
+ # Update display
+ enum_type = self.get_enum_for_key(key)
+ if enum_type is not None and element_type == GGUFValueType.INT32:
+ value_str = f"[ {', '.join(self.format_enum_value(v, enum_type) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
+ else:
+ value_str = f"[ {', '.join(str(v) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
+ target_item = self.metadata_table.item(row, 2)
+ if target_item:
+ target_item.setText(value_str)
+
+ self.statusBar().showMessage(f"Updated array values for {key}")
+
+ def edit_tokenizer_metadata(self, trigger_key):
+ """Edit the linked tokenizer metadata arrays together."""
+ if not self.reader:
+ return
+
+ # Get all three fields
+ tokens_field = self.reader.get_field(gguf.Keys.Tokenizer.LIST)
+ token_types_field = self.reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
+ scores_field = self.reader.get_field(gguf.Keys.Tokenizer.SCORES)
+
+ # Extract values from each field
+ tokens = self.extract_array_values(tokens_field) if tokens_field else []
+ token_types = self.extract_array_values(token_types_field) if token_types_field else []
+ scores = self.extract_array_values(scores_field) if scores_field else []
+
+ # Apply any pending changes
+ if gguf.Keys.Tokenizer.LIST in self.metadata_changes:
+ _, (_, tokens) = self.metadata_changes[gguf.Keys.Tokenizer.LIST]
+ if gguf.Keys.Tokenizer.TOKEN_TYPE in self.metadata_changes:
+ _, (_, token_types) = self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE]
+ if gguf.Keys.Tokenizer.SCORES in self.metadata_changes:
+ _, (_, scores) = self.metadata_changes[gguf.Keys.Tokenizer.SCORES]
+
+ # Open the tokenizer editor dialog
+ dialog = TokenizerEditorDialog(tokens, token_types, scores, self)
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ new_tokens, new_token_types, new_scores = dialog.get_data()
+
+ # Store changes for all three arrays
+ if tokens_field:
+ self.metadata_changes[gguf.Keys.Tokenizer.LIST] = (
+ GGUFValueType.ARRAY,
+ (tokens_field.types[1], new_tokens)
+ )
+
+ if token_types_field:
+ self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = (
+ GGUFValueType.ARRAY,
+ (token_types_field.types[1], new_token_types)
+ )
+
+ if scores_field:
+ self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = (
+ GGUFValueType.ARRAY,
+ (scores_field.types[1], new_scores)
+ )
+
+ self.modified = True
+
+ # Update display for all three fields
+ self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens)
+ self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types)
+ self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores)
+
+ self.statusBar().showMessage("Updated tokenizer data")
+
+ def update_tokenizer_display(self, key, values):
+ """Update the display of a tokenizer field in the metadata table."""
+ for row in range(self.metadata_table.rowCount()):
+ key_item = self.metadata_table.item(row, 0)
+ if key_item and key_item.text() == key:
+ value_str = f"[ {', '.join(str(v) for v in values[:5])}{', ...' if len(values) > 5 else ''} ]"
+ value_item = self.metadata_table.item(row, 2)
+ if value_item:
+ value_item.setText(value_str)
+ break
+
+ def add_metadata(self):
+ dialog = AddMetadataDialog(self)
+ if dialog.exec() == QDialog.DialogCode.Accepted:
+ key, value_type, value = dialog.get_data()
+
+ if not key:
+ QMessageBox.warning(self, "Invalid Key", "Key cannot be empty")
+ return
+
+ # Check if key already exists
+ for row in range(self.metadata_table.rowCount()):
+ orig_item = self.metadata_table.item(row, 0)
+ if orig_item and orig_item.text() == key:
+ QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists")
+ return
+
+ # Add to table
+ row = self.metadata_table.rowCount()
+ self.metadata_table.insertRow(row)
+
+ # Key
+ key_item = QTableWidgetItem(key)
+ key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.metadata_table.setItem(row, 0, key_item)
+
+ # Type
+ type_item = QTableWidgetItem(value_type.name)
+ type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+ self.metadata_table.setItem(row, 1, type_item)
+
+ # Value
+ value_item = QTableWidgetItem(str(value))
+ value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
+ self.metadata_table.setItem(row, 2, value_item)
+
+ # Actions
+ actions_widget = QWidget()
+ actions_layout = QHBoxLayout(actions_widget)
+ actions_layout.setContentsMargins(2, 2, 2, 2)
+
+ remove_button = QPushButton("Remove")
+ remove_button.setProperty("row", row)
+ remove_button.setProperty("key", key)
+ remove_button.clicked.connect(self.remove_metadata)
+ actions_layout.addWidget(remove_button)
+
+ self.metadata_table.setCellWidget(row, 3, actions_widget)
+
+ # Store the change
+ self.metadata_changes[key] = (value_type, value)
+ self.modified = True
+
+ self.statusBar().showMessage(f"Added new metadata key {key}")
+
+ def save_file(self):
+ if not self.reader:
+ QMessageBox.warning(self, "No File Open", "Please open a GGUF file first")
+ return
+
+ if not self.modified and not self.metadata_changes and not self.metadata_to_remove:
+ QMessageBox.information(self, "No Changes", "No changes to save")
+ return
+
+ file_path, _ = QFileDialog.getSaveFileName(
+ self, "Save GGUF File As", "", "GGUF Files (*.gguf);;All Files (*)"
+ )
+
+ if not file_path:
+ return
+
+ try:
+ self.statusBar().showMessage(f"Saving to {file_path}...")
+ QApplication.processEvents()
+
+ # Get architecture and endianness from the original file
+ arch = 'unknown'
+ field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE)
+ if field:
+ arch = field.contents()
+
+ # Create writer
+ writer = GGUFWriter(file_path, arch=arch, endianess=self.reader.endianess)
+
+ # Get alignment if present
+ alignment = None
+ field = self.reader.get_field(gguf.Keys.General.ALIGNMENT)
+ if field:
+ alignment = field.contents()
+ if alignment is not None:
+ writer.data_alignment = alignment
+
+ # Copy metadata with changes
+ for field in self.reader.fields.values():
+ # Skip virtual fields and fields written by GGUFWriter
+ if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
+ continue
+
+ # Skip fields marked for removal
+ if field.name in self.metadata_to_remove:
+ continue
+
+ # Apply changes if any
+ sub_type = None
+ if field.name in self.metadata_changes:
+ value_type, value = self.metadata_changes[field.name]
+ if value_type == GGUFValueType.ARRAY:
+ # Handle array values
+ sub_type, value = value
+ else:
+ # Copy original value
+ value = field.contents()
+ value_type = field.types[0]
+ if value_type == GGUFValueType.ARRAY:
+ sub_type = field.types[-1]
+
+ if value is not None:
+ writer.add_key_value(field.name, value, value_type, sub_type=sub_type)
+
+ # Add new metadata
+ for key, (value_type, value) in self.metadata_changes.items():
+ # Skip if the key already existed (we handled it above)
+ if self.reader.get_field(key) is not None:
+ continue
+
+ sub_type = None
+ if value_type == GGUFValueType.ARRAY:
+ # Handle array values
+ sub_type, value = value
+
+ writer.add_key_value(key, value, value_type, sub_type=sub_type)
+
+ # Add tensors (including data)
+ for tensor in self.reader.tensors:
+ writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
+
+ # Write header and metadata
+ writer.open_output_file(Path(file_path))
+ writer.write_header_to_file()
+ writer.write_kv_data_to_file()
+
+ # Write tensor data using the optimized method
+ writer.write_tensors_to_file(progress=False)
+
+ writer.close()
+
+ self.statusBar().showMessage(f"Saved to {file_path}")
+
+ # Ask if user wants to open the new file
+ reply = QMessageBox.question(
+ self, "Open Saved File",
+ "Would you like to open the newly saved file?",
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes
+ )
+
+ if reply == QMessageBox.StandardButton.Yes:
+ self.reader = GGUFReader(file_path, 'r')
+ self.current_file = file_path
+ self.file_path_edit.setText(file_path)
+
+ self.load_metadata()
+ self.load_tensors()
+
+ self.metadata_changes = {}
+ self.metadata_to_remove = set()
+ self.modified = False
+
+ except Exception as e:
+ QMessageBox.critical(self, "Error", f"Failed to save file: {str(e)}")
+ self.statusBar().showMessage("Error saving file")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="GUI GGUF Editor")
+ parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup")
+ parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ app = QApplication(sys.argv)
+ window = GGUFEditorWindow()
+ window.show()
+
+ # Load model if specified
+ if args.model_path:
+ if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'):
+ window.load_file(args.model_path)
+ else:
+ logger.error(f"Invalid model path: {args.model_path}")
+ QMessageBox.warning(
+ window,
+ "Invalid Model Path",
+ f"The specified file does not exist or is not a GGUF file: {args.model_path}")
+
+ sys.exit(app.exec())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_hash.py b/llama.cpp/gguf-py/gguf/scripts/gguf_hash.py
new file mode 100755
index 0000000..3ef9899
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_hash.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import uuid
+import hashlib
+
+import logging
+import argparse
+import os
+import sys
+from pathlib import Path
+
+from tqdm import tqdm
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from gguf import GGUFReader # noqa: E402
+
+
+logger = logging.getLogger("gguf-hash")
+
+# UUID_NAMESPACE_LLAMA_CPP = uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
+UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5')
+
+
+# For more information about what field.parts and field.data represent,
+# please see the comments in the modify_gguf.py example.
+def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None:
+ sha1 = hashlib.sha1()
+ sha256 = hashlib.sha256()
+ uuidv5_sha1 = hashlib.sha1()
+ uuidv5_sha1.update(UUID_NAMESPACE_LLAMA_CPP.bytes)
+
+ # Total Weight Calculation For Progress Bar
+ total_weights = 0
+ for n, tensor in enumerate(reader.tensors, 1):
+
+ # We don't need these
+ if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
+ continue
+
+ # Calculate Tensor Volume
+ sum_weights_in_tensor = 1
+ for dim in tensor.shape:
+ sum_weights_in_tensor *= dim
+ total_weights += sum_weights_in_tensor
+
+ # Hash Progress Bar
+ bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar)
+
+ # Hashing Process
+ for tensor in reader.tensors:
+
+ # We don't need these
+ if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
+ continue
+
+ # Progressbar
+ sum_weights_in_tensor = 1
+ for dim in tensor.shape:
+ sum_weights_in_tensor *= dim
+ bar.update(sum_weights_in_tensor)
+
+ if not no_layer:
+
+ sha1_layer = hashlib.sha1()
+ sha1_layer.update(tensor.data.data)
+ print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
+
+ sha256_layer = hashlib.sha256()
+ sha256_layer.update(tensor.data.data)
+ print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
+
+ sha1.update(tensor.data.data)
+ sha256.update(tensor.data.data)
+ uuidv5_sha1.update(tensor.data.data)
+
+ # Flush Hash Progress Bar
+ bar.close()
+
+ # Display Hash Output
+ print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100
+ print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100
+ print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
+ parser.add_argument("model", type=str, help="GGUF format model filename")
+ parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash")
+ parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+ parser.add_argument("--progressbar", action="store_true", help="enable progressbar")
+ args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+ reader = GGUFReader(args.model, 'r')
+ gguf_hash(reader, args.model, not args.progressbar, args.no_layer)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py b/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py
new file mode 100755
index 0000000..c67436b
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import argparse
+import os
+import sys
+import json
+from pathlib import Path
+
+from tqdm import tqdm
+from typing import Any, Sequence, NamedTuple
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+import gguf
+
+logger = logging.getLogger("gguf-new-metadata")
+
+
+class MetadataDetails(NamedTuple):
+ type: gguf.GGUFValueType
+ value: Any
+ description: str = ''
+ sub_type: gguf.GGUFValueType | None = None
+
+
+def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
+ field = reader.get_field(key)
+
+ return field.contents() if field else None
+
+
+def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
+ token_ids = [index for index, value in enumerate(token_list) if value == token]
+
+ if len(token_ids) == 0:
+ raise LookupError(f'Unable to find "{token}" in token list!')
+
+ return token_ids
+
+
+def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
+ for field in reader.fields.values():
+ # Suppress virtual fields and fields written by GGUFWriter
+ if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
+ logger.debug(f'Suppressing {field.name}')
+ continue
+
+ # Skip old chat templates if we have new ones
+ if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
+ logger.debug(f'Skipping {field.name}')
+ continue
+
+ if field.name in remove_metadata:
+ logger.debug(f'Removing {field.name}')
+ continue
+
+ val_type = field.types[0]
+ sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
+ old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
+ val = new_metadata.get(field.name, old_val)
+
+ if field.name in new_metadata:
+ logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
+ del new_metadata[field.name]
+ elif val.value is not None:
+ logger.debug(f'Copying {field.name}')
+
+ if val.value is not None:
+ writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
+
+ if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
+ logger.debug('Adding chat template(s)')
+ writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
+ del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
+
+ for key, val in new_metadata.items():
+ logger.debug(f'Adding {key}: "{val.value}" {val.description}')
+ writer.add_key_value(key, val.value, val.type)
+
+ total_bytes = 0
+
+ for tensor in reader.tensors:
+ total_bytes += tensor.n_bytes
+ writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
+
+ bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
+
+ writer.write_header_to_file()
+ writer.write_kv_data_to_file()
+ writer.write_ti_data_to_file()
+
+ for tensor in reader.tensors:
+ writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
+ bar.update(tensor.n_bytes)
+
+ writer.close()
+
+
+def main() -> None:
+ tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
+ token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
+
+ parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
+ parser.add_argument("input", type=Path, help="GGUF format model input filename")
+ parser.add_argument("output", type=Path, help="GGUF format model output filename")
+ parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"')
+ parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
+ parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
+ parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
+ parser.add_argument("--chat-template-file", type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja')
+ parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"')
+ parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
+ parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
+ parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
+ parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
+ parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
+ args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ new_metadata = {}
+ remove_metadata = args.remove_metadata or []
+
+ if args.general_name:
+ new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
+
+ if args.general_description:
+ new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
+
+ if args.chat_template:
+ new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
+
+ if args.chat_template_config:
+ with open(args.chat_template_config, 'r', encoding='utf-8') as fp:
+ config = json.load(fp)
+ template = config.get('chat_template')
+ if template:
+ new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
+
+ if args.chat_template_file:
+ with open(args.chat_template_file, 'r', encoding='utf-8') as fp:
+ template = fp.read()
+ new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
+
+ if args.pre_tokenizer:
+ new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer)
+
+ if remove_metadata:
+ logger.warning('*** Warning *** Warning *** Warning **')
+ logger.warning('* Most metadata is required for a fully functional GGUF file,')
+ logger.warning('* removing crucial metadata may result in a corrupt output file!')
+
+ if not args.force:
+ logger.warning('* Enter exactly YES if you are positive you want to proceed:')
+ response = input('YES, I am sure> ')
+ if response != 'YES':
+ logger.info("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+
+ logger.info(f'* Loading: {args.input}')
+ reader = gguf.GGUFReader(args.input, 'r')
+
+ arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
+
+ token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
+
+ for name, token in args.special_token or []:
+ if name not in token_names:
+ logger.warning(f'Unknown special token "{name}", ignoring...')
+ else:
+ ids = find_token(token_list, token)
+ new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
+
+ if len(ids) > 1:
+ logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
+ logger.warning(', '.join(str(i) for i in ids))
+
+ for name, id_string in args.special_token_by_id or []:
+ if name not in token_names:
+ logger.warning(f'Unknown special token "{name}", ignoring...')
+ elif not id_string.isdecimal():
+ raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
+ else:
+ id_int = int(id_string)
+
+ if id_int >= 0 and id_int < len(token_list):
+ new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
+ else:
+ raise LookupError(f'Token ID {id_int} is not within token list!')
+
+ if os.path.isfile(args.output) and not args.force:
+ logger.warning('*** Warning *** Warning *** Warning **')
+ logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
+ logger.warning('* Enter exactly YES if you are positive you want to proceed:')
+ response = input('YES, I am sure> ')
+ if response != 'YES':
+ logger.info("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+
+ logger.info(f'* Writing: {args.output}')
+ writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
+
+ alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
+ if alignment is not None:
+ logger.debug(f'Setting custom alignment: {alignment}')
+ writer.data_alignment = alignment
+
+ copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llama.cpp/gguf-py/gguf/scripts/gguf_set_metadata.py b/llama.cpp/gguf-py/gguf/scripts/gguf_set_metadata.py
new file mode 100755
index 0000000..f5809c3
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/scripts/gguf_set_metadata.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python3
+import logging
+import argparse
+import os
+import sys
+from pathlib import Path
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from gguf import GGUFReader # noqa: E402
+
+logger = logging.getLogger("gguf-set-metadata")
+
+
+def minimal_example(filename: str) -> None:
+ reader = GGUFReader(filename, 'r+')
+ field = reader.fields['tokenizer.ggml.bos_token_id']
+ if field is None:
+ return
+ part_index = field.data[0]
+ field.parts[part_index][0] = 2 # Set tokenizer.ggml.bos_token_id to 2
+ #
+ # So what's this field.data thing? It's helpful because field.parts contains
+ # _every_ part of the GGUF field. For example, tokenizer.ggml.bos_token_id consists
+ # of:
+ #
+ # Part index 0: Key length (27)
+ # Part index 1: Key data ("tokenizer.ggml.bos_token_id")
+ # Part index 2: Field type (4, the id for GGUFValueType.UINT32)
+ # Part index 3: Field value
+ #
+ # Note also that each part is an NDArray slice, so even a part that
+ # is only a single value like the key length will be a NDArray of
+ # the key length type (numpy.uint32).
+ #
+ # The .data attribute in the Field is a list of relevant part indexes
+ # and doesn't contain internal GGUF details like the key length part.
+ # In this case, .data will be [3] - just the part index of the
+ # field value itself.
+
+
+def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+ field = reader.get_field(args.key)
+ if field is None:
+ logger.error(f'! Field {repr(args.key)} not found')
+ sys.exit(1)
+ # Note that field.types is a list of types. This is because the GGUF
+ # format supports arrays. For example, an array of UINT32 would
+ # look like [GGUFValueType.ARRAY, GGUFValueType.UINT32]
+ handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None
+ if handler is None:
+ logger.error(f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}')
+ sys.exit(1)
+ current_value = field.parts[field.data[0]][0]
+ new_value = handler(args.value)
+ logger.info(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}')
+ if current_value == new_value:
+ logger.info(f'- Key {repr(args.key)} already set to requested value {current_value}')
+ sys.exit(0)
+ if args.dry_run:
+ sys.exit(0)
+ if not args.force:
+ logger.warning('*** Warning *** Warning *** Warning **')
+ logger.warning('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.')
+ logger.warning('* Enter exactly YES if you are positive you want to proceed:')
+ response = input('YES, I am sure> ')
+ if response != 'YES':
+ logger.info("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+ field.parts[field.data[0]][0] = new_value
+ logger.info('* Field changed. Successful completion.')
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata")
+ parser.add_argument("model", type=str, help="GGUF format model filename")
+ parser.add_argument("key", type=str, help="Metadata key to set")
+ parser.add_argument("value", type=str, help="Metadata value to set")
+ parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything")
+ parser.add_argument("--force", action="store_true", help="Change the field without confirmation")
+ parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+
+ args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ logger.info(f'* Loading: {args.model}')
+ reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+')
+ set_metadata(reader, args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llama.cpp/gguf-py/gguf/tensor_mapping.py b/llama.cpp/gguf-py/gguf/tensor_mapping.py
new file mode 100644
index 0000000..4364790
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/tensor_mapping.py
@@ -0,0 +1,1948 @@
+from __future__ import annotations
+
+from typing import Sequence
+
+from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
+
+
+class TensorNameMap:
+ mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
+ # Token embeddings
+ MODEL_TENSOR.TOKEN_EMBD: (
+ "gpt_neox.embed_in", # gptneox
+ "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
+ "transformer.word_embeddings", # falcon
+ "word_embeddings", # bloom
+ "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
+ "embed_tokens", # embeddinggemma
+ "tok_embeddings", # llama-pth
+ "embeddings.word_embeddings", # bert nomic-bert
+ "embeddings.tok_embeddings", # modern-bert
+ "language_model.embedding.word_embeddings", # persimmon
+ "wte", # gpt2
+ "transformer.embd.wte", # phi2
+ "model.tok_embeddings", # internlm2
+ "model.embedding", # mamba-qbert
+ "backbone.embedding", # mamba
+ "backbone.embeddings", # mamba-hf
+ "transformer.in_out_embed", # Grok
+ "embedding.word_embeddings", # chatglm
+ "transformer.token_embeddings", # openelm
+ "shared", # t5
+ "rwkv.embeddings", # rwkv6
+ "model.embeddings", # rwkv7
+ "model.word_embeddings", # bailingmoe
+ "language_model.model.embed_tokens", # llama4
+ "encoder", # neobert
+ "model.transformer.wte", # llada
+ "embed_tokens", # qwen3-embedding
+ ),
+
+ # Token type embeddings
+ MODEL_TENSOR.TOKEN_TYPES: (
+ "embeddings.token_type_embeddings", # bert nomic-bert
+ ),
+
+ # Normalization of token embeddings
+ MODEL_TENSOR.TOKEN_EMBD_NORM: (
+ "word_embeddings_layernorm", # bloom
+ "embeddings.LayerNorm", # bert
+ "embeddings.norm", # modern-bert
+ "emb_ln", # nomic-bert
+ "transformer.norm", # openelm
+ "rwkv.blocks.0.pre_ln", # rwkv
+ "rwkv.blocks.0.pre_ln", # rwkv6
+ "model.pre_ln", # rwkv7
+ "model.layers.0.pre_norm", # rwkv7
+ "backbone.norm", # wavtokenizer
+ "model.embedding_norm", # lfm2
+ ),
+
+ # Position embeddings
+ MODEL_TENSOR.POS_EMBD: (
+ "transformer.wpe", # gpt2
+ "embeddings.position_embeddings", # bert
+ "wpe", # gpt2
+ ),
+
+ # Output
+ MODEL_TENSOR.OUTPUT: (
+ "embed_out", # gptneox
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2
+ "output", # llama-pth bloom internlm2
+ "word_embeddings_for_head", # persimmon
+ "lm_head.linear", # phi2
+ "output_layer", # chatglm
+ "head", # rwkv
+ "head.out", # wavtokenizer
+ "lm_head", # llama4
+ "model.transformer.ff_out", # llada
+ "head.decoder", # modern-bert
+ ),
+ MODEL_TENSOR.DENSE_2_OUT: (
+ "dense_2_out", # embeddinggemma
+ ),
+ MODEL_TENSOR.DENSE_3_OUT: (
+ "dense_3_out", # embeddinggemma
+ ),
+ # Output norm
+ MODEL_TENSOR.OUTPUT_NORM: (
+ "gpt_neox.final_layer_norm", # gptneox
+ "transformer.ln_f", # gpt2 gpt-j falcon jais exaone
+ "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2
+ "norm", # llama-pth
+ "transformer.norm_f", # mpt dbrx
+ "ln_f", # refact bloom qwen gpt2
+ "language_model.encoder.final_layernorm", # persimmon
+ "model.final_layernorm", # persimmon
+ "lm_head.ln", # phi2
+ "model.norm_f", # mamba-qbert
+ "backbone.norm_f", # mamba
+ "transformer.rms_norm", # Grok
+ "encoder.final_layernorm", # chatglm
+ "transformer.norm", # openelm
+ "model.norm", # nemotron
+ "rwkv.ln_out", # rwkv6
+ "model.ln_out", # rwkv7
+ "backbone.final_layer_norm", # wavtokenizer
+ "model.norm", # llama4
+ "model.transformer.ln_f", # llada
+ "final_norm", # modern-bert
+ "model.norm", # cogvlm
+ ),
+
+ # Rope frequencies
+ MODEL_TENSOR.ROPE_FREQS: (
+ "rope.freqs", # llama-pth
+ "rotary_pos_emb.inv_freq", # chatglm
+ ),
+
+ MODEL_TENSOR.ROPE_FACTORS_LONG: (),
+ MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
+
+ MODEL_TENSOR.CONV1D: (
+ "backbone.embed", # roberta
+ ),
+
+ MODEL_TENSOR.V_MM_EMBEDDING: (
+ "model.embed_vision.embedding", # gemma3n
+ ),
+ MODEL_TENSOR.V_MM_HARD_EMB_NORM: (
+ "model.embed_vision.hard_embedding_norm", # gemma3n
+ ),
+ MODEL_TENSOR.V_MM_INP_PROJ: (
+ "model.embed_vision.embedding_projection", # gemma3n
+ ),
+ MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
+ "model.embed_vision.soft_embedding_norm", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_CONV_STEM: (
+ "model.vision_tower.timm_model.conv_stem.conv", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_CONV_STEM_NORM: (
+ "model.vision_tower.timm_model.conv_stem.bn", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_MSFA_EXP: (
+ "model.vision_tower.timm_model.msfa.ffn.pw_exp.conv", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_MSFA_EXP_NORM: (
+ "model.vision_tower.timm_model.msfa.ffn.pw_exp.bn", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_MSFA_PROJ: (
+ "model.vision_tower.timm_model.msfa.ffn.pw_proj.conv", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_MSFA_PROJ_NORM: (
+ "model.vision_tower.timm_model.msfa.ffn.pw_proj.bn", # gemma3n
+ ),
+ MODEL_TENSOR.V_ENC_MSFA_NORM: (
+ "model.vision_tower.timm_model.msfa.norm", # gemma3n
+ ),
+ }
+
+ block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
+ # Attention norm
+ MODEL_TENSOR.ATTN_NORM: (
+ "gpt_neox.layers.{bid}.input_layernorm", # gptneox
+ "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
+ "transformer.blocks.{bid}.norm_1", # mpt
+ "transformer.h.{bid}.input_layernorm", # falcon7b
+ "h.{bid}.input_layernorm", # bloom
+ "transformer.h.{bid}.ln_mlp", # falcon40b
+ "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe granite-hybrid
+ "layers.{bid}.attention_norm", # llama-pth
+ "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
+ "model.layers.{bid}.ln1", # yi
+ "h.{bid}.ln_1", # gpt2
+ "transformer.h.{bid}.ln", # phi2
+ "model.layers.layers.{bid}.norm", # plamo
+ "model.layers.layers.{bid}.pre_mixer_norm", # plamo2
+ "model.layers.{bid}.attention_norm", # internlm2
+ "model.layers.{bid}.norm", # mamba-qbert
+ "backbone.layers.{bid}.norm", # mamba
+ "transformer.decoder_layer.{bid}.rms_norm", # Grok
+ "model.layers.{bid}.pre_attn_norm", # grok-2
+ "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
+ "encoder.layers.{bid}.input_layernorm", # chatglm
+ "transformer.layers.{bid}.attn_norm", # openelm
+ "rwkv.blocks.{bid}.ln1", # rwkv6
+ "model.layers.{bid}.ln1", # rwkv7
+ "model.layers.{bid}.input_layernorm", # llama4
+ "layers.{bid}.input_layernorm", # embeddinggemma
+ "transformer_encoder.{bid}.attention_norm", # neobert
+ "layers.{bid}.attn_norm", # modern-bert
+ "model.layers.{bid}.operator_norm", # lfm2
+ "model.transformer.blocks.{bid}.attn_norm", # llada
+ "layers.{bid}.input_layernorm", # qwen3-embedding
+ "model.layers.{bid}.attention_layernorm", # apertus
+ "model.layers.{bid}.pre_attention_layernorm", # kormo
+ ),
+
+ # Attention norm 2
+ MODEL_TENSOR.ATTN_NORM_2: (
+ "transformer.h.{bid}.ln_attn", # falcon40b
+ "encoder.layer.{bid}.layer_norm_1", # jina-v2-code
+ "rwkv.blocks.{bid}.ln2", # rwkv6
+ "model.layers.{bid}.ln2", # rwkv7
+ "model.layers.{bid}.post_attention_layernorm", # cogvlm
+ ),
+
+ # Attention query-key-value
+ MODEL_TENSOR.ATTN_QKV: (
+ "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
+ "transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
+ "transformer.blocks.{bid}.attn.Wqkv", # mpt
+ "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
+ "transformer.h.{bid}.self_attention.query_key_value", # falcon
+ "h.{bid}.self_attention.query_key_value", # bloom
+ "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
+ "model.layers.{bid}.self_attn.query_key_value", # persimmon
+ "model.layers.{bid}.attention.query_key_value", # bailingmoe2
+ "h.{bid}.attn.c_attn", # gpt2
+ "transformer.h.{bid}.mixer.Wqkv", # phi2
+ "encoder.layers.{bid}.attn.Wqkv", # nomic-bert
+ "encoder.layers.{bid}.mixer.Wqkv", # jina
+ "model.layers.{bid}.self_attn.qkv_proj", # phi3
+ "model.layers.layers.{bid}.mixer.qkv_proj", # plamo2
+ "encoder.layers.{bid}.self_attention.query_key_value", # chatglm
+ "transformer.layers.{bid}.attn.qkv_proj", # openelm
+ "transformer_encoder.{bid}.qkv", # neobert
+ "layers.{bid}.attn.Wqkv", # modern-bert
+ "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
+ "model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5
+ ),
+
+ # Attention query
+ MODEL_TENSOR.ATTN_Q: (
+ "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.q_proj", # embeddinggemma
+ "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
+ "layers.{bid}.attention.wq", # llama-pth
+ "encoder.layer.{bid}.attention.self.query", # bert
+ "transformer.layer.{bid}.attention.q_lin", # distillbert
+ "transformer.h.{bid}.attn.q_proj", # gpt-j
+ "model.layers.layers.{bid}.self_attn.q_proj", # plamo
+ "model.layers.{bid}.attention.wq", # internlm2
+ "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
+ "transformer.h.{bid}.attn.attention.q_proj", # exaone
+ "model.layers.{bid}.self_attn.q_proj", # llama4
+ "model.transformer.blocks.{bid}.q_proj", # llada
+ "layers.{bid}.self_attn.q_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.q_proj", # nemotron-h
+ ),
+
+ # Attention key
+ MODEL_TENSOR.ATTN_K: (
+ "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.k_proj", # embeddinggemma
+ "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
+ "layers.{bid}.attention.wk", # llama-pth
+ "encoder.layer.{bid}.attention.self.key", # bert
+ "transformer.layer.{bid}.attention.k_lin", # distillbert
+ "transformer.h.{bid}.attn.k_proj", # gpt-j
+ "transformer.h.{bid}.attn.k", # refact
+ "model.layers.layers.{bid}.self_attn.k_proj", # plamo
+ "model.layers.{bid}.attention.wk", # internlm2
+ "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
+ "transformer.h.{bid}.attn.attention.k_proj", # exaone
+ "model.layers.{bid}.self_attn.k_proj", # llama4
+ "model.transformer.blocks.{bid}.k_proj", # llada
+ "layers.{bid}.self_attn.k_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.k_proj", # nemotron-h
+ ),
+
+ # Attention value
+ MODEL_TENSOR.ATTN_V: (
+ "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.v_proj", # embeddinggemma
+ "layers.{bid}.attention.wv", # llama-pth
+ "encoder.layer.{bid}.attention.self.value", # bert
+ "transformer.layer.{bid}.attention.v_lin", # distillbert
+ "transformer.h.{bid}.attn.v_proj", # gpt-j
+ "transformer.h.{bid}.attn.v", # refact
+ "model.layers.layers.{bid}.self_attn.v_proj", # plamo
+ "model.layers.{bid}.attention.wv", # internlm2
+ "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
+ "transformer.h.{bid}.attn.attention.v_proj", # exaone
+ "model.layers.{bid}.self_attn.v_proj", # llama4
+ "model.transformer.blocks.{bid}.v_proj", # llada
+ "layers.{bid}.self_attn.v_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.v_proj", # nemotron-h
+ ),
+
+ # Attention output
+ MODEL_TENSOR.ATTN_OUT: (
+ "gpt_neox.layers.{bid}.attention.dense", # gptneox
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
+ "transformer.blocks.{bid}.attn.out_proj", # mpt
+ "transformer.h.{bid}.self_attention.dense", # falcon
+ "h.{bid}.self_attention.dense", # bloom
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
+ "layers.{bid}.self_attn.o_proj", # embeddinggemma
+ "model.layers.{bid}.self_attn.out_proj", # lfm2
+ "model.layers.{bid}.self_attn.linear_attn", # deci
+ "layers.{bid}.attention.wo", # llama-pth
+ "encoder.layer.{bid}.attention.output.dense", # bert
+ "layers.{bid}.attn.Wo", # modern-bert
+ "transformer.layer.{bid}.attention.out_lin", # distillbert
+ "transformer.h.{bid}.attn.out_proj", # gpt-j
+ "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
+ "model.layers.{bid}.self_attn.dense", # persimmon
+ "model.layers.{bid}.attention.dense", # bailingmoe2
+ "h.{bid}.attn.c_proj", # gpt2
+ "transformer.h.{bid}.mixer.out_proj", # phi2
+ "model.layers.layers.{bid}.self_attn.o_proj", # plamo
+ "model.layers.layers.{bid}.mixer.o_proj", # plamo2
+ "model.layers.{bid}.attention.wo", # internlm2
+ "encoder.layers.{bid}.attn.out_proj", # nomic-bert
+ "encoder.layers.{bid}.mixer.out_proj", # jina
+ "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
+ "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
+ "encoder.layers.{bid}.self_attention.dense", # chatglm
+ "transformer.layers.{bid}.attn.out_proj", # openelm
+ "transformer.h.{bid}.attn.attention.out_proj", # exaone
+ "model.layers.{bid}.self_attn.o_proj", # llama4
+ "transformer_encoder.{bid}.wo", # neobert
+ "model.transformer.blocks.{bid}.attn_out", # llada
+ "layers.{bid}.self_attn.o_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.o_proj", # nemotron-h
+ "model.layers.{bid}.self_attn.language_expert_dense", # cogvlm
+ ),
+
+ # Attention output norm
+ MODEL_TENSOR.ATTN_OUT_NORM: (
+ "encoder.layer.{bid}.attention.output.LayerNorm", # bert
+ "transformer.layer.{bid}.sa_layer_norm", # distillbert
+ "encoder.layers.{bid}.norm1", # nomic-bert
+ "transformer.decoder_layer.{bid}.rms_norm_1", # Grok
+ "model.layers.{bid}.post_attn_norm", # grok-2
+ "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
+ ),
+
+ MODEL_TENSOR.ATTN_POST_NORM: (
+ "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
+ "layers.{bid}.post_attention_layernorm", # embeddinggemma
+ "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
+ "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
+ ),
+
+ # Rotary embeddings
+ MODEL_TENSOR.ATTN_ROT_EMBD: (
+ "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
+ "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
+ "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
+ "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
+ ),
+
+ MODEL_TENSOR.ATTN_SINKS: (
+ "model.layers.{bid}.self_attn.sinks", # openai-moe
+ "model.layers.{bid}.self_attn.attention_sink_bias", # mimov2
+ ),
+
+ MODEL_TENSOR.ATTN_GATE: (
+ "model.layers.{bid}.self_attn.gate_proj", # afmoe
+ "model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
+ "model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
+ ),
+
+ # Feed-forward norm
+ MODEL_TENSOR.FFN_NORM: (
+ "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
+ "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
+ "h.{bid}.post_attention_layernorm", # bloom
+ "transformer.blocks.{bid}.norm_2", # mpt
+ "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
+ "layers.{bid}.ffn_norm", # llama-pth
+ "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
+ "model.layers.{bid}.ln2", # yi
+ "h.{bid}.ln_2", # gpt2
+ "model.layers.{bid}.ffn_norm", # internlm2
+ "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
+ "model.layers.{bid}.pre_moe_norm", # grok-2
+ "encoder.layers.{bid}.post_attention_layernorm", # chatglm
+ "transformer.layers.{bid}.ffn_norm", # openelm
+ "model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
+ "model.layers.{bid}.pre_moe_layernorm", # mini-jamba
+ "model.layers.{bid}.post_attention_layernorm", # llama4
+ "transformer_encoder.{bid}.ffn_norm", # neobert
+ "model.layers.layers.{bid}.pre_mlp_norm", # plamo2
+ "model.transformer.blocks.{bid}.ff_norm", # llada
+ "layers.{bid}.post_attention_layernorm", # qwen3-embedding
+ "model.layers.{bid}.feedforward_layernorm", # apertus
+ "model.layers.{bid}.pre_mlp_layernorm", # kormo
+ "layers.{bid}.mlp_norm" # modern-bert
+ ),
+
+ # Pre feed-forward norm
+ MODEL_TENSOR.FFN_PRE_NORM: (
+ "model.layers.{bid}.pre_feedforward_layernorm", # gemma2
+ "layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
+ "model.layers.{bid}.pre_ff_layernorm.weight",
+ "model.layers.{bid}.pre_mlp_layernorm", # afmoe
+ ),
+
+ # Post feed-forward norm
+ MODEL_TENSOR.FFN_POST_NORM: (
+ "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
+ "layers.{bid}.post_feedforward_layernorm", # embeddinggemma
+ "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
+ "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
+ "model.layers.{bid}.feed_forward.up_proj",
+ "model.layers.{bid}.post_moe_norm", # grok-2
+ ),
+
+ MODEL_TENSOR.FFN_GATE_INP: (
+ "layers.{bid}.feed_forward.gate", # mixtral
+ "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
+ "model.layers.{bid}.mlp.gate", # qwen2moe olmoe
+ "transformer.decoder_layer.{bid}.router", # Grok
+ "transformer.blocks.{bid}.ffn.router.layer", # dbrx
+ "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
+ "model.layers.{bid}.feed_forward.router", # llama4 jamba
+ "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
+ "model.layers.{bid}.mlp.router", # openai-moe
+ "model.layers.{bid}.mlp.gate.wg", # hunyuan
+ "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
+ "model.layers.{bid}.feed_forward.gate", # lfm2moe
+ "model.layers.{bid}.mlp.router.gate", # afmoe
+ "layers.{bid}.gate", # mistral-large
+ "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
+ "model.layers.{bid}.moe.gate", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
+ "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
+ ),
+
+ MODEL_TENSOR.FFN_EXP_PROBS_B: (
+ "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
+ "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
+ "model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
+ "model.layers.{bid}.mlp.expert_bias", # afmoe
+ "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
+ "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
+ "backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
+ "model.layers.{bid}.mlp.e_score_correction", # exaone-moe
+ "model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
+ "model.layers.{bid}.moe.router_bias", # step3.5 expert selection bias
+ ),
+
+ # Feed-forward up
+ MODEL_TENSOR.FFN_UP: (
+ "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
+ "transformer.h.{bid}.mlp.c_fc", # gpt2 jais
+ "transformer.blocks.{bid}.ffn.up_proj", # mpt
+ "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
+ "h.{bid}.mlp.dense_h_to_4h", # bloom
+ "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
+ "layers.{bid}.mlp.up_proj", # embeddinggemma
+ "layers.{bid}.feed_forward.w3", # llama-pth
+ "encoder.layer.{bid}.intermediate.dense", # bert
+ "layers.{bid}.mlp.Wi", # modern-bert
+ "transformer.layer.{bid}.ffn.lin1", # distillbert
+ "transformer.h.{bid}.mlp.fc_in", # gpt-j
+ "transformer.h.{bid}.mlp.linear_3", # refact
+ "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
+ "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
+ "transformer.h.{bid}.mlp.w1", # qwen
+ "h.{bid}.mlp.c_fc", # gpt2
+ "transformer.h.{bid}.mlp.fc1", # phi2
+ "model.layers.{bid}.mlp.fc1", # phi2
+ "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414
+ "model.layers.layers.{bid}.mlp.up_proj", # plamo
+ "model.layers.layers.{bid}.mlp.gate_up_proj", # plamo2
+ "model.layers.{bid}.feed_forward.w3", # internlm2
+ "encoder.layers.{bid}.mlp.fc11", # nomic-bert
+ "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
+ "model.layers.{bid}.mlp.c_fc", # starcoder2
+ "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used)
+ "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU)
+ "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU)
+ "model.layers.{bid}.residual_mlp.w3", # arctic
+ "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
+ "transformer.h.{bid}.mlp.c_fc_1", # exaone
+ "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
+ "transformer_encoder.{bid}.ffn.w12", # neobert
+ "model.layers.{bid}.block_sparse_moe.up", # smallthinker
+ "model.transformer.blocks.{bid}.up_proj", # llada
+ "layers.{bid}.mlp.up_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.up_proj", # nemotron-h
+ "model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.FFN_UP_EXP: (
+ "layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
+ "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged)
+ "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
+ "model.layers.{bid}.feed_forward.experts.up_proj", # llama4
+ "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
+ "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
+ "model.layers.{bid}.moe.up_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_UP_SHEXP: (
+ "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
+ "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
+ "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
+ "model.layers.{bid}.feed_forward.down_proj",
+ "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
+ "layers.{bid}.shared_experts.w3", # mistral-large
+ "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
+ "model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi
+ "model.layers.{bid}.share_expert.up_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_UP_CHEXP: (
+ "model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
+ ),
+
+ # AWQ-activation gate
+ MODEL_TENSOR.FFN_ACT: (
+ "transformer.blocks.{bid}.ffn.act", # mpt
+ ),
+
+ # Feed-forward gate
+ MODEL_TENSOR.FFN_GATE: (
+ "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
+ "layers.{bid}.mlp.gate_proj", # embeddinggemma
+ "layers.{bid}.feed_forward.w1", # llama-pth
+ "transformer.h.{bid}.mlp.w2", # qwen
+ "transformer.h.{bid}.mlp.c_fc2", # jais
+ "model.layers.layers.{bid}.mlp.gate_proj", # plamo
+ "model.layers.{bid}.feed_forward.w1", # internlm2
+ "encoder.layers.{bid}.mlp.fc12", # nomic-bert
+ "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
+ "transformer.h.{bid}.mlp.linear_1", # refact
+ "model.layers.{bid}.residual_mlp.w1", # arctic
+ "transformer.h.{bid}.mlp.c_fc_0", # exaone
+ "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
+ "model.transformer.blocks.{bid}.ff_proj", # llada
+ "layers.{bid}.mlp.gate_proj", # qwen3-embedding
+ "model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.FFN_GATE_EXP: (
+ "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
+ "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
+ "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
+ "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
+ "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
+ "model.layers.{bid}.moe.gate_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_GATE_SHEXP: (
+ "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
+ "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
+ "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
+ "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
+ "layers.{bid}.shared_experts.w1", # mistral-large
+ "model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi
+ "model.layers.{bid}.share_expert.gate_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_GATE_CHEXP: (
+ "model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
+ ),
+
+ # Feed-forward down
+ MODEL_TENSOR.FFN_DOWN: (
+ "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
+ "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
+ "transformer.blocks.{bid}.ffn.down_proj", # mpt
+ "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
+ "h.{bid}.mlp.dense_4h_to_h", # bloom
+ "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
+ "layers.{bid}.mlp.down_proj", # embeddinggemma
+ "layers.{bid}.feed_forward.w2", # llama-pth
+ "encoder.layer.{bid}.output.dense", # bert
+ "layers.{bid}.mlp.Wo", # modern-bert
+ "transformer.layer.{bid}.ffn.lin2", # distillbert
+ "transformer.h.{bid}.mlp.fc_out", # gpt-j
+ "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
+ "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
+ "h.{bid}.mlp.c_proj", # gpt2
+ "transformer.h.{bid}.mlp.fc2", # phi2
+ "model.layers.{bid}.mlp.fc2", # phi2
+ "model.layers.layers.{bid}.mlp.down_proj", # plamo
+ "model.layers.{bid}.feed_forward.w2", # internlm2
+ "encoder.layers.{bid}.mlp.fc2", # nomic-bert
+ "model.layers.{bid}.mlp.c_proj", # starcoder2
+ "encoder.layer.{bid}.mlp.wo", # jina-bert-v2
+ "transformer.layers.{bid}.ffn.proj_2", # openelm
+ "model.layers.{bid}.residual_mlp.w2", # arctic
+ "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
+ "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
+ "model.layers.h.{bid}.mlp.c_proj", # exaone
+ "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
+ "transformer_encoder.{bid}.ffn.w3", # neobert
+ "model.layers.{bid}.block_sparse_moe.down", # smallthinker
+ "model.transformer.blocks.{bid}.ff_out", # llada
+ "layers.{bid}.mlp.down_proj", # qwen3-embedding
+ "backbone.layers.{bid}.mixer.down_proj", # nemotron-h
+ "model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.FFN_DOWN_EXP: (
+ "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
+ "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
+ "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
+ "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged)
+ "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
+ "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
+ "model.layers.{bid}.feed_forward.experts.down_proj", # llama4
+ "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
+ "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
+ "model.layers.{bid}.moe.down_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_DOWN_SHEXP: (
+ "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
+ "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
+ "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
+ "model.layers.{bid}.shared_mlp.output_linear", # granitemoe
+ "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
+ "layers.{bid}.shared_experts.w2", # mistral-large
+ "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
+ "model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi
+ "model.layers.{bid}.share_expert.down_proj", # step3.5
+ ),
+
+ MODEL_TENSOR.FFN_DOWN_CHEXP: (
+ "model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
+ ),
+
+ MODEL_TENSOR.ATTN_Q_NORM: (
+ "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
+ "model.layers.{bid}.self_attn.q_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.query_layernorm", # hunyuan
+ "model.layers.{bid}.attention.query_layernorm", # bailingmoe2
+ "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
+ "layers.{bid}.self_attn.q_norm", # embeddinggemma
+ "transformer.blocks.{bid}.attn.q_ln", # sea-lion
+ "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
+ "transformer.layers.{bid}.attn.q_norm", # openelm
+ "model.layers.layers.{bid}.mixer.q", # plamo2
+ "model.layers.layers.{bid}.mixer.q_norm", # plamo3
+ "layers.{bid}.self_attn.q_norm", # qwen3-embedding
+ "model.layers.{bid}.attention.query_layernorm", # apertus
+ ),
+
+ MODEL_TENSOR.ATTN_K_NORM: (
+ "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
+ "model.layers.{bid}.self_attn.k_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.key_layernorm", # hunyuan
+ "model.layers.{bid}.attention.key_layernorm", # bailingmoe2
+ "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
+ "layers.{bid}.self_attn.k_norm", # embeddinggemma
+ "transformer.blocks.{bid}.attn.k_ln", # sea-lion
+ "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
+ "transformer.layers.{bid}.attn.k_norm", # openelm
+ "model.layers.layers.{bid}.mixer.k", # plamo2
+ "model.layers.layers.{bid}.mixer.k_norm", # plamo3
+ "layers.{bid}.self_attn.k_norm", # qwen3-embedding
+ "model.layers.{bid}.attention.key_layernorm", # apertus
+ ),
+
+ MODEL_TENSOR.ROPE_FREQS: (
+ "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
+ ),
+
+ MODEL_TENSOR.LAYER_OUT_NORM: (
+ "encoder.layer.{bid}.output.LayerNorm", # bert
+ "transformer.layer.{bid}.output_layer_norm", # distillbert
+ "encoder.layers.{bid}.norm2", # nomic-bert
+ "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
+ "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
+ "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
+ "model.layers.{bid}.final_layernorm", # bailingmoe2
+ ),
+
+ MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
+ "model.embed_tokens_per_layer", # gemma3n
+ ),
+
+ MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
+ "model.per_layer_model_projection", # gemma3n
+ ),
+
+ MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
+ "model.per_layer_projection_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_PROJ: (
+ "model.altup_projections", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_UNEMBD_PROJ: (
+ "model.altup_unembed_projections", # gemma3n
+ ),
+
+ MODEL_TENSOR.PER_LAYER_INP_GATE: (
+ "model.layers.{bid}.per_layer_input_gate", # gemma3n
+ ),
+
+ MODEL_TENSOR.PER_LAYER_PROJ: (
+ "model.layers.{bid}.per_layer_projection", # gemma3n
+ ),
+
+ MODEL_TENSOR.PER_LAYER_POST_NORM: (
+ "model.layers.{bid}.post_per_layer_input_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_CORRECT_COEF: (
+ "model.layers.{bid}.altup.correction_coefs", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_CORRECT_SCALE: (
+ "model.layers.{bid}.altup.correct_output_scale", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_PREDICT_COEF: (
+ "model.layers.{bid}.altup.prediction_coefs", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_ROUTER: (
+ "model.layers.{bid}.altup.modality_router", # gemma3n
+ ),
+
+ MODEL_TENSOR.ALTUP_ROUTER_NORM: (
+ "model.layers.{bid}.altup.router_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.LAUREL_L: (
+ "model.layers.{bid}.laurel.linear_left", # gemma3n
+ ),
+
+ MODEL_TENSOR.LAUREL_R: (
+ "model.layers.{bid}.laurel.linear_right", # gemma3n
+ ),
+
+ MODEL_TENSOR.LAUREL_POST_NORM: (
+ "model.layers.{bid}.laurel.post_laurel_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.SSM_IN: (
+ "model.layers.{bid}.in_proj", # mamba-hf
+ "backbone.layers.{bid}.mixer.in_proj", # mamba
+ "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
+ "model.layers.layers.{bid}.mixer.in_proj", # plamo2
+ "model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
+ ),
+
+ MODEL_TENSOR.SSM_CONV1D: (
+ "model.layers.{bid}.conv1d", # mamba-hf
+ "backbone.layers.{bid}.mixer.conv1d", # mamba
+ "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
+ "model.layers.layers.{bid}.mixer.conv1d", # plamo2
+ "model.layers.{bid}.linear_attn.conv1d", # qwen3next
+ ),
+
+ MODEL_TENSOR.SSM_X: (
+ "model.layers.{bid}.x_proj", # mamba-hf
+ "backbone.layers.{bid}.mixer.x_proj", # mamba
+ "model.layers.{bid}.mamba.x_proj", # jamba
+ "model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2
+ ),
+
+ MODEL_TENSOR.SSM_DT: (
+ "model.layers.{bid}.dt_proj", # mamba-hf
+ "backbone.layers.{bid}.mixer.dt_proj", # mamba
+ "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
+ "model.layers.layers.{bid}.mixer.dt_proj", # plamo2
+ "model.layers.{bid}.linear_attn.dt_proj", # qwen3next
+ "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe
+ "model.layers.{bid}.self_attn.dt_proj", # kimi
+ ),
+
+ MODEL_TENSOR.SSM_DT_NORM: (
+ "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2
+ "model.layers.{bid}.mamba.dt_layernorm", # jamba
+ ),
+
+ MODEL_TENSOR.SSM_A: (
+ "model.layers.{bid}.A_log", # mamba-hf
+ "backbone.layers.{bid}.mixer.A_log", # mamba
+ "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
+ "model.layers.layers.{bid}.mixer.A_log", # plamo2
+ "model.layers.{bid}.linear_attn.A_log", # qwen3next
+ "model.layers.{bid}.self_attn.A_log", # kimi
+ ),
+
+ MODEL_TENSOR.SSM_B_NORM: (
+ "model.layers.{bid}.mamba.b_layernorm", # jamba
+ "model.layers.{bid}.mamba.B_layernorm", # mini-jamba
+ "model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2
+ ),
+
+ MODEL_TENSOR.SSM_C_NORM: (
+ "model.layers.{bid}.mamba.c_layernorm", # jamba
+ "model.layers.{bid}.mamba.C_layernorm", # mini-jamba
+ "model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2
+ ),
+
+ MODEL_TENSOR.SSM_D: (
+ "model.layers.{bid}.D", # mamba-hf
+ "backbone.layers.{bid}.mixer.D", # mamba
+ "model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid
+ "model.layers.layers.{bid}.mixer.D", # plamo2
+ ),
+
+ MODEL_TENSOR.SSM_NORM: (
+ "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
+ "model.layers.{bid}.linear_attn.norm", # qwen3next
+ "backbone.layers.{bid}.mixer.norm", # mamba2
+ "model.layers.{bid}.self_attn.o_norm", # kimi
+ ),
+
+ MODEL_TENSOR.SSM_OUT: (
+ "model.layers.{bid}.out_proj", # mamba-hf
+ "backbone.layers.{bid}.mixer.out_proj", # mamba
+ "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
+ "model.layers.{bid}.linear_attn.out_proj", # qwen3next
+ "model.layers.layers.{bid}.mixer.out_proj", # plamo2
+ ),
+
+ MODEL_TENSOR.SSM_ALPHA: (
+ "model.layers.{bid}.linear_attn.in_proj_a", # qwen3.5
+ ),
+
+ MODEL_TENSOR.SSM_BETA_ALPHA: (
+ "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
+ ),
+
+ # Kimi Linear KDA (using SSM_ prefix for consistency)
+ MODEL_TENSOR.SSM_CONV1D_Q: (
+ "model.layers.{bid}.self_attn.q_conv1d",
+ ),
+ MODEL_TENSOR.SSM_CONV1D_K: (
+ "model.layers.{bid}.self_attn.k_conv1d",
+ ),
+ MODEL_TENSOR.SSM_CONV1D_V: (
+ "model.layers.{bid}.self_attn.v_conv1d",
+ ),
+ MODEL_TENSOR.SSM_F_A: (
+ "model.layers.{bid}.self_attn.f_a_proj",
+ ),
+ MODEL_TENSOR.SSM_F_B: (
+ "model.layers.{bid}.self_attn.f_b_proj",
+ ),
+ MODEL_TENSOR.SSM_BETA: (
+ "model.layers.{bid}.linear_attn.in_proj_b", # qwen3.5
+ "model.layers.{bid}.self_attn.b_proj", # Kimi Linear
+ ),
+ MODEL_TENSOR.SSM_G_A: (
+ "model.layers.{bid}.self_attn.g_a_proj",
+ ),
+ MODEL_TENSOR.SSM_G_B: (
+ "model.layers.{bid}.self_attn.g_b_proj",
+ ),
+ MODEL_TENSOR.TIME_MIX_W0: (
+ "model.layers.{bid}.attention.w0", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_W1: (
+ "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
+ "model.layers.{bid}.attention.w1", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_W2: (
+ "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
+ "model.layers.{bid}.attention.w2", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_A0: (
+ "model.layers.{bid}.attention.a0", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_A1: (
+ "model.layers.{bid}.attention.a1", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_A2: (
+ "model.layers.{bid}.attention.a2", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_V0: (
+ "model.layers.{bid}.attention.v0", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_V1: (
+ "model.layers.{bid}.attention.v1", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_V2: (
+ "model.layers.{bid}.attention.v2", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_G1: (
+ "model.layers.{bid}.attention.g1", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_G2: (
+ "model.layers.{bid}.attention.g2", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_K_K: (
+ "model.layers.{bid}.attention.k_k", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_K_A: (
+ "model.layers.{bid}.attention.k_a", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_R_K: (
+ "model.layers.{bid}.attention.r_k", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_X: (
+ "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_K: (
+ "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_V: (
+ "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_R: (
+ "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_G: (
+ "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LERP_W: (
+ "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6
+ "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_FIRST: (
+ "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY: (
+ "rwkv.blocks.{bid}.attention.time_decay", # rwkv6
+ "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY_W1: (
+ "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6
+ "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_DECAY_W2: (
+ "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6
+ "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_KEY: (
+ "rwkv.blocks.{bid}.attention.key", # rwkv6
+ "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
+ "model.layers.{bid}.attention.key", # rwkv7
+ "model.layers.{bid}.attention.k_proj", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_VALUE: (
+ "rwkv.blocks.{bid}.attention.value", # rwkv6
+ "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
+ "model.layers.{bid}.attention.value", # rwkv7
+ "model.layers.{bid}.attention.v_proj", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
+ "rwkv.blocks.{bid}.attention.receptance", # rwkv6
+ "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
+ "model.layers.{bid}.attention.receptance", # rwkv7
+ "model.layers.{bid}.attention.r_proj", # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_GATE: (
+ "rwkv.blocks.{bid}.attention.gate", # rwkv6
+ "model.layers.{bid}.self_attn.gate", # rwkv6qwen2
+ ),
+
+ MODEL_TENSOR.TIME_MIX_LN: (
+ "rwkv.blocks.{bid}.attention.ln_x", # rwkv6
+ "model.layers.{bid}.attention.ln_x" # rwkv7
+ ),
+
+ MODEL_TENSOR.TIME_MIX_OUTPUT: (
+ "rwkv.blocks.{bid}.attention.output", # rwkv6
+ "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
+ "model.layers.{bid}.attention.output", # rwkv7
+ "model.layers.{bid}.attention.o_proj", # rwkv7
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
+ "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6
+ "model.layers.{bid}.feed_forward.x_k", # rwkv7
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
+ "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_KEY: (
+ "rwkv.blocks.{bid}.feed_forward.key", # rwkv6
+ "model.layers.{bid}.feed_forward.key", # rwkv7
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
+ "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6
+ ),
+
+ MODEL_TENSOR.CHANNEL_MIX_VALUE: (
+ "rwkv.blocks.{bid}.feed_forward.value", # rwkv6
+ "model.layers.{bid}.feed_forward.value", # rwkv7
+ ),
+
+ MODEL_TENSOR.ATTN_Q_A: (
+ "model.layers.{bid}.self_attn.q_a_proj", # deepseek2
+ "layers.{bid}.attention.wq_a", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_Q_B: (
+ "model.layers.{bid}.self_attn.q_b_proj", # deepseek2
+ "layers.{bid}.attention.wq_b", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_KV_A_MQA: (
+ "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
+ "layers.{bid}.attention.wkv_a_with_mqa", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_KV_B: (
+ "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
+ ),
+
+ MODEL_TENSOR.ATTN_K_B: (
+ "model.layers.{bid}.self_attn.k_b_proj", # deepseek2
+ "layers.{bid}.attention.k_b_proj", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_V_B: (
+ "model.layers.{bid}.self_attn.v_b_proj", # deepseek2
+ "layers.{bid}.attention.v_b_proj", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_Q_A_NORM: (
+ "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
+ "layers.{bid}.attention.q_a_norm", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_KV_A_NORM: (
+ "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
+ "layers.{bid}.attention.kv_a_norm", # mistral-large
+ ),
+
+ MODEL_TENSOR.ATTN_SUB_NORM: (
+ "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
+ ),
+
+ MODEL_TENSOR.FFN_SUB_NORM: (
+ "model.layers.{bid}.mlp.ffn_layernorm", # bitnet
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_NORM: (
+ "decoder.block.{bid}.layer.0.layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_Q: (
+ "decoder.block.{bid}.layer.0.SelfAttention.q", # t5
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_K: (
+ "decoder.block.{bid}.layer.0.SelfAttention.k", # t5
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_V: (
+ "decoder.block.{bid}.layer.0.SelfAttention.v", # t5
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_OUT: (
+ "decoder.block.{bid}.layer.0.SelfAttention.o", # t5
+ ),
+
+ MODEL_TENSOR.DEC_ATTN_REL_B: (
+ "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
+ "decoder.block.{bid}.layer.1.layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
+ "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_K: (
+ "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_V: (
+ "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
+ "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
+ ),
+
+ MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
+ "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5
+ ),
+
+ MODEL_TENSOR.DEC_FFN_NORM: (
+ "decoder.block.{bid}.layer.2.layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.DEC_FFN_GATE: (
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
+ ),
+
+ MODEL_TENSOR.DEC_FFN_UP: (
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5
+ "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
+ ),
+
+ MODEL_TENSOR.DEC_FFN_DOWN: (
+ "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
+ ),
+
+ MODEL_TENSOR.DEC_OUTPUT_NORM: (
+ "decoder.final_layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_NORM: (
+ "encoder.block.{bid}.layer.0.layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_Q: (
+ "encoder.block.{bid}.layer.0.SelfAttention.q", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_K: (
+ "encoder.block.{bid}.layer.0.SelfAttention.k", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_V: (
+ "encoder.block.{bid}.layer.0.SelfAttention.v", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_OUT: (
+ "encoder.block.{bid}.layer.0.SelfAttention.o", # t5
+ ),
+
+ MODEL_TENSOR.ENC_ATTN_REL_B: (
+ "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
+ ),
+
+ MODEL_TENSOR.ENC_FFN_NORM: (
+ "encoder.block.{bid}.layer.1.layer_norm", # t5
+ ),
+
+ MODEL_TENSOR.ENC_FFN_GATE: (
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
+ ),
+
+ MODEL_TENSOR.ENC_FFN_UP: (
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5
+ "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
+ ),
+
+ MODEL_TENSOR.ENC_FFN_DOWN: (
+ "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
+ ),
+
+ MODEL_TENSOR.VISEXP_UP: (
+ "model.layers.{bid}.mlp.vision_mlp.up_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.VISEXP_GATE: (
+ "model.layers.{bid}.mlp.vision_mlp.gate_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.VISEXP_DOWN: (
+ "model.layers.{bid}.mlp.vision_mlp.down_proj", # cogvlm
+ ),
+
+ MODEL_TENSOR.VISEXP_ATTN_OUT: (
+ "model.layers.{bid}.self_attn.vision_expert_dense", # cogvlm
+ ),
+
+ MODEL_TENSOR.VISEXP_ATTN_QKV: (
+ "model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
+ ),
+
+ ############################################################################
+ # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
+ MODEL_TENSOR.ENC_OUTPUT_NORM: (
+ "encoder.final_layer_norm", # t5
+ "layer_norm", # neobert
+ ),
+
+ MODEL_TENSOR.CLS: (
+ "classifier", # jina
+ "classifier.dense", # roberta
+ "pre_classifier", # distillbert
+ "dense", # neobert
+ "head.dense", # modern-bert
+ ),
+
+ MODEL_TENSOR.CLS_OUT: (
+ "classifier.out_proj", # roberta
+ ),
+ #############################################################################
+
+ MODEL_TENSOR.CONVNEXT_DW: (
+ "backbone.convnext.{bid}.dwconv", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.CONVNEXT_NORM: (
+ "backbone.convnext.{bid}.norm", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.CONVNEXT_PW1: (
+ "backbone.convnext.{bid}.pwconv1", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.CONVNEXT_PW2: (
+ "backbone.convnext.{bid}.pwconv2", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.CONVNEXT_GAMMA: (
+ "backbone.convnext.{bid}.gamma", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_CONV1: (
+ "backbone.posnet.{bid}.conv1", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_CONV2: (
+ "backbone.posnet.{bid}.conv2", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_NORM: (
+ "backbone.posnet.{bid}.norm", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_NORM1: (
+ "backbone.posnet.{bid}.norm1", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_NORM2: (
+ "backbone.posnet.{bid}.norm2", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_ATTN_NORM: (
+ "backbone.posnet.{bid}.norm", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_ATTN_Q: (
+ "backbone.posnet.{bid}.q", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_ATTN_K: (
+ "backbone.posnet.{bid}.k", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_ATTN_V: (
+ "backbone.posnet.{bid}.v", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.POSNET_ATTN_OUT: (
+ "backbone.posnet.{bid}.proj_out", # wavtokenizer
+ ),
+
+ MODEL_TENSOR.SHORTCONV_CONV: (
+ "model.layers.{bid}.conv.conv",
+ ),
+
+ MODEL_TENSOR.SHORTCONV_INPROJ: (
+ "model.layers.{bid}.conv.in_proj",
+ ),
+
+ MODEL_TENSOR.SHORTCONV_OUTPROJ: (
+ "model.layers.{bid}.conv.out_proj",
+ ),
+
+ #############################################################################
+ ## Vision encoder
+
+ MODEL_TENSOR.V_MMPROJ: (
+ "multi_modal_projector.linear_{bid}",
+ "mm_projector.proj.linear_{bid}", # Kimi-K2.5
+ "visual.merger.mlp.{bid}", # qwen2vl
+ "merger.mlp.{bid}",
+ ),
+
+ MODEL_TENSOR.V_MMPROJ_FC: (
+ "model.connector.modality_projection.proj", # SmolVLM
+ "model.vision.linear_proj.linear_proj", # cogvlm
+ "visual.merger.proj", # glm4v
+ ),
+
+ MODEL_TENSOR.V_MMPROJ_MLP: (
+ "model.mm_projector.mlp.mlp.{bid}",
+ "vision_model.vision_adapter.mlp.fc{bid}", # llama 4
+ "mlp1.{bid}", # InternVL
+ "model.aligner.fc1.hidden_layers.{bid}", # Janus Pro
+ ),
+
+ MODEL_TENSOR.V_MMPROJ_PEG: (
+ "model.mm_projector.peg.peg.{bid}",
+ ),
+
+ MODEL_TENSOR.V_ENC_EMBD_CLS: (
+ "vision_tower.vision_model.embeddings.class_embedding",
+ "model.vision_tower.embeddings.cls_token", # Intern-S1
+ "vision_model.class_embedding", # llama 4
+ "model.vision.patch_embedding.cls_embedding", # cogvlm
+ ),
+
+ MODEL_TENSOR.V_ENC_EMBD_PATCH: (
+ "vision_tower.vision_model.embeddings.patch_embedding",
+ "model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
+ "vpm.embeddings.patch_embedding",
+ "model.vision_model.embeddings.patch_embedding", # SmolVLM
+ "vision_tower.patch_conv", # pixtral-hf
+ "vision_encoder.patch_conv", # pixtral
+ "vision_model.patch_embedding.linear", # llama 4
+ "visual.patch_embed.proj", # qwen2vl
+ "vision_tower.patch_embed.proj", # kimi-vl
+ "model.vision.patch_embedding.proj", # cogvlm
+ "siglip2.vision_model.embeddings.patch_embedding",
+ ),
+
+ MODEL_TENSOR.V_ENC_EMBD_NORM: (
+ "visual.post_conv_layernorm", # glm4v
+ ),
+
+ MODEL_TENSOR.V_ENC_EMBD_POS: (
+ "vision_tower.vision_model.embeddings.position_embedding",
+ "model.vision_tower.embeddings.position_embeddings", # Intern-S1
+ "vpm.embeddings.position_embedding",
+ "model.vision_model.embeddings.position_embedding", # SmolVLM
+ "vision_model.positional_embedding_vlm", # llama 4
+ "vision_tower.patch_embed.pos_emb", # kimi-vl
+ "visual.pos_embed", # qwen3vl
+ "model.vision.patch_embedding.position_embedding", # cogvlm
+ "visual.embeddings.position_embedding", # glm4v
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_QKV: (
+ "visual.blocks.{bid}.attn.qkv", # qwen3vl
+ "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
+ "vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_Q: (
+ "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
+ "model.vision_tower.encoder.layer.{bid}.attention.q_proj", # Intern-S1
+ "vpm.encoder.layers.{bid}.self_attn.q_proj",
+ "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
+ "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
+ "visual.blocks.{bid}.attn.q", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
+ "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_K: (
+ "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
+ "model.vision_tower.encoder.layer.{bid}.attention.k_proj", # Intern-S1
+ "vpm.encoder.layers.{bid}.self_attn.k_proj",
+ "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
+ "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
+ "visual.blocks.{bid}.attn.k", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj",
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
+ "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_V: (
+ "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
+ "model.vision_tower.encoder.layer.{bid}.attention.v_proj", # Intern-S1
+ "vpm.encoder.layers.{bid}.self_attn.v_proj",
+ "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
+ "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
+ "visual.blocks.{bid}.attn.v", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj",
+ ),
+
+ MODEL_TENSOR.V_ENC_INPUT_NORM: (
+ "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
+ "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
+ "vpm.encoder.layers.{bid}.layer_norm1",
+ "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
+ "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
+ "vision_model.model.layers.{bid}.input_layernorm", # llama4
+ "visual.blocks.{bid}.norm1", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
+ "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
+ ),
+
+ MODEL_TENSOR.V_ENC_ATTN_O: (
+ "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
+ "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1
+ "vpm.encoder.layers.{bid}.self_attn.out_proj",
+ "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
+ "model.vision_model.encoder.layers.{bid}.self_attn.projection_layer", # Janus Pro
+ "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
+ "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
+ "visual.blocks.{bid}.attn.proj", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
+ "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
+ ),
+
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
+ "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
+ "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.layernorm_after", # Intern-S1
+ "vpm.encoder.layers.{bid}.layer_norm2",
+ "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
+ "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
+ "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
+ "visual.blocks.{bid}.norm2", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
+ "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
+ ),
+
+ MODEL_TENSOR.V_ENC_FFN_UP: (
+ "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
+ "model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
+ "vpm.encoder.layers.{bid}.mlp.fc1",
+ "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
+ "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w3", # pixtral
+ "vision_model.model.layers.{bid}.mlp.fc1", # llama4
+ "visual.blocks.{bid}.mlp.fc1", # qwen2vl
+ "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
+ "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
+ "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
+ "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
+ ),
+
+ MODEL_TENSOR.V_ENC_FFN_GATE: (
+ "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
+ "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
+ ),
+
+ MODEL_TENSOR.V_ENC_FFN_DOWN: (
+ "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
+ "model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
+ "vpm.encoder.layers.{bid}.mlp.fc2",
+ "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
+ "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w2", # pixtral
+ "vision_model.model.layers.{bid}.mlp.fc2", # llama4
+ "visual.blocks.{bid}.mlp.fc2", # qwen2vl
+ "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
+ "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
+ "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
+ "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
+ "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
+ ),
+
+ MODEL_TENSOR.V_LAYER_SCALE_1: (
+ "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.lambda_1", # Intern-S1
+ ),
+
+ MODEL_TENSOR.V_LAYER_SCALE_2: (
+ "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
+ "model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1
+ ),
+
+ MODEL_TENSOR.V_PRE_NORM: (
+ "vision_tower.vision_model.pre_layrnorm",
+ "vision_tower.ln_pre", # pixtral-hf
+ "vision_encoder.ln_pre", # pixtral
+ "vision_model.layernorm_pre", # llama4
+ ),
+
+ MODEL_TENSOR.V_POST_NORM: (
+ "vision_tower.vision_model.post_layernorm",
+ "model.vision_model.post_layernorm", # SmolVLM
+ "vision_model.layernorm_post", # llama4
+ "visual.merger.ln_q", # qwen2vl
+ "vision_tower.encoder.final_layernorm", # kimi-vl
+ "visual.post_layernorm", # glm4v
+ "siglip2.vision_model.post_layernorm",
+ ),
+
+ MODEL_TENSOR.V_MM_POST_NORM: (
+ "visual.merger.post_projection_norm", # glm4v
+ ),
+
+ MODEL_TENSOR.V_MM_INP_PROJ: (
+ "multi_modal_projector.mm_input_projection",
+ ),
+
+ MODEL_TENSOR.V_MM_INP_NORM: (
+ "multi_modal_projector.norm",
+ "multi_modal_projector.layer_norm",
+ "multi_modal_projector.pre_norm",
+ "mm_projector.pre_norm", # Kimi-K2.5
+ "pre_mm_projector_norm",
+ "model.vision.linear_proj.norm1", # cogvlm
+ "merger.ln_q",
+ ),
+
+ MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
+ "multi_modal_projector.mm_soft_emb_norm",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
+ "resampler.pos_embed_k",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_ATTN_Q: (
+ "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
+ ),
+
+ MODEL_TENSOR.V_RESMPL_ATTN_K: (
+ "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
+ ),
+
+ MODEL_TENSOR.V_RESMPL_ATTN_V: (
+ "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
+ ),
+
+ MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
+ "resampler.attn.out_proj",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_KV: (
+ "resampler.kv_proj",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_POST_NORM: (
+ "resampler.ln_post",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_KV_NORM: (
+ "resampler.ln_kv",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_Q_NORM: (
+ "resampler.ln_q",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_PROJ: (
+ "resampler.proj",
+ ),
+
+ MODEL_TENSOR.V_RESMPL_QUERY: (
+ "resampler.query",
+ ),
+
+ MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
+ "v.token_embd.img_break", # for pixtral, this is a generated vector
+ ),
+
+ MODEL_TENSOR.V_MM_PATCH_MERGER: (
+ "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
+ "patch_merger.merging_layer", # mistral
+ "visual.downsample", # glm4v
+ ),
+
+ MODEL_TENSOR.V_DS_NORM: (
+ "model.visual.deepstack_merger_list.{bid}.norm", # deepstack in qwen3vl
+ ),
+
+ MODEL_TENSOR.V_DS_FC1: (
+ "model.visual.deepstack_merger_list.{bid}.linear_fc1", # deepstack in qwen3vl
+ ),
+
+ MODEL_TENSOR.V_DS_FC2: (
+ "model.visual.deepstack_merger_list.{bid}.linear_fc2", # deepstack in qwen3vl
+ ),
+
+ MODEL_TENSOR.V_MM_POST_FC_NORM: (
+ "model.vision.linear_proj.norm1", # cogvlm
+ ),
+
+ MODEL_TENSOR.V_MM_UP: (
+ "model.vision.linear_proj.dense_h_to_4h", # cogvlm
+ "visual.merger.up_proj", # glm4v
+ ),
+
+ MODEL_TENSOR.V_MM_DOWN: (
+ "model.vision.linear_proj.dense_4h_to_h", # cogvlm
+ "visual.merger.down_proj", # glm4v
+ ),
+
+ MODEL_TENSOR.V_MM_GATE: (
+ "model.vision.linear_proj.gate_proj", # cogvlm
+ "visual.merger.gate_proj", # glm4v
+ ),
+
+ MODEL_TENSOR.V_TOK_BOI: (
+ "model.vision.boi", # cogvlm
+ ),
+
+ MODEL_TENSOR.V_TOK_EOI: (
+ "model.vision.eoi", # cogvlm
+ ),
+
+ # audio (mtmd)
+
+ MODEL_TENSOR.A_ENC_EMBD_POS: (
+ "audio_tower.embed_positions", # ultravox
+ "audio_embedding.embedding", # lfm2
+ ),
+
+ MODEL_TENSOR.A_ENC_EMBD_NORM: (
+ "audio_embedding.embedding_norm", # lfm2
+ ),
+
+ MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: (
+ "audio_embedding.to_logits", # lfm2
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV1D: (
+ "audio_tower.conv{bid}", # ultravox
+ "conformer.pre_encode.conv.{bid}", # lfm2
+ "model.audio_tower.subsample_conv_projection.conv_{bid}.conv", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV1D_NORM: (
+ "model.audio_tower.subsample_conv_projection.conv_{bid}.norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_PRE_NORM: (),
+
+ MODEL_TENSOR.A_POST_NORM: (
+ "audio_tower.layer_norm", # ultravox
+ "audio_tower.ln_post", # qwen2omni
+ ),
+
+ MODEL_TENSOR.A_ENC_ATTN_Q: (
+ "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
+ "conformer.layers.{bid}.self_attn.linear_q", # lfm2
+ "conformer.layers.{bid}.attention.attn.q_proj", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_ATTN_K: (
+ "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
+ "conformer.layers.{bid}.self_attn.linear_k", # lfm2
+ "conformer.layers.{bid}.attention.attn.k_proj", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_ATTN_V: (
+ "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
+ "conformer.layers.{bid}.self_attn.linear_v", # lfm2
+ "conformer.layers.{bid}.attention.attn.v_proj", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_PER_DIM_SCALE: (
+ "conformer.layers.{bid}.attention.attn.per_dim_scale", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_LAYER_PRE_NORM: (
+ "conformer.layers.{bid}.norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_INPUT_NORM: (
+ "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
+ "conformer.layers.{bid}.norm_self_att", # lfm2
+ "conformer.layers.{bid}.attention.pre_attn_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_OUTPUT: (
+ "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
+ "conformer.layers.{bid}.self_attn.linear_out", # lfm2
+ "conformer.layers.{bid}.attention.post", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
+ "audio_tower.layers.{bid}.final_layer_norm", # ultravox
+ "conformer.layers.{bid}.norm_out", # lfm2
+ "conformer.layers.{bid}.attention.post_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_NORM: (
+ "conformer.layers.{bid}.norm_feed_forward1", # lfm2
+ "conformer.layers.{bid}.ffw_layer_start.pre_layer_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM: (
+ "conformer.layers.{bid}.ffw_layer_start.post_layer_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_SCALE: (
+ "conformer.layers.{bid}.ffw_layer_start.post_layer_scale", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_UP: (
+ "audio_tower.layers.{bid}.fc1", # ultravox
+ "conformer.layers.{bid}.feed_forward1.linear1", # lfm2
+ "conformer.layers.{bid}.ffw_layer_start.ffw_layer_1", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_GATE: (),
+
+ MODEL_TENSOR.A_ENC_FFN_DOWN: (
+ "audio_tower.layers.{bid}.fc2", # ultravox
+ "conformer.layers.{bid}.feed_forward1.linear2", # lfm2
+ "conformer.layers.{bid}.ffw_layer_start.ffw_layer_2", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_UP_1: (
+ "conformer.layers.{bid}.feed_forward2.linear1", # lfm2
+ "conformer.layers.{bid}.ffw_layer_end.ffw_layer_1", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_DOWN_1: (
+ "conformer.layers.{bid}.feed_forward2.linear2", # lfm2
+ "conformer.layers.{bid}.ffw_layer_end.ffw_layer_2", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_NORM_1: (
+ "conformer.layers.{bid}.norm_feed_forward2", # lfm2
+ "conformer.layers.{bid}.ffw_layer_end.pre_layer_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_POST_NORM_1: (
+ "conformer.layers.{bid}.ffw_layer_end.post_layer_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_FFN_SCALE_1: (
+ "conformer.layers.{bid}.ffw_layer_end.post_layer_scale", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_LINEAR_POS: (
+ "conformer.layers.{bid}.self_attn.linear_pos", # lfm2
+ "conformer.layers.{bid}.attention.attn.relative_position_embedding.pos_proj", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_POS_BIAS_U: (
+ "conformer.layers.{bid}.self_attn.pos_bias_u", # lfm2
+ ),
+
+ MODEL_TENSOR.A_ENC_POS_BIAS_V: (
+ "conformer.layers.{bid}.self_attn.pos_bias_v", # lfm2
+ ),
+
+ MODEL_TENSOR.A_ENC_OUT: (
+ "conformer.pre_encode.out", # lfm2
+ "model.audio_tower.subsample_conv_projection.input_proj_linear", # gemma3n
+ ),
+
+ # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
+ # this prefix is added in the conversion code in modify_tensors()
+
+ MODEL_TENSOR.A_MMPROJ: (
+ "audio.multi_modal_projector.linear_{bid}", # ultravox
+ "audio_adapter.model.{bid}" # lfm2
+ ),
+
+ MODEL_TENSOR.A_MMPROJ_FC: (
+ "audio.multi_modal_projector.linear", # qwen2audio
+ "audio_tower.proj", # qwen2omni
+ ),
+
+ MODEL_TENSOR.A_MM_NORM_PRE: (
+ "audio.multi_modal_projector.ln_pre", # ultravox
+ ),
+
+ MODEL_TENSOR.A_MM_NORM_MID: (
+ "audio.multi_modal_projector.ln_mid", # ultravox
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV_DW: (
+ "conformer.layers.{bid}.conv.depthwise_conv", # lfm2
+ "conformer.layers.{bid}.lconv1d.depthwise_conv1d", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV_NORM: (
+ "conformer.layers.{bid}.conv.batch_norm", # lfm2
+ "conformer.layers.{bid}.lconv1d.pre_layer_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV_PW1: (
+ "conformer.layers.{bid}.conv.pointwise_conv1", # lfm2
+ "conformer.layers.{bid}.lconv1d.linear_start", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_CONV_PW2: (
+ "conformer.layers.{bid}.conv.pointwise_conv2", # lfm2
+ "conformer.layers.{bid}.lconv1d.linear_end", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_ENC_NORM_CONV: (
+ "conformer.layers.{bid}.norm_conv", # lfm2
+ "conformer.layers.{bid}.lconv1d.conv_norm", # gemma3n
+ ),
+
+ MODEL_TENSOR.A_MM_EMBEDDING: (
+ "model.embed_audio.embedding", # gemma3n
+ ),
+ MODEL_TENSOR.A_MM_HARD_EMB_NORM: (
+ "model.embed_audio.hard_embedding_norm", # gemma3n
+ ),
+ MODEL_TENSOR.A_MM_INP_PROJ: (
+ "model.embed_audio.embedding_projection", # gemma3n
+ ),
+ MODEL_TENSOR.A_MM_SOFT_EMB_NORM: (
+ "model.embed_audio.soft_embedding_norm", # gemma3n
+ ),
+
+ # NextN/MTP tensors
+ MODEL_TENSOR.NEXTN_EH_PROJ: (
+ "model.layers.{bid}.eh_proj",
+ ),
+
+ MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
+ "model.layers.{bid}.embed_tokens",
+ ),
+
+ MODEL_TENSOR.NEXTN_ENORM: (
+ "model.layers.{bid}.enorm",
+ ),
+
+ MODEL_TENSOR.NEXTN_HNORM: (
+ "model.layers.{bid}.hnorm",
+ ),
+
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
+ "model.layers.{bid}.shared_head.head",
+ ),
+
+ MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
+ "model.layers.{bid}.shared_head.norm",
+ ),
+ }
+
+ # architecture-specific block mappings
+ arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
+ MODEL_ARCH.ARCTIC: {
+ MODEL_TENSOR.FFN_NORM: (
+ "model.layers.{bid}.residual_layernorm",
+ ),
+ MODEL_TENSOR.FFN_NORM_EXP: (
+ "model.layers.{bid}.post_attention_layernorm",
+ ),
+ },
+ }
+
+ mapping: dict[str, tuple[MODEL_TENSOR, str]]
+
+ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
+ self.mapping = {}
+ for tensor, keys in self.mappings_cfg.items():
+ if tensor not in MODEL_TENSORS[arch]:
+ continue
+ tensor_name = TENSOR_NAMES[tensor]
+ self.mapping[tensor_name] = (tensor, tensor_name)
+ for key in keys:
+ self.mapping[key] = (tensor, tensor_name)
+ if arch in self.arch_block_mappings_cfg:
+ self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
+ for bid in range(n_blocks):
+ for tensor, keys in self.block_mappings_cfg.items():
+ if tensor not in MODEL_TENSORS[arch]:
+ continue
+
+ tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
+ self.mapping[tensor_name] = (tensor, tensor_name)
+ for key in keys:
+ key = key.format(bid = bid)
+ self.mapping[key] = (tensor, tensor_name)
+
+ def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
+ result = self.mapping.get(key)
+ if result is not None:
+ return result
+ for suffix in try_suffixes:
+ if key.endswith(suffix):
+ result = self.mapping.get(key[:-len(suffix)])
+ if result is not None:
+ return result[0], result[1] + suffix
+ return None
+
+ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
+ result = self.get_type_and_name(key, try_suffixes = try_suffixes)
+ if result is None:
+ return None
+ return result[1]
+
+ def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
+ result = self.get_type_and_name(key, try_suffixes = try_suffixes)
+ if result is None:
+ return None
+ return result[0]
+
+ def __getitem__(self, key: str) -> str:
+ try:
+ return self.mapping[key][1]
+ except KeyError:
+ raise KeyError(key)
+
+ def __contains__(self, key: str) -> bool:
+ return key in self.mapping
+
+ def __repr__(self) -> str:
+ return repr(self.mapping)
+
+
+def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
+ return TensorNameMap(arch, n_blocks)
diff --git a/llama.cpp/gguf-py/gguf/utility.py b/llama.cpp/gguf-py/gguf/utility.py
new file mode 100644
index 0000000..154351d
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/utility.py
@@ -0,0 +1,340 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Literal
+
+import os
+import json
+import numpy as np
+
+
+def fill_templated_filename(filename: str, output_type: str | None) -> str:
+ # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
+ ftype_lowercase: str = output_type.lower() if output_type is not None else ""
+ ftype_uppercase: str = output_type.upper() if output_type is not None else ""
+ return filename.format(ftype_lowercase,
+ outtype=ftype_lowercase, ftype=ftype_lowercase,
+ OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
+
+
+def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
+ if model_params_count > 1e12 :
+ # Trillions Of Parameters
+ scaled_model_params = model_params_count * 1e-12
+ scale_suffix = "T"
+ elif model_params_count > 1e9 :
+ # Billions Of Parameters
+ scaled_model_params = model_params_count * 1e-9
+ scale_suffix = "B"
+ elif model_params_count > 1e6 :
+ # Millions Of Parameters
+ scaled_model_params = model_params_count * 1e-6
+ scale_suffix = "M"
+ else:
+ # Thousands Of Parameters
+ scaled_model_params = model_params_count * 1e-3
+ scale_suffix = "K"
+
+ fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
+
+ return f"{scaled_model_params:.{fix}f}{scale_suffix}"
+
+
+def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
+
+ if expert_count > 0:
+ pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
+ size_class = f"{expert_count}x{pretty_size}"
+ else:
+ size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
+
+ return size_class
+
+
+def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
+ # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention
+
+ if base_name is not None:
+ name = base_name.strip().replace(' ', '-').replace('/', '-')
+ elif model_name is not None:
+ name = model_name.strip().replace(' ', '-').replace('/', '-')
+ else:
+ name = "ggml-model"
+
+ parameters = f"-{size_label}" if size_label is not None else ""
+
+ finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
+
+ version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
+
+ encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
+
+ kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
+
+ return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
+
+
+@dataclass
+class RemoteTensor:
+ dtype: str
+ shape: tuple[int, ...]
+ offset_start: int
+ size: int
+ url: str
+
+ def data(self) -> bytearray:
+ # TODO: handle request errors (maybe with limited retries?)
+ # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
+ data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
+ return data
+
+
+class SafetensorRemote:
+ """
+ Uility class to handle remote safetensor files.
+ This class is designed to work with Hugging Face model repositories.
+
+ Example (one model has single safetensor file, the other has multiple):
+ for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
+ tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
+ print(tensors)
+
+ Example reading tensor data:
+ tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
+ for name, meta in tensors.items():
+ dtype, shape, offset_start, size, remote_safetensor_url = meta
+ # read the tensor data
+ data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
+ print(data)
+ """
+
+ BASE_DOMAIN = "https://huggingface.co"
+
+ @classmethod
+ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
+ """
+ Get list of tensors from a Hugging Face model repository.
+
+ Returns a dictionary of tensor names and their metadata.
+ Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
+ """
+ # case 1: model has only one single model.safetensor file
+ is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
+ if is_single_file:
+ url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
+ return cls.get_list_tensors(url)
+
+ # case 2: model has multiple files
+ index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
+ is_multiple_files = cls.check_file_exist(index_url)
+ if is_multiple_files:
+ # read the index file
+ index_data = cls.get_data_by_range(index_url, 0)
+ index_str = index_data.decode('utf-8')
+ index_json = json.loads(index_str)
+ assert index_json.get("weight_map") is not None, "weight_map not found in index file"
+ weight_map = index_json["weight_map"]
+ # get the list of files
+ all_files = list(set(weight_map.values()))
+ all_files.sort() # make sure we load shard files in order
+ # get the list of tensors
+ tensors: dict[str, RemoteTensor] = {}
+ for file in all_files:
+ url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
+ for key, val in cls.get_list_tensors(url).items():
+ tensors[key] = val
+ return tensors
+
+ raise ValueError(
+ f"No safetensor file has been found for model {model_id}."
+ "If the repo has safetensor files, make sure the model is public or you have a "
+ "valid Hugging Face token set in the environment variable HF_TOKEN."
+ )
+
+ @classmethod
+ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
+ """
+ Get list of tensors from a remote safetensor file.
+
+ Returns a dictionary of tensor names and their metadata.
+ Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
+ """
+ metadata, data_start_offset = cls.get_metadata(url)
+ res: dict[str, RemoteTensor] = {}
+
+ for name, meta in metadata.items():
+ if name == "__metadata__":
+ continue
+ if not isinstance(meta, dict):
+ raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
+ try:
+ dtype = meta["dtype"]
+ shape = meta["shape"]
+ offset_start_relative, offset_end_relative = meta["data_offsets"]
+ size = offset_end_relative - offset_start_relative
+ offset_start = data_start_offset + offset_start_relative
+ res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
+ except KeyError as e:
+ raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
+
+ # order by name (same as default safetensors behavior)
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+ res = dict(sorted(res.items(), key=lambda t: t[0]))
+
+ return res
+
+ @classmethod
+ def get_metadata(cls, url: str) -> tuple[dict, int]:
+ """
+ Get JSON metadata from a remote safetensor file.
+
+ Returns tuple of (metadata, data_start_offset)
+ """
+ # Request first 5MB of the file (hopefully enough for metadata)
+ read_size = 5 * 1024 * 1024
+ raw_data = cls.get_data_by_range(url, 0, read_size)
+
+ # Parse header
+ # First 8 bytes contain the metadata length as u64 little-endian
+ if len(raw_data) < 8:
+ raise ValueError("Not enough data to read metadata size")
+ metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
+
+ # Calculate the data start offset
+ data_start_offset = 8 + metadata_length
+
+ # Check if we have enough data to read the metadata
+ if len(raw_data) < 8 + metadata_length:
+ raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
+
+ # Extract metadata bytes and parse as JSON
+ metadata_bytes = raw_data[8:8 + metadata_length]
+ metadata_str = metadata_bytes.decode('utf-8')
+ try:
+ metadata = json.loads(metadata_str)
+ return metadata, data_start_offset
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
+
+ @classmethod
+ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
+ """
+ Get raw byte data from a remote file by range.
+ If size is not specified, it will read the entire file.
+ """
+ import requests
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(url)
+ if not parsed_url.scheme or not parsed_url.netloc:
+ raise ValueError(f"Invalid URL: {url}")
+
+ headers = cls._get_request_headers()
+ if size > -1:
+ headers["Range"] = f"bytes={start}-{start + size}"
+ response = requests.get(url, allow_redirects=True, headers=headers)
+ response.raise_for_status()
+
+ # Get raw byte data
+ return response.content[slice(size if size > -1 else None)]
+
+ @classmethod
+ def check_file_exist(cls, url: str) -> bool:
+ """
+ Check if a file exists at the given URL.
+ Returns True if the file exists, False otherwise.
+ """
+ import requests
+ from urllib.parse import urlparse
+
+ parsed_url = urlparse(url)
+ if not parsed_url.scheme or not parsed_url.netloc:
+ raise ValueError(f"Invalid URL: {url}")
+
+ try:
+ headers = cls._get_request_headers()
+ headers["Range"] = "bytes=0-0"
+ response = requests.head(url, allow_redirects=True, headers=headers)
+ # Success (2xx) or redirect (3xx)
+ return 200 <= response.status_code < 400
+ except requests.RequestException:
+ return False
+
+ @classmethod
+ def _get_request_headers(cls) -> dict[str, str]:
+ """Prepare common headers for requests."""
+ headers = {"User-Agent": "convert_hf_to_gguf"}
+ if os.environ.get("HF_TOKEN"):
+ headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
+ return headers
+
+
+@dataclass
+class LocalTensorRange:
+ filename: Path
+ offset: int
+ size: int
+
+
+@dataclass
+class LocalTensor:
+ dtype: str
+ shape: tuple[int, ...]
+ data_range: LocalTensorRange
+
+ def mmap_bytes(self) -> np.ndarray:
+ return np.memmap(self.data_range.filename, mode='c', offset=self.data_range.offset, shape=self.data_range.size)
+
+
+class SafetensorsLocal:
+ """
+ Read a safetensors file from the local filesystem.
+
+ Custom parsing gives a bit more control over the memory usage.
+ The official safetensors library doesn't expose file ranges.
+ """
+
+ tensors: dict[str, LocalTensor]
+
+ def __init__(self, filename: Path):
+ with open(filename, "rb") as f:
+ metadata_length = int.from_bytes(f.read(8), byteorder='little')
+ file_size = os.stat(filename).st_size
+ if file_size < 8 + metadata_length:
+ raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
+
+ metadata_str = f.read(metadata_length).decode('utf-8')
+ try:
+ metadata = json.loads(metadata_str)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
+
+ data_start_offset = f.tell()
+
+ tensors: dict[str, LocalTensor] = {}
+ for name, meta in metadata.items():
+ if name == "__metadata__":
+ # ignore metadata, it's not a tensor
+ continue
+
+ tensors[name] = LocalTensor(
+ dtype=meta["dtype"],
+ shape=tuple(meta["shape"]),
+ data_range=LocalTensorRange(
+ filename,
+ data_start_offset + meta["data_offsets"][0],
+ meta["data_offsets"][1] - meta["data_offsets"][0],
+ ),
+ )
+
+ # order by name (same as default safetensors behavior)
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+ self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
+
+ def __enter__(self, *args, **kwargs):
+ del args, kwargs # unused
+ return self.tensors
+
+ def __exit__(self, *args, **kwargs):
+ del args, kwargs # unused
diff --git a/llama.cpp/gguf-py/gguf/vocab.py b/llama.cpp/gguf-py/gguf/vocab.py
new file mode 100644
index 0000000..028e574
--- /dev/null
+++ b/llama.cpp/gguf-py/gguf/vocab.py
@@ -0,0 +1,891 @@
+from __future__ import annotations
+
+from enum import Enum
+import re
+import logging
+import json
+import os
+from pathlib import Path
+from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
+
+try:
+ from sentencepiece import SentencePieceProcessor
+except ImportError:
+ SentencePieceProcessor = None
+
+try:
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
+ from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
+ from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
+ _filter_valid_tokenizer_files,
+ )
+ from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
+ SentencePieceTokenizer,
+ )
+except ImportError:
+ _mistral_common_installed = False
+ MistralTokenizer = None
+ Tekkenizer = None
+ SentencePieceTokenizer = None
+ _filter_valid_tokenizer_files = None
+else:
+ _mistral_common_installed = True
+
+try:
+ from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
+ get_one_valid_tokenizer_file,
+ )
+except ImportError:
+ # We still want the conversion to work with older mistral-common versions.
+ get_one_valid_tokenizer_file = None
+
+
+import gguf
+
+from .gguf_writer import GGUFWriter
+
+logger = logging.getLogger(__name__)
+
+
+class SpecialVocab:
+ merges: list[str]
+ add_special_token: dict[str, bool]
+ special_token_ids: dict[str, int]
+ chat_template: str | Sequence[Mapping[str, str]] | None
+
+ def __init__(
+ self, path: str | os.PathLike[str], load_merges: bool = False,
+ special_token_types: Iterable[str] | None = None,
+ n_vocab: int | None = None,
+ ):
+ self.special_token_ids = {}
+ self.add_special_token = {}
+ self.n_vocab = n_vocab
+ self.load_merges = load_merges
+ self.merges = []
+ self.chat_template = None
+ if special_token_types is not None:
+ self.special_token_types = special_token_types
+ else:
+ self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
+ self._load(Path(path))
+
+ def __repr__(self) -> str:
+ return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
+ len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
+ )
+
+ def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
+ if self.merges:
+ if not quiet:
+ logger.info(f'Adding {len(self.merges)} merge(s).')
+ gw.add_token_merges(self.merges)
+ elif self.load_merges:
+ logger.warning('Adding merges requested but no merges found, output may be non-functional.')
+ for typ, tokid in self.special_token_ids.items():
+ id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
+ if id_handler is None:
+ logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
+ continue
+ if not quiet:
+ logger.info(f'Setting special token type {typ} to {tokid}')
+ id_handler(tokid)
+ for typ, value in self.add_special_token.items():
+ add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
+ if add_handler is None:
+ logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
+ continue
+ if not quiet:
+ logger.info(f'Setting add_{typ}_token to {value}')
+ add_handler(value)
+ if self.chat_template is not None:
+ if not quiet:
+ logger.info(f'Setting chat_template to {self.chat_template}')
+ gw.add_chat_template(self.chat_template)
+
+ def _load(self, path: Path) -> None:
+ self._try_load_from_tokenizer_json(path)
+ self._try_load_from_config_json(path)
+ if self.load_merges and not self.merges:
+ self._try_load_merges_txt(path)
+
+ def _try_load_merges_txt(self, path: Path) -> bool:
+ merges_file = path / 'merges.txt'
+ if not merges_file.is_file():
+ return False
+ with open(merges_file, 'r', encoding = 'utf-8') as fp:
+ first_line = next(fp, '').strip()
+ if not first_line.startswith('#'):
+ fp.seek(0)
+ line_num = 0
+ else:
+ line_num = 1
+ merges = []
+ for line in fp:
+ line_num += 1
+ line = line.strip()
+ if not line:
+ continue
+ parts = line.split(None, 3)
+ if len(parts) != 2:
+ logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
+ continue
+ merges.append(f'{parts[0]} {parts[1]}')
+ self.merges = merges
+ return True
+
+ def _set_special_token(self, typ: str, tid: Any) -> None:
+ if not isinstance(tid, int):
+ return
+ if tid < 0:
+ raise ValueError(f'invalid value for special token type {typ}: {tid}')
+ if self.n_vocab is None or tid < self.n_vocab:
+ if typ in self.special_token_ids:
+ return
+ self.special_token_ids[typ] = tid
+ return
+ logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
+
+ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
+ tokenizer = None
+ tokenizer_file = path / 'tokenizer.json'
+ if tokenizer_file.is_file():
+ with open(tokenizer_file, encoding = 'utf-8') as f:
+ tokenizer = json.load(f)
+ if self.load_merges:
+ merges = tokenizer.get('model', {}).get('merges')
+ if isinstance(merges, list) and merges:
+ if isinstance(merges[0], str):
+ self.merges = merges
+ elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
+ # New format since transformers 4.45 to support spaces in merges
+ # ref: https://github.com/ggml-org/llama.cpp/issues/9692
+ # TODO: internally store as the new format instead of converting to old
+ if any(' ' in s for pair in merges for s in pair):
+ logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
+ self.merges = [
+ ' '.join(
+ [
+ # ensure the spaces are properly encoded
+ ''.join(
+ chr(ord(c) + 256) if c == ' ' else c
+ for c in part
+ )
+ for part in pair
+ ]
+ )
+ for pair in merges
+ ]
+ else:
+ raise ValueError("Unknown tokenizer merges format")
+ added_tokens = tokenizer.get('added_tokens', {})
+ else:
+ added_tokens = {}
+ tokenizer_config = None
+ tokenizer_config_file = path / 'tokenizer_config.json'
+ if tokenizer_config_file.is_file():
+ with open(tokenizer_config_file, encoding = 'utf-8') as f:
+ tokenizer_config = json.load(f)
+ if tokenizer:
+ special_bos = (tokenizer_config or {}).get('bos_token')
+ special_cls = (tokenizer_config or {}).get('cls_token')
+ special_eos = (tokenizer_config or {}).get('eos_token')
+ special_sep = (tokenizer_config or {}).get('sep_token')
+ if not special_bos and special_cls and tokenizer_config:
+ tokenizer_config['bos_token'] = special_bos = special_cls
+ if not special_eos and special_sep and tokenizer_config:
+ tokenizer_config['eos_token'] = special_eos = special_sep
+ if post_processor := tokenizer.get('post_processor'):
+ for processor in post_processor.get('processors', [post_processor]):
+ if processor.get('type') == 'RobertaProcessing':
+ self.add_special_token['bos'] = True
+ self.add_special_token['eos'] = True
+ self.add_special_token['sep'] = True
+ if not special_cls and tokenizer_config:
+ special_cls = processor.get('cls', [special_bos])[0]
+ tokenizer_config['cls_token'] = special_cls
+ if not special_sep and tokenizer_config:
+ special_sep = processor.get('sep', [special_eos])[0]
+ tokenizer_config['sep_token'] = special_sep
+ continue
+ # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
+ # Only works with simple templates, **will** get it wrong on unusual sequences
+ if processor.get('type') == 'TemplateProcessing':
+ tmpl_single = processor.get('single', [])
+ tmpl_pair = processor.get('pair', [])
+ special_first = None
+ special_last = None
+ if len(tmpl_single) > 1:
+ if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
+ if not tokenizer_config:
+ special_bos = special_first
+ self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
+ if special_first not in (special_bos, special_cls):
+ logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
+ if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
+ if not tokenizer_config:
+ special_eos = special_last
+ elif special_last != special_eos:
+ if 'eot' not in self.special_token_types:
+ self.special_token_types = tuple(self.special_token_types) + ('eot', )
+ tokenizer_config['eot_token'] = special_eos
+ elif 'eom' not in self.special_token_types:
+ self.special_token_types = tuple(self.special_token_types) + ('eom', )
+ tokenizer_config['eom_token'] = special_eos
+ else:
+ logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
+ tokenizer_config['eos_token'] = special_eos = special_last
+ self.add_special_token['eos'] = True if special_last == special_eos else False
+ if special_last != special_eos:
+ logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
+ if tmpl_pair:
+ seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
+ seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
+ if (special_first and seq_start == 0) or (special_last and seq_stop is None):
+ logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
+ if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
+ tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
+ tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
+ if tmpl_a != 'A' or tmpl_b != 'B':
+ logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
+ # A [sep] [eos] B
+ if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
+ add_sep = False
+ if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
+ if special_entry in (special_sep, special_eos) and not special_last:
+ add_sep = True
+ if special_entry not in (special_sep, special_eos):
+ logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
+ else:
+ logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
+ if len(tmpl_pair) == 2:
+ if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
+ if special_entry in (special_sep, special_eos):
+ add_sep = True
+ if special_entry not in (special_sep, special_eos):
+ logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
+ else:
+ logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
+ self.add_special_token['sep'] = add_sep
+ if add_sep and not special_sep and tokenizer_config:
+ tokenizer_config['sep_token'] = special_eos
+ continue
+ if not tokenizer_config:
+ return True
+ chat_template_alt = None
+ chat_template_json = path / 'chat_template.json'
+ chat_template_jinja = path / 'chat_template.jinja'
+ if chat_template_jinja.is_file():
+ with open(chat_template_jinja, encoding = 'utf-8') as f:
+ chat_template_alt = f.read()
+ if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
+ chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
+ for template_path in additional_templates:
+ with open(template_path, encoding = 'utf-8') as fp:
+ chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
+ elif chat_template_json.is_file():
+ with open(chat_template_json, encoding = 'utf-8') as f:
+ chat_template_alt = json.load(f).get('chat_template')
+ chat_template = tokenizer_config.get('chat_template', chat_template_alt)
+ if chat_template is None or isinstance(chat_template, (str, list)):
+ self.chat_template = chat_template
+ else:
+ logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
+ for typ in self.special_token_types:
+ add_entry = tokenizer_config.get(f'add_{typ}_token')
+ if isinstance(add_entry, bool):
+ self.add_special_token[typ] = add_entry
+ entry = tokenizer_config.get(f'{typ}_token')
+ if isinstance(entry, str):
+ tc_content = entry
+ elif isinstance(entry, dict):
+ entry_content = entry.get('content')
+ if not isinstance(entry_content, str):
+ continue
+ tc_content = entry_content
+ else:
+ continue
+ # We only need the first match here.
+ maybe_token_id = next(
+ (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
+ None,
+ )
+ self._set_special_token(typ, maybe_token_id)
+ return True
+
+ def _try_load_from_config_json(self, path: Path) -> bool:
+ config_file = path / 'config.json'
+ if not config_file.is_file():
+ return False
+ with open(config_file, encoding = 'utf-8') as f:
+ config = json.load(f)
+ for typ in self.special_token_types:
+ token_id = config.get(f'{typ}_token_id')
+ # If not found at root, check in text_config (for multimodal models like Kimi-VL)
+ if token_id is None and 'text_config' in config:
+ token_id = config['text_config'].get(f'{typ}_token_id')
+ self._set_special_token(typ, token_id)
+ return True
+
+
+@runtime_checkable
+class BaseVocab(Protocol):
+ tokenizer_model: ClassVar[str]
+ name: ClassVar[str]
+
+
+@runtime_checkable
+class Vocab(BaseVocab, Protocol):
+ vocab_size: int
+ added_tokens_dict: dict[str, int]
+ added_tokens_list: list[str]
+ fname_tokenizer: Path
+
+ def __init__(self, base_path: Path): ...
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
+
+
+class NoVocab(BaseVocab):
+ tokenizer_model = "no_vocab"
+ name = "no_vocab"
+
+ def __repr__(self) -> str:
+ return "<NoVocab for a model without integrated vocabulary>"
+
+
+class BpeVocab(Vocab):
+ tokenizer_model = "gpt2"
+ name = "bpe"
+
+ def __init__(self, base_path: Path):
+ added_tokens: dict[str, int] = {}
+
+ if (fname_tokenizer := base_path / 'vocab.json').exists():
+ # "slow" tokenizer
+ with open(fname_tokenizer, encoding="utf-8") as f:
+ self.vocab = json.load(f)
+
+ try:
+ # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
+ with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
+ added_tokens = json.load(f)
+ except FileNotFoundError:
+ pass
+ else:
+ # "fast" tokenizer
+ fname_tokenizer = base_path / 'tokenizer.json'
+
+ # if this fails, FileNotFoundError propagates to caller
+ with open(fname_tokenizer, encoding="utf-8") as f:
+ tokenizer_json = json.load(f)
+
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
+ if (
+ tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
+ or tokenizer_json['decoder']['type'] != 'ByteLevel'
+ ):
+ raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
+
+ self.vocab = tokenizer_model["vocab"]
+
+ if (added := tokenizer_json.get('added_tokens')) is not None:
+ # Added tokens here can be duplicates of the main vocabulary.
+ added_tokens = {item['content']: item['id']
+ for item in added
+ if item['content'] not in self.vocab}
+
+ vocab_size = len(self.vocab)
+ expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
+ actual_ids = sorted(added_tokens.values())
+ if expected_ids != actual_ids:
+ expected_end_id = vocab_size + len(actual_ids) - 1
+ raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
+ f"{vocab_size} - {expected_end_id}; got {actual_ids}")
+
+ items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
+ self.added_tokens_dict = added_tokens
+ self.added_tokens_list = [text for (text, idx) in items]
+ self.vocab_size_base = vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+ self.fname_tokenizer = fname_tokenizer
+
+ def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
+
+ for i, _ in enumerate(self.vocab):
+ yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
+
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ for text in self.added_tokens_list:
+ score = -1000.0
+ yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
+
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ yield from self.bpe_tokens()
+ yield from self.added_tokens()
+
+ def __repr__(self) -> str:
+ return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+
+
+class SentencePieceVocab(Vocab):
+ tokenizer_model = "llama"
+ name = "spm"
+
+ def __init__(self, base_path: Path):
+ if SentencePieceProcessor is None:
+ raise RuntimeError("sentencepiece is not installed")
+
+ added_tokens: dict[str, int] = {}
+ if (fname_tokenizer := base_path / 'tokenizer.model').exists():
+ # normal location
+ try:
+ with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
+ added_tokens = json.load(f)
+ except FileNotFoundError:
+ pass
+ elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
+ # not found in alternate location either
+ raise FileNotFoundError('Cannot find tokenizer.model')
+
+ self.sentencepiece_tokenizer = SentencePieceProcessor()
+ self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
+ vocab_size = self.sentencepiece_tokenizer.vocab_size()
+
+ new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
+ expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
+ actual_new_ids = sorted(new_tokens.keys())
+
+ if expected_new_ids != actual_new_ids:
+ raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
+
+ # Token pieces that were added to the base vocabulary.
+ self.added_tokens_dict = added_tokens
+ self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
+ self.vocab_size_base = vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+ self.fname_tokenizer = fname_tokenizer
+
+ def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ tokenizer = self.sentencepiece_tokenizer
+ for i in range(tokenizer.vocab_size()):
+ piece = tokenizer.IdToPiece(i)
+ text = piece.encode("utf-8")
+ score: float = tokenizer.GetScore(i)
+
+ toktype = gguf.TokenType.NORMAL
+ if tokenizer.IsUnknown(i):
+ toktype = gguf.TokenType.UNKNOWN
+ if tokenizer.IsControl(i):
+ toktype = gguf.TokenType.CONTROL
+
+ # NOTE: I think added_tokens are user defined.
+ # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
+ # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
+
+ if tokenizer.IsUnused(i):
+ toktype = gguf.TokenType.UNUSED
+ if tokenizer.IsByte(i):
+ toktype = gguf.TokenType.BYTE
+
+ yield text, score, toktype
+
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ for text in self.added_tokens_list:
+ score = -1000.0
+ yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
+
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ yield from self.sentencepiece_tokens()
+ yield from self.added_tokens()
+
+ def __repr__(self) -> str:
+ return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+
+
+class LlamaHfVocab(Vocab):
+ tokenizer_model = "llama"
+ name = "hfft"
+
+ def __init__(self, base_path: Path):
+ fname_tokenizer = base_path / 'tokenizer.json'
+ # if this fails, FileNotFoundError propagates to caller
+ with open(fname_tokenizer, encoding='utf-8') as f:
+ tokenizer_json = json.load(f)
+
+ # pre-check so we know if we need transformers
+ tokenizer_model: dict[str, Any] = tokenizer_json['model']
+ is_llama3 = (
+ tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
+ and not tokenizer_model.get('byte_fallback', True)
+ )
+ if is_llama3:
+ raise TypeError('Llama 3 must be converted with BpeVocab')
+
+ if not is_llama3 and (
+ tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
+ or tokenizer_json['decoder']['type'] != 'Sequence'
+ ):
+ raise FileNotFoundError('Cannot find Llama BPE tokenizer')
+
+ try:
+ from transformers import AutoTokenizer
+ except ImportError as e:
+ raise ImportError(
+ "To use LlamaHfVocab, please install the `transformers` package. "
+ "You can install it with `pip install transformers`."
+ ) from e
+
+ # Allow the tokenizer to default to slow or fast versions.
+ # Explicitly set tokenizer to use local paths.
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ base_path,
+ cache_dir=base_path,
+ local_files_only=True,
+ )
+ assert self.tokenizer.is_fast # assume tokenizer.json is used
+
+ # Initialize lists and dictionaries for added tokens
+ self.added_tokens_list = []
+ self.added_tokens_dict = dict()
+ self.added_tokens_ids = set()
+
+ # Process added tokens
+ for tok, tokidx in sorted(
+ self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
+ ):
+ # Only consider added tokens that are not in the base vocabulary
+ if tokidx >= self.tokenizer.vocab_size:
+ self.added_tokens_list.append(tok)
+ self.added_tokens_dict[tok] = tokidx
+ self.added_tokens_ids.add(tokidx)
+
+ # Store special tokens and their IDs
+ self.specials = {
+ tok: self.tokenizer.get_vocab()[tok]
+ for tok in self.tokenizer.all_special_tokens
+ }
+ self.special_ids = set(self.tokenizer.all_special_ids)
+
+ # Set vocabulary sizes
+ self.vocab_size_base = self.tokenizer.vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+
+ self.fname_tokenizer = fname_tokenizer
+
+ def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ reverse_vocab = {
+ id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
+ }
+
+ for token_id in range(self.vocab_size_base):
+ # Skip processing added tokens here
+ if token_id in self.added_tokens_ids:
+ continue
+
+ # Convert token text to bytes
+ token_text = reverse_vocab[token_id].encode("utf-8")
+
+ # Yield token text, score, and type
+ yield token_text, self.get_token_score(token_id), self.get_token_type(
+ token_id, token_text, self.special_ids # Reuse already stored special IDs
+ )
+
+ def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
+ # Special case for byte tokens
+ if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
+ return gguf.TokenType.BYTE
+
+ # Determine token type based on whether it's a special token
+ return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
+
+ def get_token_score(self, token_id: int) -> float:
+ # Placeholder for actual logic to determine the token's score
+ # This needs to be implemented based on specific requirements
+ return -1000.0 # Default score
+
+ def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ for text in self.added_tokens_list:
+ if text in self.specials:
+ toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
+ score = self.get_token_score(self.specials[text])
+ else:
+ toktype = gguf.TokenType.USER_DEFINED
+ score = -1000.0
+
+ yield text.encode("utf-8"), score, toktype
+
+ def has_newline_token(self):
+ return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
+
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ yield from self.hf_tokens()
+ yield from self.added_tokens()
+
+ def __repr__(self) -> str:
+ return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+
+
+class MistralTokenizerType(str, Enum):
+ spm = "spm"
+ tekken = "tekken"
+
+
+# Copied from Transformers (Apache 2.0)
+# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
+
+def bytes_to_unicode() -> dict[int, str]:
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs_str = [chr(n) for n in cs]
+ return dict(zip(bs, cs_str))
+
+
+class MistralVocab(Vocab):
+ tokenizer_model = "mistral"
+ name = "mistral"
+
+ added_tokens_dict: dict[str, int] = {}
+ added_tokens_list: list[str] = []
+
+ def __init__(self, base_path: Path):
+ if not _mistral_common_installed:
+ raise ImportError(
+ "To use MistralVocab, please install the `mistral-common` package. "
+ "You can install it with `pip install mistral-common`."
+ )
+ assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
+ assert MistralTokenizer is not None, "mistral_common is not installed"
+ assert Tekkenizer is not None, "mistral_common is not installed"
+
+ logger.info(f"Loading Mistral tokenizer from {base_path}")
+
+ # Find the tokenizer files
+ all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
+
+ if get_one_valid_tokenizer_file is not None:
+ tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
+ else:
+ valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
+
+ if len(valid_tokenizer_files) == 0:
+ raise ValueError(f"No tokenizer file found in the directory: {base_path}")
+ # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
+ if len(valid_tokenizer_files) > 1:
+ if "tekken.json" in valid_tokenizer_files:
+ tokenizer_file = "tekken.json"
+ else:
+ tokenizer_file = sorted(valid_tokenizer_files)[-1]
+ logger.warning(
+ f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
+ )
+ else:
+ tokenizer_file = valid_tokenizer_files[0]
+
+ tokenizer_file_path = base_path / tokenizer_file
+
+ self.tokenizer = MistralTokenizer.from_file(
+ tokenizer_file_path
+ ).instruct_tokenizer.tokenizer
+ self.tokenizer_type = (
+ MistralTokenizerType.tekken
+ if isinstance(self.tokenizer, Tekkenizer)
+ else MistralTokenizerType.spm
+ )
+ self.vocab_size = self.tokenizer.n_words
+ self.fname_tokenizer = tokenizer_file_path
+ self._name = (
+ "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
+ )
+
+ @property
+ def tokenizer_name(self) -> str:
+ return self._name
+
+ @property
+ def gguf_tokenizer_model(self) -> str:
+ return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
+
+ def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ assert SentencePieceTokenizer is not None, "mistral_common is not installed"
+ assert isinstance(self.tokenizer, SentencePieceTokenizer), (
+ f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
+ )
+
+ for i in range(self.tokenizer._model.vocab_size()):
+ piece = self.tokenizer._model.IdToPiece(i)
+ text = piece.encode("utf-8")
+ score: float = self.tokenizer._model.GetScore(i)
+
+ toktype = gguf.TokenType.NORMAL
+ if self.tokenizer._model.IsUnknown(i):
+ toktype = gguf.TokenType.UNKNOWN
+ if self.tokenizer._model.IsControl(i):
+ toktype = gguf.TokenType.CONTROL
+
+ if self.tokenizer._model.IsUnused(i):
+ toktype = gguf.TokenType.UNUSED
+ if self.tokenizer._model.IsByte(i):
+ toktype = gguf.TokenType.BYTE
+
+ yield text, score, toktype
+
+ def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ assert Tekkenizer is not None, "mistral_common is not installed"
+ assert isinstance(self.tokenizer, Tekkenizer), (
+ f"Expected Tekkenizer, got {type(self.tokenizer)}"
+ )
+
+ byte_encoder = bytes_to_unicode()
+ for token_id in range(self.tokenizer.num_special_tokens):
+ yield (
+ self.tokenizer.id_to_piece(token_id).encode("utf-8"),
+ 0,
+ gguf.TokenType.CONTROL
+ )
+ for token in self.tokenizer._tekken_token2id_nospecial:
+ yield (
+ self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
+ 0,
+ gguf.TokenType.NORMAL,
+ )
+
+ def get_token_id(self, token: str) -> int:
+ assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
+ if self.tokenizer_type == MistralTokenizerType.spm:
+ assert isinstance(self.tokenizer, SentencePieceTokenizer)
+ return self.tokenizer._vocab.index(token)
+ elif self.tokenizer_type == MistralTokenizerType.tekken:
+ assert isinstance(self.tokenizer, Tekkenizer)
+ return (
+ self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
+ )
+ else:
+ raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
+
+ @property
+ def bos_id(self) -> int:
+ return self.tokenizer.bos_id
+
+ @property
+ def eos_id(self) -> int:
+ return self.tokenizer.eos_id
+
+ @property
+ def pad_id(self) -> int:
+ if self.tokenizer.pad_id == -1:
+ return self.eos_id
+ return self.tokenizer.pad_id
+
+ @property
+ def unk_id(self) -> int:
+ return self.tokenizer.unk_id
+
+ @property
+ def bos_token(self) -> str:
+ return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
+
+ @property
+ def eos_token(self) -> str:
+ return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
+
+ @property
+ def pad_token(self) -> str:
+ return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
+
+ @property
+ def unk_token(self) -> str:
+ return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
+
+ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+ if self.tokenizer_type == MistralTokenizerType.spm:
+ yield from self._sentencepiece_tokens()
+
+ elif self.tokenizer_type == MistralTokenizerType.tekken:
+ yield from self._tekken_tokens()
+
+ else:
+ raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
+
+ @staticmethod
+ def token_bytes_to_string(b, byte_encoder):
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
+
+ def extract_vocab_merges_from_model(self):
+ # Adapted from Transformers (Apache 2.0)
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
+ assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
+ f"Expected Tekkenizer, got {type(self.tokenizer)}"
+ )
+ mergeable_ranks = self.tokenizer._model._mergeable_ranks
+ token_bytes_map = {
+ rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
+ }
+ merge_pairs = []
+
+ # Sort vocab by rank to ensure correct merge order
+ for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
+ merged_token = token_bytes_map[i]
+ local = []
+ for j in range(1, len(merged_token)):
+ left = merged_token[:j]
+ right = merged_token[j:]
+ if (
+ left in mergeable_ranks
+ and right in mergeable_ranks
+ and (left + right) in mergeable_ranks
+ ):
+ local.append((left, right, i))
+ if not local:
+ raise ValueError(
+ f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
+ )
+ local = sorted(
+ local,
+ key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
+ reverse=False,
+ )
+ merge_pairs.extend(local)
+ merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
+
+ byte_encoder = bytes_to_unicode()
+
+ decoded_merge_pairs = [
+ [
+ self.token_bytes_to_string(val[0], byte_encoder),
+ self.token_bytes_to_string(val[1], byte_encoder),
+ ]
+ for val in merge_pairs
+ ]
+
+ merges = [
+ " ".join(
+ [
+ # ensure the spaces are properly encoded
+ "".join(chr(ord(c) + 256) if c == " " else c for c in part)
+ for part in pair
+ ]
+ )
+ for pair in decoded_merge_pairs
+ ]
+
+ return merges