1# Add a new model architecture to `llama.cpp`
2
3Adding a model requires few steps:
4
51. Convert the model to GGUF
62. Define the model architecture in `llama.cpp`
73. Build the GGML graph implementation
8
9After following these steps, you can open PR.
10
11Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially:
12- [cli](/tools/cli/)
13- [completion](/tools/completion/)
14- [imatrix](/tools/imatrix/)
15- [quantize](/tools/quantize/)
16- [server](/tools/server/)
17
18### 1. Convert the model to GGUF
19
20This step is done in python with a `convert` script using the [gguf](https://pypi.org/project/gguf/) library.
21Depending on the model architecture, you can use either [convert_hf_to_gguf.py](/convert_hf_to_gguf.py) or [examples/convert_legacy_llama.py](/examples/convert_legacy_llama.py) (for `llama/llama2` models in `.pth` format).
22
23The convert script reads the model configuration, tokenizer, tensor names+data and converts them to GGUF metadata and tensors.
24
25The required steps to implement for an HF model are:
26
271. Define the model `ModelBase.register` annotation in a new `TextModel` or `MmprojModel` subclass, example:
28
29```python
30@ModelBase.register("MyModelForCausalLM")
31class MyModel(TextModel):
32 model_arch = gguf.MODEL_ARCH.MYMODEL
33```
34
35or
36
37```python
38@ModelBase.register("MyModelForConditionalGeneration")
39class MyModel(MmprojModel):
40 model_arch = gguf.MODEL_ARCH.MYMODEL
41```
42
432. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
44
45Add an enum entry in `MODEL_ARCH`, the model human friendly name in `MODEL_ARCH_NAMES` and the GGUF tensor names in `MODEL_TENSORS`.
46
47Example for `falcon` model:
48```python
49 MODEL_ARCH.FALCON: [
50 MODEL_TENSOR.TOKEN_EMBD,
51 MODEL_TENSOR.OUTPUT_NORM,
52 MODEL_TENSOR.OUTPUT,
53 MODEL_TENSOR.ATTN_NORM,
54 MODEL_TENSOR.ATTN_NORM_2,
55 MODEL_TENSOR.ATTN_QKV,
56 MODEL_TENSOR.ATTN_OUT,
57 MODEL_TENSOR.FFN_DOWN,
58 MODEL_TENSOR.FFN_UP,
59 ]
60```
61
623. Map the original tensor names to the standardize equivalent in GGUF
63
64As a general rule, before adding a new tensor name to GGUF, be sure the equivalent naming does not already exist.
65
66Once you have found the GGUF tensor name equivalent, add it to the [tensor_mapping.py](/gguf-py/gguf/tensor_mapping.py) file.
67
68If the tensor name is part of a repetitive layer/block, the key word `bid` substitutes it.
69
70Example for the normalization tensor in attention layers:
71
72```python
73block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
74 # Attention norm
75 MODEL_TENSOR.ATTN_NORM: (
76 "gpt_neox.layers.{bid}.input_layernorm", # gptneox
77 "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
78 "transformer.blocks.{bid}.norm_1", # mpt
79 ...
80 )
81}
82```
83
84`transformer.blocks.{bid}.norm_1` will be mapped to `blk.{bid}.attn_norm` in GGUF.
85
86Depending on the model configuration, tokenizer, code and tensors layout, you will have to override:
87- `TextModel#set_gguf_parameters`
88- `MmprojModel#set_gguf_parameters`
89- `ModelBase#set_vocab`
90- `ModelBase#modify_tensors`
91
92NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
93
94### 2. Define the model architecture in `llama.cpp`
95
96The model params and tensors layout must be defined in `llama.cpp` source files:
971. Define a new `llm_arch` enum value in `src/llama-arch.h`.
982. In `src/llama-arch.cpp`:
99 - Add the architecture name to the `LLM_ARCH_NAMES` map.
100 - Add the list of model tensors to `llm_get_tensor_names` (you may also need to update `LLM_TENSOR_NAMES`)
1013. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
1024. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
103
104NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
105
106### 3. Build the GGML graph implementation
107
108This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
109Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
110Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
111Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
112
113Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
114
115Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
116
117## GGUF specification
118
119https://github.com/ggml-org/ggml/blob/master/docs/gguf.md
120
121## Resources
122
123- YaRN RoPE scaling https://github.com/ggml-org/llama.cpp/pull/2268
124- support Baichuan serial models https://github.com/ggml-org/llama.cpp/pull/3009
125- support attention bias https://github.com/ggml-org/llama.cpp/pull/4283
126- Mixtral support https://github.com/ggml-org/llama.cpp/pull/4406
127- BERT embeddings https://github.com/ggml-org/llama.cpp/pull/5423
128- Grok-1 support https://github.com/ggml-org/llama.cpp/pull/6204
129- Command R Plus support https://github.com/ggml-org/llama.cpp/pull/6491
130- support arch DBRX https://github.com/ggml-org/llama.cpp/pull/6515
131- How to convert HuggingFace model to GGUF format https://github.com/ggml-org/llama.cpp/discussions/2948