1#include "models.h"
2
3// Helpers for MobileNetV5 Blocks
4// RMS Norm 2D - normalizes over channels for each spatial position
5ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) {
6 // inp: [W, H, C, B]
7
8 ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3);
9 cur = ggml_cont(ctx0, cur);
10 cur = ggml_rms_norm(ctx0, cur, eps);
11
12 if (weight) {
13 cur = ggml_mul(ctx0, cur, weight);
14 }
15
16 cur = ggml_permute(ctx0, cur, 2, 1, 0, 3);
17 cur = ggml_cont(ctx0, cur);
18
19 return cur;
20}
21
22// Conv2dSame padding - asymmetric SAME padding like PyTorch/TF
23ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) {
24 const int64_t ih = inp->ne[1]; // height
25 const int64_t iw = inp->ne[0]; // width
26
27 // Calculate output size (ceil division)
28 const int64_t oh = (ih + stride_h - 1) / stride_h;
29 const int64_t ow = (iw + stride_w - 1) / stride_w;
30
31 // Calculate padding needed
32 const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih);
33 const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw);
34
35 // Split padding asymmetrically
36 const int pad_h_top = pad_h / 2;
37 const int pad_h_bottom = pad_h - pad_h_top;
38 const int pad_w_left = pad_w / 2;
39 const int pad_w_right = pad_w - pad_w_left;
40
41 // Apply padding if needed
42 // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
43 // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch
44 if (pad_h > 0 || pad_w > 0) {
45 inp = ggml_pad_ext(ctx0, inp,
46 pad_w_left, pad_w_right, // width padding (dim 0)
47 pad_h_top, pad_h_bottom, // height padding (dim 1)
48 0, 0, // no channel padding (dim 2)
49 0, 0); // no batch padding (dim 3)
50 }
51
52 return inp;
53}
54
55
56// Edge Residual Block (Stage 0)
57ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
58 ggml_tensor * cur = inp;
59
60 // 1. Expansion Conv (3x3)
61 if (stride == 2) {
62 // Case: Downsampling (Block 0)
63 // Replicates Conv2dSame(kernel=3, stride=2)
64 cur = pad_same_2d(cur, 3, 3, stride, stride);
65 cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1);
66 } else {
67 // Case: Normal 3x3 Block (Block 1, 2)
68 // Replicates Conv2d(kernel=3, stride=1, padding=1)
69 cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1);
70 }
71
72 // BN + Activation
73 if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w);
74 cur = ggml_gelu(ctx0, cur);
75
76 // 2. Pointwise Linear Conv (1x1)
77 // 1x1 Convs usually have padding=0 and stride=1
78 cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1);
79 if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w);
80
81 // 3. Residual Connection
82 // Only apply residual if spatial dimensions and channels match (stride 1)
83 if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) {
84 cur = ggml_add(ctx0, cur, inp);
85 }
86
87 return cur;
88}
89
90// Universal Inverted Residual Block (Stage 1+)
91ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) {
92 ggml_tensor * cur = inp;
93
94 // 1. Depthwise Start (Optional)
95 // NOTE: dw_start always has stride=1 (no downsampling here)
96 if (block.dw_start_w) {
97 int k = block.dw_start_w->ne[0]; // 3 or 5
98 int p = k / 2;
99 cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1);
100 if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w);
101 }
102
103 // 2. Pointwise Expansion (1x1)
104 if (block.pw_exp_w) {
105 // Standard 1x1 conv, pad=0, stride=1
106 cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1);
107 if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w);
108 cur = ggml_gelu(ctx0, cur);
109 }
110
111 // 3. Depthwise Mid (Optional)
112 // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage)
113 if (block.dw_mid_w) {
114 int k = block.dw_mid_w->ne[0]; // 3 or 5
115
116 if (stride > 1) {
117 // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding
118 cur = pad_same_2d(cur, k, k, stride, stride);
119 cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0
120 } else {
121 // Case: Stride 1 -> Use Standard Symmetric Padding
122 int p = k / 2;
123 cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1);
124 }
125
126 if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w);
127 cur = ggml_gelu(ctx0, cur);
128 }
129
130 // 4. Pointwise Projection (1x1)
131 if (block.pw_proj_w) {
132 cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1);
133 if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w);
134 }
135
136 // Apply Layer Scaling if present
137 if (block.layer_scale_w) {
138 cur = ggml_mul(ctx0, cur, block.layer_scale_w);
139 }
140
141 // 5. Residual Connection
142 bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]);
143 bool same_channel = (inp->ne[2] == cur->ne[2]);
144 if (same_spatial && same_channel) {
145 cur = ggml_add(ctx0, cur, inp);
146 }
147
148 return cur;
149}
150
151// Attention Block (MQA)
152ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) {
153 ggml_tensor * cur = inp;
154
155 // Norm
156 if (block.attn_norm_w) {
157 cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f);
158 }
159
160 // 1. Q Calculation
161 ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1);
162
163 // 2. K Calculation (Downsampled)
164 // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
165 ggml_tensor * k_inp = cur;
166 if (block.attn_k_dw_w) {
167 int k_size = block.attn_k_dw_w->ne[0]; // Usually 3
168 k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding
169 k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0
170 if (block.attn_k_norm_w) {
171 k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f);
172 }
173 }
174 ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1);
175
176 // 3. V Calculation (Downsampled)
177 // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640)
178 ggml_tensor * v_inp = cur;
179 if (block.attn_v_dw_w) {
180 int v_size = block.attn_v_dw_w->ne[0]; // Usually 3
181 v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding
182 v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0
183 if (block.attn_v_norm_w) {
184 v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f);
185 }
186 }
187 ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1);
188
189 const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3];
190 const int D = k->ne[2]; // Head dimension
191 const int n_head = q->ne[2] / D;
192 const int N = W * H;
193
194 // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B]
195 q = ggml_reshape_3d(ctx0, q, N, D*n_head, B);
196 q = ggml_reshape_4d(ctx0, q, N, D, n_head, B);
197 q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B]
198 q = ggml_cont(ctx0, q);
199
200 const int Wk = k->ne[0]; const int Hk = k->ne[1];
201 const int M = Wk * Hk;
202
203 // Process K: [Wk, Hk, D, B] -> [D, M, 1, B]
204 k = ggml_reshape_3d(ctx0, k, M, D, B);
205 k = ggml_reshape_4d(ctx0, k, M, D, 1, B);
206 k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B]
207 k = ggml_cont(ctx0, k);
208
209 // Process V: [Wk, Hk, D, B] -> [M, D, 1, B]
210 v = ggml_reshape_3d(ctx0, v, M, D, B);
211 v = ggml_reshape_4d(ctx0, v, M, D, 1, B);
212 v = ggml_cont(ctx0, v); // [M, D, 1, B]
213
214 // Multi-Query Attention
215 float scale = 1.0f / sqrtf((float)D);
216
217 // Step 1: Compute Q @ K.T
218 ggml_tensor * scores = ggml_mul_mat(ctx0, k, q);
219
220 scores = ggml_scale(ctx0, scores, scale);
221
222 scores = ggml_soft_max(ctx0, scores);
223
224 ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores);
225
226 kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3);
227 kqv = ggml_cont(ctx0, kqv);
228
229
230 kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B);
231 kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B);
232 kqv = ggml_cont(ctx0, kqv);
233
234 // Output projection
235 cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1);
236
237 // Residual & Layer Scale
238 if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) {
239 if (block.layer_scale_w) {
240 cur = ggml_mul(ctx0, cur, block.layer_scale_w);
241 }
242 cur = ggml_add(ctx0, cur, inp);
243 }
244
245 return cur;
246}
247
248ggml_cgraph * clip_graph_mobilenetv5::build() {
249 ggml_tensor * inp = build_inp_raw();
250
251 // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2))
252 ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding
253
254 cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0
255 if (model.mobilenet_stem_conv_b) {
256 cur = ggml_add(ctx0, cur, model.mobilenet_stem_conv_b);
257 }
258 if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w);
259 cur = ggml_gelu(ctx0, cur);
260
261
262 // 2. Blocks
263 std::vector<ggml_tensor*> intermediate_features;
264 const int total_blocks = model.mobilenet_blocks.size();
265
266 auto is_stage_start = [&](int i) {
267 if (i == 0) return true;
268 for (int end_idx : model.mobilenet_stage_ends) {
269 if (i == end_idx + 1) return true;
270 }
271 return false;
272 };
273
274 auto is_fusion_point = [&](int i) {
275 if (model.mobilenet_stage_ends.size() >= 4) {
276 if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2
277 if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3
278 } else {
279 if (i == total_blocks - 1) return true;
280 }
281 return false;
282 };
283
284 for (int i = 0; i < total_blocks; i++) {
285 const auto & block = model.mobilenet_blocks[i];
286 int stride = is_stage_start(i) ? 2 : 1;
287
288 if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride);
289 else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block);
290 else cur = build_inverted_residual(cur, block, stride);
291
292 if (is_fusion_point(i)) {
293
294 intermediate_features.push_back(cur);
295 }
296 }
297
298 // 3. Multi-Scale Fusion Adapter (MSFA)
299 if (!intermediate_features.empty()) {
300
301 // A. Reference Resolution: PyTorch implementation uses inputs[0]
302 // We assume intermediate_features[0] is the "High Resolution" target.
303 // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32).
304 ggml_tensor* target_feat = intermediate_features[0];
305 int high_res_w = target_feat->ne[0];
306 int high_res_h = target_feat->ne[1];
307
308 std::vector<ggml_tensor*> resized_feats;
309
310 // B. Resize inputs to match inputs[0] (High Resolution)
311 for (auto feat : intermediate_features) {
312 int feat_w = feat->ne[0];
313 int feat_h = feat->ne[1];
314
315 // PyTorch: if feat_size < high_resolution: interpolate
316 if (feat_w < high_res_w || feat_h < high_res_h) {
317 // Calculate scale factor.
318 // Note: PyTorch 'nearest' works on arbitrary float scales.
319 // ggml_upscale generally takes integer factors or target sizes depending on helper.
320 // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2).
321 int scale_w = high_res_w / feat_w;
322 // int scale_h = high_res_h / feat_h;
323
324 // Safety check for non-integer scaling if strictly replicating
325 GGML_ASSERT(high_res_w % feat_w == 0);
326
327 // Upsample (Nearest Neighbor)
328 // 2 is the scale factor
329 feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST);
330 }
331 resized_feats.push_back(feat);
332 }
333
334 // C. Concatenate at High Resolution (Channel Dim = 2 in ggml)
335 cur = resized_feats[0];
336 for (size_t k = 1; k < resized_feats.size(); ++k) {
337 cur = ggml_concat(ctx0, cur, resized_feats[k], 2);
338 }
339
340 // D. FFN (UniversalInvertedResidual)
341 // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm
342
343 // 1. Expansion
344 if (model.msfa_ffn_expand_w) {
345 // 1x1 Conv
346 cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1);
347
348 if (model.msfa_ffn_expand_bn) {
349 cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn);
350 }
351
352 cur = ggml_gelu(ctx0, cur);
353
354 }
355
356 // 2. Projection (No DW because kernel_size=0)
357 if (model.msfa_ffn_project_w) {
358 // 1x1 Conv
359 cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1);
360
361 // UniversalInvertedResidual typically has a norm after projection
362 if (model.msfa_ffn_project_bn) {
363 cur = rms_norm_2d(cur, model.msfa_ffn_project_bn);
364 }
365
366 }
367
368 // E. Final Downsample to Target Resolution (Output Resolution)
369 // PyTorch: matches self.output_resolution (e.g. 16x16)
370 const int target_out_res = 16;
371 int current_w = cur->ne[0];
372
373 if (current_w > target_out_res) {
374 int s = current_w / target_out_res;
375
376 GGML_ASSERT(current_w % target_out_res == 0);
377
378 // Avg Pool: Kernel=s, Stride=s
379 cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0);
380
381 }
382
383 // F. Final Norm
384 if (model.msfa_concat_norm_w) {
385 cur = rms_norm_2d(cur, model.msfa_concat_norm_w);
386
387 }
388 }
389
390 // 4. Gemma 3n Multimodal Projection (Embedder)
391 // Input: 'cur' is [Width, Height, Channels, Batch]
392 int W = cur->ne[0];
393 int H = cur->ne[1];
394 int C = cur->ne[2];
395 int B = cur->ne[3];
396
397 GGML_ASSERT(C == hparams.n_embd);
398
399 // 1. Permute and Flatten to [Channels, Tokens, Batch]
400 // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch)
401 cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B]
402 cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B]
403 cur = ggml_cont(ctx0, cur);
404 cur = ggml_reshape_3d(ctx0, cur, C, W*H, B);
405 cur = ggml_cont(ctx0, cur);
406
407
408 // 2. FEATURE SCALING
409 // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5
410 const float scale_factor = sqrtf((float)C);
411 cur = ggml_scale(ctx0, cur, scale_factor);
412
413
414 // 3. SOFT EMBEDDING NORM
415 // PyTorch: self._norm(x) * self.weight
416 // We must normalize regardless, then multiply if weight exists.
417 {
418 const float eps = 1e-6f; // Gemma3n uses 1e-6
419 cur = ggml_rms_norm(ctx0, cur, eps);
420
421 if (model.mm_soft_emb_norm_w) {
422 // Weight shape is (2048,) -> Element-wise broadcast multiply
423 cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
424 }
425
426 }
427
428 // 4. PROJECTION
429 // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False)
430 // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size]
431 if (model.mm_input_proj_w) {
432 cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur);
433 }
434
435 // 5. POST PROJECTION NORM
436 // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False)
437 // with_scale=False means weight is registered as buffer with value 1.0
438 // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1
439 {
440 const float eps = 1e-6f;
441 cur = ggml_rms_norm(ctx0, cur, eps);
442
443 if (model.mm_post_proj_norm_w) {
444 // If weight is loaded, multiply (should be ~1.0 anyway)
445 cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w);
446 }
447 }
448
449 ggml_build_forward_expand(gf, cur);
450 return gf;
451}