1#pragma once
  2
  3#include "ggml.h"
  4#include "clip.h"
  5#include "clip-impl.h"
  6
  7#include <array>
  8#include <vector>
  9#include <unordered_set>
 10#include <cstdint>
 11#include <cmath>
 12
 13enum ffn_op_type {
 14    FFN_GELU,
 15    FFN_GELU_ERF,
 16    FFN_SILU,
 17    FFN_GELU_QUICK,
 18};
 19
 20enum norm_type {
 21    NORM_TYPE_NORMAL,
 22    NORM_TYPE_RMS,
 23};
 24
 25enum patch_merge_type {
 26    PATCH_MERGE_FLAT,
 27    PATCH_MERGE_SPATIAL_UNPAD,
 28};
 29
 30struct clip_hparams {
 31    int32_t image_size = 0;
 32    int32_t patch_size = 0;
 33    int32_t n_embd = 0;
 34    int32_t n_ff = 0;
 35    int32_t projection_dim = 0;
 36    int32_t n_head = 0;
 37    int32_t n_layer = 0;
 38    // idefics3
 39    int32_t image_longest_edge = 0;
 40    int32_t image_min_pixels = -1;
 41    int32_t image_max_pixels = -1;
 42    int32_t n_merge = 0; // number of patch merges **per-side**
 43
 44    float image_mean[3];
 45    float image_std[3];
 46
 47    // for models using dynamic image size, we need to have a smaller image size to warmup
 48    // otherwise, user will get OOM everytime they load the model
 49    int32_t warmup_image_size = 0;
 50    int32_t warmup_audio_size = 3000;
 51
 52    ffn_op_type ffn_op = FFN_GELU;
 53
 54    patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
 55
 56    float eps = 1e-6;
 57    float rope_theta = 0.0;
 58
 59    std::vector<clip_image_size> image_res_candidates; // for llava-uhd style models
 60    int32_t image_crop_resolution;
 61    std::unordered_set<int32_t> vision_feature_layer;
 62    int32_t attn_window_size = 0;
 63    int32_t n_wa_pattern = 0;
 64    std::unordered_set<int32_t> wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL)
 65
 66    // audio
 67    int32_t n_mel_bins = 0; // whisper preprocessor
 68    int32_t proj_stack_factor = 0; // ultravox
 69
 70    // audio-to-mel preprocessor params
 71    int32_t audio_chunk_len   = -1; // in seconds
 72    int32_t audio_sample_rate = -1;
 73    int32_t audio_n_fft       = -1;
 74    int32_t audio_window_len  = -1;
 75    int32_t audio_hop_len     = -1;
 76
 77    // legacy
 78    bool has_llava_projector = false;
 79    int minicpmv_version = 0;
 80    int32_t minicpmv_query_num = 0;         // MiniCPM-V query number
 81
 82    // custom value provided by user, can be undefined if not set
 83    int32_t custom_image_min_tokens = -1;
 84    int32_t custom_image_max_tokens = -1;
 85
 86    void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
 87        const int cur_merge = n_merge == 0 ? 1 : n_merge;
 88        const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
 89        image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
 90        image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
 91        warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
 92    }
 93
 94    void set_warmup_n_tokens(int n_tokens) {
 95        int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
 96        GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
 97        const int cur_merge = n_merge == 0 ? 1 : n_merge;
 98        warmup_image_size = n_tok_per_side * patch_size * cur_merge;
 99        // TODO: support warmup size for custom token numbers
100    }
101};
102
103struct clip_layer {
104    // attention
105    ggml_tensor * k_w = nullptr;
106    ggml_tensor * k_b = nullptr;
107    ggml_tensor * q_w = nullptr;
108    ggml_tensor * q_b = nullptr;
109    ggml_tensor * v_w = nullptr;
110    ggml_tensor * v_b = nullptr;
111    ggml_tensor * qkv_w = nullptr;
112    ggml_tensor * qkv_b = nullptr;
113
114    ggml_tensor * o_w = nullptr;
115    ggml_tensor * o_b = nullptr;
116
117    ggml_tensor * k_norm = nullptr;
118    ggml_tensor * q_norm = nullptr;
119
120    // layernorm 1
121    ggml_tensor * ln_1_w = nullptr;
122    ggml_tensor * ln_1_b = nullptr;
123
124    ggml_tensor * ff_up_w = nullptr;
125    ggml_tensor * ff_up_b = nullptr;
126    ggml_tensor * ff_gate_w = nullptr;
127    ggml_tensor * ff_gate_b = nullptr;
128    ggml_tensor * ff_down_w = nullptr;
129    ggml_tensor * ff_down_b = nullptr;
130
131    // layernorm 2
132    ggml_tensor * ln_2_w = nullptr;
133    ggml_tensor * ln_2_b = nullptr;
134
135    // layer scale (no bias)
136    ggml_tensor * ls_1_w = nullptr;
137    ggml_tensor * ls_2_w = nullptr;
138
139    // qwen3vl deepstack merger
140    ggml_tensor * deepstack_norm_w = nullptr;
141    ggml_tensor * deepstack_norm_b = nullptr;
142    ggml_tensor * deepstack_fc1_w = nullptr;
143    ggml_tensor * deepstack_fc1_b = nullptr;
144    ggml_tensor * deepstack_fc2_w = nullptr;
145    ggml_tensor * deepstack_fc2_b = nullptr;
146
147    // lfm2
148    ggml_tensor * ff_norm_w     = nullptr;
149    ggml_tensor * ff_norm_b     = nullptr;
150    ggml_tensor * ff_norm_1_w   = nullptr;
151    ggml_tensor * ff_norm_1_b   = nullptr;
152    ggml_tensor * ff_up_1_w     = nullptr;
153    ggml_tensor * ff_up_1_b     = nullptr;
154    ggml_tensor * ff_down_1_w   = nullptr;
155    ggml_tensor * ff_down_1_b   = nullptr;
156    ggml_tensor * pos_bias_u    = nullptr;
157    ggml_tensor * pos_bias_v    = nullptr;
158    ggml_tensor * norm_conv_w   = nullptr;
159    ggml_tensor * norm_conv_b   = nullptr;
160    ggml_tensor * linear_pos_w  = nullptr;
161
162    ggml_tensor * conv_norm_w   = nullptr;
163    ggml_tensor * conv_norm_b   = nullptr;
164    ggml_tensor * conv_dw_w     = nullptr;
165    ggml_tensor * conv_dw_b     = nullptr;
166    ggml_tensor * conv_pw1_w    = nullptr;
167    ggml_tensor * conv_pw1_b    = nullptr;
168    ggml_tensor * conv_pw2_w    = nullptr;
169    ggml_tensor * conv_pw2_b    = nullptr;
170
171    bool has_deepstack() const {
172        return deepstack_fc1_w != nullptr;
173    }
174};
175
176// Expanded MobileNetV5 block structure for Gemma3n vision encoder
177struct mobilenetv5_block {
178    // Stage 0 (Edge Residual)
179    ggml_tensor * s0_conv_exp_w = nullptr;
180    ggml_tensor * s0_bn1_w      = nullptr;
181    ggml_tensor * s0_conv_pwl_w = nullptr;
182    ggml_tensor * s0_bn2_w      = nullptr;
183
184    // Stage 1+ (Universal Inverted Residual)
185    ggml_tensor * dw_start_w    = nullptr;
186    ggml_tensor * dw_start_bn_w = nullptr;
187
188    ggml_tensor * pw_exp_w      = nullptr;
189    ggml_tensor * pw_exp_bn_w   = nullptr;
190
191    ggml_tensor * dw_mid_w      = nullptr;
192    ggml_tensor * dw_mid_bn_w   = nullptr;
193
194    ggml_tensor * pw_proj_w     = nullptr;
195    ggml_tensor * pw_proj_bn_w  = nullptr;
196
197    ggml_tensor * layer_scale_w = nullptr;
198
199    // Attention (MQA) components
200    ggml_tensor * attn_q_w = nullptr;
201    ggml_tensor * attn_k_w = nullptr;
202    ggml_tensor * attn_v_w = nullptr;
203    ggml_tensor * attn_o_w = nullptr;
204
205    // Optional downsampling/norm in attention
206    ggml_tensor * attn_k_dw_w   = nullptr;
207    ggml_tensor * attn_k_norm_w = nullptr;
208    ggml_tensor * attn_v_dw_w   = nullptr;
209    ggml_tensor * attn_v_norm_w = nullptr;
210
211    // Block norm (often present in attention blocks)
212    ggml_tensor * attn_norm_w   = nullptr;
213};
214
215struct clip_model {
216    clip_modality modality = CLIP_MODALITY_VISION;
217    projector_type proj_type = PROJECTOR_TYPE_MLP;
218    clip_hparams hparams;
219
220    // embeddings
221    ggml_tensor * class_embedding = nullptr;
222    ggml_tensor * patch_embeddings_0 = nullptr;
223    ggml_tensor * patch_embeddings_1 = nullptr;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
224    ggml_tensor * patch_bias = nullptr;
225    ggml_tensor * position_embeddings = nullptr;
226    ggml_tensor * norm_embd_w = nullptr;
227    ggml_tensor * norm_embd_b = nullptr;
228
229    ggml_tensor * pre_ln_w = nullptr;
230    ggml_tensor * pre_ln_b = nullptr;
231
232    std::vector<clip_layer> layers;
233
234    int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer
235
236    ggml_tensor * post_ln_w;
237    ggml_tensor * post_ln_b;
238
239    ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
240    ggml_tensor * mm_fc_w;
241    ggml_tensor * mm_fc_b;
242    ggml_tensor * mm_ffn_up_w = nullptr;
243    ggml_tensor * mm_ffn_up_b = nullptr;
244    ggml_tensor * mm_ffn_gate_w = nullptr;
245    ggml_tensor * mm_ffn_gate_b = nullptr;
246    ggml_tensor * mm_ffn_down_w = nullptr;
247    ggml_tensor * mm_ffn_down_b = nullptr;
248    ggml_tensor * mm_post_norm_w = nullptr;
249    ggml_tensor * mm_post_norm_b = nullptr;
250
251    // LLaVA projection
252    ggml_tensor * mm_input_norm_w = nullptr;
253    ggml_tensor * mm_input_norm_b = nullptr;
254    ggml_tensor * mm_0_w = nullptr;
255    ggml_tensor * mm_0_b = nullptr;
256    ggml_tensor * mm_2_w = nullptr;
257    ggml_tensor * mm_2_b = nullptr;
258
259    ggml_tensor * image_newline = nullptr;
260
261    // Yi type models with mlp+normalization projection
262    ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
263    ggml_tensor * mm_1_b = nullptr;
264    ggml_tensor * mm_3_w = nullptr;
265    ggml_tensor * mm_3_b = nullptr;
266    ggml_tensor * mm_4_w = nullptr;
267    ggml_tensor * mm_4_b = nullptr;
268
269    // GLMV-Edge projection
270    ggml_tensor * mm_model_adapter_conv_w = nullptr;
271    ggml_tensor * mm_model_adapter_conv_b = nullptr;
272
273    // MobileVLM projection
274    ggml_tensor * mm_model_mlp_1_w = nullptr;
275    ggml_tensor * mm_model_mlp_1_b = nullptr;
276    ggml_tensor * mm_model_mlp_3_w = nullptr;
277    ggml_tensor * mm_model_mlp_3_b = nullptr;
278    ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
279    ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
280    ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
281    ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
282    ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
283    ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
284    ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
285    ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
286    ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
287    ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
288    ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
289    ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
290    ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
291    ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
292    ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
293    ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
294    ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
295    ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
296    ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
297    ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
298
299    // MobileVLM_V2 projection
300    ggml_tensor * mm_model_mlp_0_w = nullptr;
301    ggml_tensor * mm_model_mlp_0_b = nullptr;
302    ggml_tensor * mm_model_mlp_2_w = nullptr;
303    ggml_tensor * mm_model_mlp_2_b = nullptr;
304    ggml_tensor * mm_model_peg_0_w = nullptr;
305    ggml_tensor * mm_model_peg_0_b = nullptr;
306
307    // MINICPMV projection
308    ggml_tensor * mm_model_pos_embed_k = nullptr;
309    ggml_tensor * mm_model_query = nullptr;
310    ggml_tensor * mm_model_proj = nullptr;
311    ggml_tensor * mm_model_kv_proj = nullptr;
312    ggml_tensor * mm_model_attn_q_w = nullptr;
313    ggml_tensor * mm_model_attn_q_b = nullptr;
314    ggml_tensor * mm_model_attn_k_w = nullptr;
315    ggml_tensor * mm_model_attn_k_b = nullptr;
316    ggml_tensor * mm_model_attn_v_w = nullptr;
317    ggml_tensor * mm_model_attn_v_b = nullptr;
318    ggml_tensor * mm_model_attn_o_w = nullptr;
319    ggml_tensor * mm_model_attn_o_b = nullptr;
320    ggml_tensor * mm_model_ln_q_w = nullptr;
321    ggml_tensor * mm_model_ln_q_b = nullptr;
322    ggml_tensor * mm_model_ln_kv_w = nullptr;
323    ggml_tensor * mm_model_ln_kv_b = nullptr;
324    ggml_tensor * mm_model_ln_post_w = nullptr;
325    ggml_tensor * mm_model_ln_post_b = nullptr;
326
327    // gemma3
328    ggml_tensor * mm_input_proj_w = nullptr;
329    ggml_tensor * mm_soft_emb_norm_w = nullptr;
330
331    // mobilenetv5 for gemma3n
332    std::vector<mobilenetv5_block> mobilenet_blocks;
333    std::vector<int> mobilenet_stage_ends;
334    ggml_tensor * mobilenet_stem_conv_w = nullptr;
335    ggml_tensor * mobilenet_stem_conv_b = nullptr;
336    ggml_tensor * mobilenet_stem_norm_w = nullptr;
337    ggml_tensor * mm_post_proj_norm_w = nullptr;
338
339    // Multi-Scale Fusion Adapter (MSFA) components
340    ggml_tensor * msfa_concat_conv_w = nullptr;
341    ggml_tensor * msfa_concat_norm_w = nullptr;
342    ggml_tensor * msfa_ffn_expand_w = nullptr;
343    ggml_tensor * msfa_ffn_project_w = nullptr;
344    ggml_tensor * msfa_ffn_expand_bn = nullptr;
345    ggml_tensor * msfa_ffn_project_bn = nullptr;
346
347
348    // pixtral, glm4v
349    ggml_tensor * token_embd_img_break = nullptr;
350    ggml_tensor * mm_patch_merger_w = nullptr;
351    ggml_tensor * mm_patch_merger_b = nullptr;
352
353    // ultravox / whisper encoder
354    ggml_tensor * conv1d_1_w = nullptr;
355    ggml_tensor * conv1d_1_b = nullptr;
356    ggml_tensor * conv1d_2_w = nullptr;
357    ggml_tensor * conv1d_2_b = nullptr;
358    ggml_tensor * mm_norm_pre_w = nullptr;
359    ggml_tensor * mm_norm_pre_b = nullptr;
360    ggml_tensor * mm_norm_mid_w = nullptr;
361
362    // cogvlm
363    ggml_tensor * mm_post_fc_norm_w = nullptr;
364    ggml_tensor * mm_post_fc_norm_b = nullptr;
365    ggml_tensor * mm_h_to_4h_w = nullptr;
366    ggml_tensor * mm_gate_w = nullptr;
367    ggml_tensor * mm_4h_to_h_w = nullptr;
368    ggml_tensor * mm_boi = nullptr;
369    ggml_tensor * mm_eoi = nullptr;
370
371    // lfm2 audio
372    std::array<ggml_tensor *, 7> pre_encode_conv_X_w = {nullptr};
373    std::array<ggml_tensor *, 7> pre_encode_conv_X_b = {nullptr};
374    ggml_tensor * pre_encode_out_w = nullptr;
375    ggml_tensor * pre_encode_out_b = nullptr;
376
377    bool audio_has_avgpool() const {
378        return proj_type == PROJECTOR_TYPE_QWEN2A
379            || proj_type == PROJECTOR_TYPE_VOXTRAL
380            || proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO;
381    }
382
383    bool audio_has_stack_frames() const {
384        return proj_type == PROJECTOR_TYPE_ULTRAVOX
385            || proj_type == PROJECTOR_TYPE_VOXTRAL;
386    }
387};
388
389const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx);