1#include "models.h"
  2
  3llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
  4
  5ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
  6                                                         ggml_tensor *        cur,
  7                                                         const llama_model &  model,
  8                                                         const llama_ubatch & ubatch,
  9                                                         int                  il) {
 10    const auto * mctx_cur = inp->mctx;
 11
 12    const auto kv_head = mctx_cur->get_head();
 13
 14    const auto & layer = model.layers[il];
 15
 16    const int64_t d_conv         = hparams.ssm_d_conv;
 17    const int64_t d_inner        = hparams.ssm_d_inner;
 18    const int64_t d_state        = hparams.ssm_d_state;
 19    const int64_t dt_rank        = hparams.ssm_dt_rank;
 20    const int64_t n_head         = d_inner;
 21    const int64_t head_dim       = 1;
 22    const int64_t n_seqs         = ubatch.n_seqs;
 23    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
 24    const bool    ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
 25
 26    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 27
 28    GGML_ASSERT(n_seqs != 0);
 29    GGML_ASSERT(ubatch.equal_seqs());
 30    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 31
 32    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
 33    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 34
 35    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
 36    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
 37
 38    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
 39    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
 40
 41    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
 42    ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
 43    // split the above in two
 44    // => {d_inner, n_seq_tokens, n_seqs}
 45    ggml_tensor * x  = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
 46    ggml_tensor * z =
 47        ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz));
 48
 49    // conv
 50    {
 51        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
 52        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
 53
 54        // copy last (d_conv - 1) columns back into the state cache
 55        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
 56                                               n_seq_tokens * (conv_x->nb[0]));
 57
 58        ggml_build_forward_expand(
 59            gf, ggml_cpy(ctx0, last_conv,
 60                         ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs),
 61                                      kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all))));
 62
 63        // 1D convolution
 64        // The equivalent is to make a self-overlapping view of conv_x
 65        // over d_conv columns at each stride in the 3rd dimension,
 66        // then element-wise multiply that with the conv1d weight,
 67        // then sum the elements of each row,
 68        // (the last two steps are a dot product over rows (also doable with mul_mat))
 69        // then permute away the ne[0] dimension,
 70        // and then you're left with the resulting x tensor.
 71        // For simultaneous sequences, all sequences need to have the same length.
 72        x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
 73
 74        // bias
 75        x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
 76
 77        x = ggml_silu(ctx0, x);
 78    }
 79
 80    // ssm
 81    {
 82        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
 83        ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
 84        // split
 85        ggml_tensor * dt   = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
 86        ggml_tensor * B =
 87            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
 88                         x_db->nb[2], ggml_element_size(x_db) * dt_rank);
 89        ggml_tensor * C =
 90            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
 91                         x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state));
 92
 93        // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
 94        if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
 95            dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
 96            B  = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il);
 97            C  = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il);
 98        }
 99
100        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
101        dt = build_lora_mm(layer.ssm_dt, dt);
102        dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
103
104        cur = x;
105        x   = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
106
107        ggml_tensor * A = layer.ssm_a;
108
109        // use the states and the indices provided by build_recurrent_state
110        // (this is necessary in order to properly use the states before they are overwritten,
111        //  while avoiding to make unnecessary copies of the states)
112        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
113            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
114
115            // Custom operator to optimize the parallel associative scan
116            // as described in the Annex D of the Mamba paper.
117            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
118            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
119        };
120
121        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
122
123        // store last states
124        ggml_build_forward_expand(
125            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]),
126                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
127                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
128
129        ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
130
131        // TODO: skip computing output earlier for unused tokens
132
133        y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
134        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
135
136        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
137        cur = build_lora_mm(layer.ssm_out, y);
138    }
139
140    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
141    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
142
143    return cur;
144}
145
146ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
147                                                          ggml_tensor *        cur,
148                                                          const llama_model &  model,
149                                                          const llama_ubatch & ubatch,
150                                                          int                  il) const {
151    const auto * mctx_cur = inp->mctx;
152
153    const auto kv_head = mctx_cur->get_head();
154
155    const int64_t d_conv   = hparams.ssm_d_conv;
156    const int64_t d_inner  = hparams.ssm_d_inner;
157    const int64_t d_state  = hparams.ssm_d_state;
158    const int64_t n_head   = hparams.ssm_dt_rank;
159    const int64_t head_dim = d_inner / n_head;
160    const int64_t n_group  = hparams.ssm_n_group;
161    const int64_t n_seqs   = ubatch.n_seqs;
162
163    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
164
165    GGML_ASSERT(n_seqs != 0);
166    GGML_ASSERT(ubatch.equal_seqs());
167    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
168
169    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
170    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
171
172    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
173    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
174
175    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
176    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
177
178    // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
179
180    // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
181    ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
182
183    // split the above in three
184    ggml_tensor * z   = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
185                                     zxBCdt->nb[1], zxBCdt->nb[2], 0);
186    ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1],
187                                     zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt));
188    ggml_tensor * dt  = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2],
189                                     (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt));
190
191    // conv
192    {
193        // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
194        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
195
196        // copy last (d_conv - 1) columns back into the state cache
197        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs,
198                                               conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0]));
199
200        ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
201                                               ggml_view_1d(ctx0, conv_states_all,
202                                                            (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
203                                                            kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
204                                                                ggml_element_size(conv_states_all))));
205
206        // 1D convolution
207        // The equivalent is to make a self-overlapping view of conv_x
208        // over d_conv columns at each stride in the 3rd dimension,
209        // then element-wise multiply that with the conv1d weight,
210        // then sum the elements of each row,
211        // (the last two steps are a dot product over rows (also doable with mul_mat))
212        // then permute away the ne[0] dimension,
213        // and then you're left with the resulting x tensor.
214        // For simultaneous sequences, all sequences need to have the same length.
215        xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
216
217        // bias
218        xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
219
220        xBC = ggml_silu(ctx0, xBC);
221    }
222
223    // ssm
224    {
225        // These correspond to V K Q in SSM/attention duality
226        ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0],
227                                       xBC->nb[1], xBC->nb[2], 0);
228        ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
229                                       xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC));
230        ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
231                                       xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC));
232
233        // {n_head, n_seq_tokens, n_seqs}
234        dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
235
236        ggml_tensor * A = model.layers[il].ssm_a;
237
238        // use the states and the indices provided by build_recurrent_state
239        // (this is necessary in order to properly use the states before they are overwritten,
240        //  while avoiding to make unnecessary copies of the states)
241        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
242            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
243
244            // TODO: use semistructured matrices to implement state-space duality
245            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
246            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
247        };
248
249        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
250
251        // store last states
252        ggml_build_forward_expand(
253            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]),
254                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
255                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
256
257        ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1],
258                                       n_seq_tokens * n_head * x->nb[1], 0);
259
260        // TODO: skip computing output earlier for unused tokens
261
262        y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
263        cb(y, "mamba2_y_add_d", il);
264        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
265
266        // grouped RMS norm
267        if (model.layers[il].ssm_norm) {
268            y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
269            y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
270        }
271
272        y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
273
274        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
275        cur = build_lora_mm(model.layers[il].ssm_out, y);
276    }
277
278    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
279    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
280    cb(cur, "mamba_out", il);
281
282    return cur;
283}