1#include "llama-hparams.h"
  2
  3#include "ggml.h"
  4
  5#include <algorithm>
  6#include <cassert>
  7
  8void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
  9    if (dense_first) {
 10        for (uint32_t il = 0; il < n_layer; ++il) {
 11            swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
 12        }
 13    } else {
 14        for (uint32_t il = 0; il < n_layer; ++il) {
 15            swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
 16        }
 17    }
 18}
 19
 20bool llama_hparams::is_swa_any() const {
 21    for (uint32_t il = 0; il < n_layer; ++il) {
 22        if (swa_layers[il]) {
 23            return true;
 24        }
 25    }
 26
 27    return false;
 28}
 29
 30uint32_t llama_hparams::n_head(uint32_t il) const {
 31    if (il < n_layer) {
 32        return n_head_arr[il];
 33    }
 34
 35    GGML_ABORT("fatal error");
 36}
 37
 38uint32_t llama_hparams::n_head_kv(uint32_t il) const {
 39    if (il < n_layer) {
 40        return n_head_kv_arr[il];
 41    }
 42
 43    GGML_ABORT("fatal error");
 44}
 45
 46uint32_t llama_hparams::n_ff(uint32_t il) const {
 47    if (il < n_layer) {
 48        return n_ff_arr[il];
 49    }
 50
 51    GGML_ABORT("fatal error");
 52}
 53
 54uint32_t llama_hparams::n_gqa(uint32_t il) const {
 55    const uint32_t n_head    = this->n_head(il);
 56    const uint32_t n_head_kv = this->n_head_kv(il);
 57
 58    if (n_head_kv == 0) {
 59        return 0;
 60    }
 61
 62    return n_head/n_head_kv;
 63}
 64
 65uint32_t llama_hparams::n_embd_inp() const {
 66    uint32_t n_embd_inp = n_embd;
 67
 68    if (n_deepstack_layers > 0) {
 69        n_embd_inp += n_embd * n_deepstack_layers;
 70    }
 71
 72    return n_embd_inp;
 73}
 74
 75uint32_t llama_hparams::n_embd_out() const {
 76    return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
 77}
 78
 79uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
 80    const uint32_t n_head_kv = this->n_head_kv(il);
 81
 82    return n_embd_head_k * n_head_kv;
 83}
 84
 85uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
 86    const uint32_t n_head_kv = this->n_head_kv(il);
 87
 88    return n_embd_head_v * n_head_kv;
 89}
 90
 91bool llama_hparams::is_n_embd_k_gqa_variable() const {
 92    const uint32_t val = n_embd_k_gqa();
 93    for (uint32_t il = 0; il < n_layer; ++il) {
 94        if (val != n_embd_k_gqa(il)) {
 95            return true;
 96        }
 97    }
 98
 99    return false;
100}
101
102bool llama_hparams::is_n_embd_v_gqa_variable() const {
103    const uint32_t val = n_embd_v_gqa();
104    for (uint32_t il = 0; il < n_layer; ++il) {
105        if (val != n_embd_v_gqa(il)) {
106            return true;
107        }
108    }
109
110    return false;
111}
112
113uint32_t llama_hparams::n_embd_k_gqa_max() const {
114    uint32_t val = n_embd_k_gqa();
115    for (uint32_t il = 0; il < n_layer; ++il) {
116        val = std::max(val, n_embd_k_gqa(il));
117    }
118
119    return val;
120}
121
122uint32_t llama_hparams::n_embd_v_gqa_max() const {
123    uint32_t val = n_embd_v_gqa();
124    for (uint32_t il = 0; il < n_layer; ++il) {
125        val = std::max(val, n_embd_v_gqa(il));
126    }
127
128    return val;
129}
130
131uint32_t llama_hparams::n_embd_r() const {
132    if (wkv_head_size != 0) {
133        // for RWKV models
134        return token_shift_count * n_embd;
135    }
136
137    if (n_shortconv_l_cache != 0) {
138        // for LFM2 models
139        return n_embd * (n_shortconv_l_cache - 1);
140    }
141
142    if (n_embd_head_kda != 0) {
143        // for Kimi KDA layers
144        // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim
145        const uint32_t d_inner = n_head() * n_embd_head_kda;  // 32 * 128 = 4096
146        return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner;
147    }
148
149    // TODO: maybe support other convolution strides than 1
150    // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
151    // Corresponds to Mamba's conv_states size
152    return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
153}
154
155uint32_t llama_hparams::n_embd_s() const {
156    if (wkv_head_size != 0) {
157        // corresponds to RWKV's wkv_states size
158        return n_embd * wkv_head_size;
159    }
160
161    if (n_embd_head_kda != 0) {
162        // for Kimi KDA layers
163        // Full recurrent state: head_dim * head_dim * n_head
164        // h tensor shape for delta attention: [head_dim, head_dim, n_head]
165        return n_embd_head_kda * n_embd_head_kda * n_head();  // 128 * 128 * 32 = 524288
166    }
167
168    // corresponds to Mamba's ssm_states size
169    return ssm_d_state * ssm_d_inner;
170}
171
172bool llama_hparams::is_recurrent(uint32_t il) const {
173    if (il < n_layer) {
174        return recurrent_layer_arr[il];
175    }
176
177    GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
178}
179
180uint32_t llama_hparams::n_pos_per_embd() const {
181    return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1;
182}
183
184bool llama_hparams::is_swa(uint32_t il) const {
185    if (il < n_layer) {
186        return swa_layers[il];
187    }
188
189    GGML_ABORT("fatal error");
190}
191
192bool llama_hparams::is_mla() const {
193    assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
194           (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));
195
196    return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
197}
198
199uint32_t llama_hparams::n_embd_head_k_mla() const {
200    return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
201}
202
203uint32_t llama_hparams::n_embd_head_v_mla() const {
204    return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
205}
206
207bool llama_hparams::has_kv(uint32_t il) const {
208    if (n_layer_kv_from_start >= 0) {
209        if (il < (uint32_t) n_layer_kv_from_start) {
210            return true;
211        }
212
213        return false;
214    }
215
216    // by default, all layers have kv
217    return true;
218}
219
220uint32_t llama_hparams::n_layer_kv() const {
221    uint32_t res = 0;
222
223    for (uint32_t il = 0; il < n_layer; ++il) {
224        if (has_kv(il)) {
225            res++;
226        }
227    }
228
229    return res;
230}
231
232bool llama_hparams::use_mrope() const {
233    return rope_sections[0] > 0 && rope_sections[1] > 0;
234}