1#include "models.h"
  2
  3// Helpers for MobileNetV5 Blocks
  4// RMS Norm 2D - normalizes over channels for each spatial position
  5ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) {
  6    // inp: [W, H, C, B]
  7
  8    ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3);
  9    cur = ggml_cont(ctx0, cur);
 10    cur = ggml_rms_norm(ctx0, cur, eps);
 11
 12    if (weight) {
 13        cur = ggml_mul(ctx0, cur, weight);
 14    }
 15
 16    cur = ggml_permute(ctx0, cur, 2, 1, 0, 3);
 17    cur = ggml_cont(ctx0, cur);
 18
 19    return cur;
 20}
 21
 22// Conv2dSame padding - asymmetric SAME padding like PyTorch/TF
 23ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) {
 24    const int64_t ih = inp->ne[1];  // height
 25    const int64_t iw = inp->ne[0];  // width
 26
 27    // Calculate output size (ceil division)
 28    const int64_t oh = (ih + stride_h - 1) / stride_h;
 29    const int64_t ow = (iw + stride_w - 1) / stride_w;
 30
 31    // Calculate padding needed
 32    const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih);
 33    const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw);
 34
 35    // Split padding asymmetrically
 36    const int pad_h_top = pad_h / 2;
 37    const int pad_h_bottom = pad_h - pad_h_top;
 38    const int pad_w_left = pad_w / 2;
 39    const int pad_w_right = pad_w - pad_w_left;
 40
 41    // Apply padding if needed
 42    // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
 43    // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch
 44    if (pad_h > 0 || pad_w > 0) {
 45        inp = ggml_pad_ext(ctx0, inp,
 46            pad_w_left, pad_w_right,     // width padding (dim 0)
 47            pad_h_top, pad_h_bottom,      // height padding (dim 1)
 48            0, 0,                         // no channel padding (dim 2)
 49            0, 0);                        // no batch padding (dim 3)
 50    }
 51
 52    return inp;
 53}
 54
 55
 56// Edge Residual Block (Stage 0)
 57ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
 58    ggml_tensor * cur = inp;
 59
 60    // 1. Expansion Conv (3x3)
 61    if (stride == 2) {
 62        // Case: Downsampling (Block 0)
 63        // Replicates Conv2dSame(kernel=3, stride=2)
 64        cur = pad_same_2d(cur, 3, 3, stride, stride);
 65        cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1);
 66    } else {
 67        // Case: Normal 3x3 Block (Block 1, 2)
 68        // Replicates Conv2d(kernel=3, stride=1, padding=1)
 69        cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1);
 70    }
 71
 72    // BN + Activation
 73    if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w);
 74    cur = ggml_gelu(ctx0, cur);
 75
 76    // 2. Pointwise Linear Conv (1x1)
 77    // 1x1 Convs usually have padding=0 and stride=1
 78    cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1);
 79    if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w);
 80
 81    // 3. Residual Connection
 82    // Only apply residual if spatial dimensions and channels match (stride 1)
 83    if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) {
 84        cur = ggml_add(ctx0, cur, inp);
 85    }
 86
 87    return cur;
 88}
 89
 90// Universal Inverted Residual Block (Stage 1+)
 91ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
 92    ggml_tensor * cur = inp;
 93
 94    // 1. Depthwise Start (Optional)
 95    // NOTE: dw_start always has stride=1 (no downsampling here)
 96    if (block.dw_start_w) {
 97        int k = block.dw_start_w->ne[0]; // 3 or 5
 98        int p = k / 2;
 99        cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1);
100        if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w);
101    }
102
103    // 2. Pointwise Expansion (1x1)
104    if (block.pw_exp_w) {
105        // Standard 1x1 conv, pad=0, stride=1
106        cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1);
107        if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w);
108        cur = ggml_gelu(ctx0, cur);
109    }
110
111    // 3. Depthwise Mid (Optional)
112    // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage)
113    if (block.dw_mid_w) {
114        int k = block.dw_mid_w->ne[0]; // 3 or 5
115
116        if (stride > 1) {
117            // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding
118            cur = pad_same_2d(cur, k, k, stride, stride);
119            cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0
120        } else {
121            // Case: Stride 1 -> Use Standard Symmetric Padding
122            int p = k / 2;
123            cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1);
124        }
125
126        if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w);
127        cur = ggml_gelu(ctx0, cur);
128    }
129
130    // 4. Pointwise Projection (1x1)
131    if (block.pw_proj_w) {
132        cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1);
133        if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w);
134    }
135
136    // Apply Layer Scaling if present
137    if (block.layer_scale_w) {
138        cur = ggml_mul(ctx0, cur, block.layer_scale_w);
139    }
140
141    // 5. Residual Connection
142    bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]);
143    bool same_channel = (inp->ne[2] == cur->ne[2]);
144    if (same_spatial && same_channel) {
145        cur = ggml_add(ctx0, cur, inp);
146    }
147
148    return cur;
149}
150
151// Attention Block (MQA)
152ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) {
153    ggml_tensor * cur = inp;
154
155    // Norm
156    if (block.attn_norm_w) {
157        cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f);
158    }
159
160    // 1. Q Calculation
161    ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1);
162
163    // 2. K Calculation (Downsampled)
164    // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
165    ggml_tensor * k_inp = cur;
166    if (block.attn_k_dw_w) {
167        int k_size = block.attn_k_dw_w->ne[0];  // Usually 3
168        k_inp = pad_same_2d(cur, k_size, k_size, 2, 2);  // Apply SAME padding
169        k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1);  // padding=0
170        if (block.attn_k_norm_w) {
171            k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f);
172        }
173    }
174    ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1);
175
176    // 3. V Calculation (Downsampled)
177    // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
178    ggml_tensor * v_inp = cur;
179    if (block.attn_v_dw_w) {
180        int v_size = block.attn_v_dw_w->ne[0];  // Usually 3
181        v_inp = pad_same_2d(cur, v_size, v_size, 2, 2);  // Apply SAME padding
182        v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1);  // padding=0
183        if (block.attn_v_norm_w) {
184            v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f);
185        }
186    }
187    ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1);
188
189    const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3];
190    const int D = k->ne[2]; // Head dimension
191    const int n_head = q->ne[2] / D;
192    const int N = W * H;
193
194    // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B]
195    q = ggml_reshape_3d(ctx0, q, N, D*n_head, B);
196    q = ggml_reshape_4d(ctx0, q, N, D, n_head, B);
197    q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B]
198    q = ggml_cont(ctx0, q);
199
200    const int Wk = k->ne[0]; const int Hk = k->ne[1];
201    const int M = Wk * Hk;
202
203    // Process K: [Wk, Hk, D, B] -> [D, M, 1, B]
204    k = ggml_reshape_3d(ctx0, k, M, D, B);
205    k = ggml_reshape_4d(ctx0, k, M, D, 1, B);
206    k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B]
207    k = ggml_cont(ctx0, k);
208
209    // Process V: [Wk, Hk, D, B] -> [M, D, 1, B]
210    v = ggml_reshape_3d(ctx0, v, M, D, B);
211    v = ggml_reshape_4d(ctx0, v, M, D, 1, B);
212    v = ggml_cont(ctx0, v); // [M, D, 1, B]
213
214    // Multi-Query Attention
215    float scale = 1.0f / sqrtf((float)D);
216
217    // Step 1: Compute Q @ K.T
218    ggml_tensor * scores = ggml_mul_mat(ctx0, k, q);
219
220    scores = ggml_scale(ctx0, scores, scale);
221
222    scores = ggml_soft_max(ctx0, scores);
223
224    ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores);
225
226    kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3);
227    kqv = ggml_cont(ctx0, kqv);
228
229
230    kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B);
231    kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B);
232    kqv = ggml_cont(ctx0, kqv);
233
234    // Output projection
235    cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1);
236
237    // Residual & Layer Scale
238    if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) {
239        if (block.layer_scale_w) {
240            cur = ggml_mul(ctx0, cur, block.layer_scale_w);
241        }
242        cur = ggml_add(ctx0, cur, inp);
243    }
244
245    return cur;
246}
247
248ggml_cgraph * clip_graph_mobilenetv5::build() {
249    ggml_tensor * inp = build_inp_raw();
250
251    // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2))
252    ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2);  // Apply SAME padding
253
254    cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1);  // padding=0
255    if (model.mobilenet_stem_conv_b) {
256        cur = ggml_add(ctx0, cur, model.mobilenet_stem_conv_b);
257    }
258    if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w);
259    cur = ggml_gelu(ctx0, cur);
260
261
262    // 2. Blocks
263    std::vector<ggml_tensor*> intermediate_features;
264    const int total_blocks = model.mobilenet_blocks.size();
265
266    auto is_stage_start = [&](int i) {
267        if (i == 0) return true;
268        for (int end_idx : model.mobilenet_stage_ends) {
269            if (i == end_idx + 1) return true;
270        }
271        return false;
272    };
273
274    auto is_fusion_point = [&](int i) {
275        if (model.mobilenet_stage_ends.size() >= 4) {
276                if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2
277                if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3
278        } else {
279            if (i == total_blocks - 1) return true;
280        }
281        return false;
282    };
283
284    for (int i = 0; i < total_blocks; i++) {
285        const auto & block = model.mobilenet_blocks[i];
286        int stride = is_stage_start(i) ? 2 : 1;
287
288        if (block.s0_conv_exp_w)      cur = build_edge_residual(cur, block, stride);
289        else if (block.attn_q_w)      cur = build_mobilenet_attn(cur, block);
290        else                          cur = build_inverted_residual(cur, block, stride);
291
292        if (is_fusion_point(i)) {
293
294            intermediate_features.push_back(cur);
295        }
296    }
297
298    // 3. Multi-Scale Fusion Adapter (MSFA)
299    if (!intermediate_features.empty()) {
300
301        // A. Reference Resolution: PyTorch implementation uses inputs[0]
302        // We assume intermediate_features[0] is the "High Resolution" target.
303        // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32).
304        ggml_tensor* target_feat = intermediate_features[0];
305        int high_res_w = target_feat->ne[0];
306        int high_res_h = target_feat->ne[1];
307
308        std::vector<ggml_tensor*> resized_feats;
309
310        // B. Resize inputs to match inputs[0] (High Resolution)
311        for (auto feat : intermediate_features) {
312            int feat_w = feat->ne[0];
313            int feat_h = feat->ne[1];
314
315            // PyTorch: if feat_size < high_resolution: interpolate
316            if (feat_w < high_res_w || feat_h < high_res_h) {
317                // Calculate scale factor.
318                // Note: PyTorch 'nearest' works on arbitrary float scales.
319                // ggml_upscale generally takes integer factors or target sizes depending on helper.
320                // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2).
321                int scale_w = high_res_w / feat_w;
322                // int scale_h = high_res_h / feat_h;
323
324                // Safety check for non-integer scaling if strictly replicating
325                GGML_ASSERT(high_res_w % feat_w == 0);
326
327                // Upsample (Nearest Neighbor)
328                // 2 is the scale factor
329                feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST);
330            }
331            resized_feats.push_back(feat);
332        }
333
334        // C. Concatenate at High Resolution (Channel Dim = 2 in ggml)
335        cur = resized_feats[0];
336        for (size_t k = 1; k < resized_feats.size(); ++k) {
337            cur = ggml_concat(ctx0, cur, resized_feats[k], 2);
338        }
339
340        // D. FFN (UniversalInvertedResidual)
341        // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm
342
343        // 1. Expansion
344        if (model.msfa_ffn_expand_w) {
345            // 1x1 Conv
346            cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1);
347
348            if (model.msfa_ffn_expand_bn) {
349                cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn);
350            }
351
352            cur = ggml_gelu(ctx0, cur);
353
354        }
355
356        // 2. Projection (No DW because kernel_size=0)
357        if (model.msfa_ffn_project_w) {
358            // 1x1 Conv
359            cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1);
360
361            // UniversalInvertedResidual typically has a norm after projection
362            if (model.msfa_ffn_project_bn) {
363                cur = rms_norm_2d(cur, model.msfa_ffn_project_bn);
364            }
365
366        }
367
368        // E. Final Downsample to Target Resolution (Output Resolution)
369        // PyTorch: matches self.output_resolution (e.g. 16x16)
370        const int target_out_res = 16;
371        int current_w = cur->ne[0];
372
373        if (current_w > target_out_res) {
374            int s = current_w / target_out_res;
375
376            GGML_ASSERT(current_w % target_out_res == 0);
377
378            // Avg Pool: Kernel=s, Stride=s
379            cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0);
380
381        }
382
383        // F. Final Norm
384        if (model.msfa_concat_norm_w) {
385            cur = rms_norm_2d(cur, model.msfa_concat_norm_w);
386
387        }
388    }
389
390    // 4. Gemma 3n Multimodal Projection (Embedder)
391    // Input: 'cur' is [Width, Height, Channels, Batch]
392    int W = cur->ne[0];
393    int H = cur->ne[1];
394    int C = cur->ne[2];
395    int B = cur->ne[3];
396
397    GGML_ASSERT(C == hparams.n_embd);
398
399    // 1. Permute and Flatten to [Channels, Tokens, Batch]
400    // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch)
401    cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B]
402    cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B]
403    cur = ggml_cont(ctx0, cur);
404    cur = ggml_reshape_3d(ctx0, cur, C, W*H, B);
405    cur = ggml_cont(ctx0, cur);
406
407
408    // 2. FEATURE SCALING
409    // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5
410    const float scale_factor = sqrtf((float)C);
411    cur = ggml_scale(ctx0, cur, scale_factor);
412
413
414    // 3. SOFT EMBEDDING NORM
415    // PyTorch: self._norm(x) * self.weight
416    // We must normalize regardless, then multiply if weight exists.
417    {
418        const float eps = 1e-6f; // Gemma3n uses 1e-6
419        cur = ggml_rms_norm(ctx0, cur, eps);
420
421        if (model.mm_soft_emb_norm_w) {
422            // Weight shape is (2048,) -> Element-wise broadcast multiply
423            cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
424        }
425
426    }
427
428    // 4. PROJECTION
429    // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
430    // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
431    if (model.mm_input_proj_w) {
432        cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
433    }
434
435    // 5. POST PROJECTION NORM
436    // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False)
437    // with_scale=False means weight is registered as buffer with value 1.0
438    // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1
439    {
440        const float eps = 1e-6f;
441        cur = ggml_rms_norm(ctx0, cur, eps);
442
443        if (model.mm_post_proj_norm_w) {
444            // If weight is loaded, multiply (should be ~1.0 anyway)
445            cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w);
446        }
447    }
448
449    ggml_build_forward_expand(gf, cur);
450    return gf;
451}