1#include "models.h"
2
3ggml_cgraph * clip_graph_youtuvl::build() {
4 GGML_ASSERT(model.class_embedding == nullptr);
5 const int batch_size = 1;
6 const bool use_window_attn = !hparams.wa_layer_indexes.empty();
7 const int n_pos = n_patches;
8 const int num_position_ids = n_pos * 4;
9 const int m = 2;
10 const int Wp = n_patches_x;
11 const int Hp = n_patches_y;
12 const int Hm = Hp / m;
13 const int Wm = Wp / m;
14 norm_type norm_t = NORM_TYPE_NORMAL;
15
16 int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
17
18 ggml_tensor * inp = build_inp_raw();
19
20 // change conv3d to linear
21 // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
22 {
23 inp = ggml_reshape_4d(
24 ctx0, inp,
25 Wm * m * patch_size, m * patch_size, Hm, 3);
26 inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
27 inp = ggml_cont_4d(
28 ctx0, inp,
29 m * patch_size * 3, Wm, m * patch_size, Hm);
30
31 inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
32 inp = ggml_cont_4d(
33 ctx0, inp,
34 m * patch_size * 3, patch_size, m, Hm * Wm);
35
36 inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
37 inp = ggml_cont_4d(
38 ctx0, inp,
39 patch_size, 3, patch_size, Hm * Wm * m * m);
40
41 inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
42 inp = ggml_cont_3d(
43 ctx0, inp,
44 3*patch_size* patch_size, Hm * Wm * m * m, 1);
45 }
46 inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
47
48 if (model.patch_bias) {
49 inp = ggml_add(ctx0, inp, model.patch_bias);
50 }
51
52 inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
53
54 ggml_tensor * inpL = inp;
55 ggml_tensor * window_mask = nullptr;
56 ggml_tensor * window_idx = nullptr;
57 ggml_tensor * inv_window_idx = nullptr;
58
59 ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
60 ggml_set_name(positions, "positions");
61 ggml_set_input(positions);
62
63 // pre-layernorm
64 if (model.pre_ln_w) {
65 inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
66 }
67 if (use_window_attn) {
68 inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
69 ggml_set_name(inv_window_idx, "inv_window_idx");
70 ggml_set_input(inv_window_idx);
71 // mask for window attention
72 window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
73 ggml_set_name(window_mask, "window_mask");
74 ggml_set_input(window_mask);
75
76 // if flash attn is used, we need to pad the mask and cast to f16
77 if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
78 window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16);
79 }
80
81 // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
82 GGML_ASSERT(batch_size == 1);
83 inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
84 inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
85 inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
86 }
87
88 // loop over layers
89 for (int il = 0; il < n_layer; il++) {
90 const auto & layer = model.layers[il];
91 const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true;
92
93 ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
94
95 // layernorm1
96 cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
97 // self-attention
98 {
99 ggml_tensor * Qcur = ggml_add(ctx0,
100 ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
101 ggml_tensor * Kcur = ggml_add(ctx0,
102 ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
103 ggml_tensor * Vcur = ggml_add(ctx0,
104 ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
105
106 Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
107 Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
108 Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
109
110 Qcur = ggml_rope_multi(
111 ctx0, Qcur, positions, nullptr,
112 d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
113 Kcur = ggml_rope_multi(
114 ctx0, Kcur, positions, nullptr,
115 d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
116
117 ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
118
119 cur = build_attn(layer.o_w, layer.o_b,
120 Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
121 }
122 // re-add the layer input, e.g., residual
123 cur = ggml_add(ctx0, cur, inpL);
124
125 inpL = cur; // inpL = residual, cur = hidden_states
126
127 // layernorm2
128 cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
129
130 // ffn
131 cur = build_ffn(cur,
132 layer.ff_up_w, layer.ff_up_b,
133 nullptr, nullptr,
134 layer.ff_down_w, layer.ff_down_b,
135 hparams.ffn_op, il);
136
137 // residual 2
138 cur = ggml_add(ctx0, inpL, cur);
139
140 inpL = cur;
141 }
142
143 ggml_tensor * embeddings = inpL;
144 if (use_window_attn) {
145 const int spatial_merge_unit = 4;
146 window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit);
147 ggml_set_name(window_idx, "window_idx");
148 ggml_set_input(window_idx);
149 GGML_ASSERT(batch_size == 1);
150 embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit);
151 embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
152 embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size);
153 cb(embeddings, "window_order_restored", -1);
154 }
155
156 // post-layernorm (part of Siglip2VisionTransformer, applied after encoder)
157 if (model.post_ln_w) {
158 embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
159 }
160
161 // Now apply merger (VLPatchMerger):
162 // 1. Apply RMS norm (ln_q in VLPatchMerger)
163 embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
164 cb(embeddings, "merger_normed", -1);
165
166 // 2. First reshape for spatial merge (merge 2x2 patches)
167 embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
168 cb(embeddings, "merger_reshaped", -1);
169
170 embeddings = build_ffn(embeddings,
171 model.mm_0_w, model.mm_0_b,
172 nullptr, nullptr,
173 model.mm_1_w, model.mm_1_b,
174 FFN_GELU,
175 -1);
176 ggml_build_forward_expand(gf, embeddings);
177
178 return gf;
179}